Open In Colab

๐Ÿƒ Getting started#

๐Ÿ› ๏ธ Installation#

pip install pytreeclass

Install development version

pip install git+https://github.com/ASEM000/pytreeclass

๐Ÿ“– Description#

pytreeclass is a JAX-compatible class builder to create and operate on stateful JAX PyTrees.

The package aims to achieve two goals:

  1. ๐Ÿ”’ To maintain safe and correct behaviour by using immutable modules with functional API.

  2. To achieve the most intuitive user experience in the JAX ecosystem by :

    • ๐Ÿ—๏ธ Defining layers similar to PyTorch or TensorFlow subclassing style.

    • โ˜๏ธ Filtering:nbsphinx-math:Indexing `layer values similar to ``jax.numpy.at[].{get,set,apply,โ€ฆ}`

    • ๐ŸŽจ Visualize defined layers in plethora of ways.

โฉ Quick Example#

[1]:
import jax
import jax.numpy as jnp
import pytreeclass as tc


@tc.autoinit
class Tree(tc.TreeClass):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
tree
[1]:
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))

๐ŸŽจ Visualize#

tree_summary#

[2]:
tree = Tree()
print(tc.tree_summary(tree, depth=1))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Type  โ”‚Countโ”‚Size  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.a  โ”‚int   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b  โ”‚tuple โ”‚2    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.c  โ”‚f32[3]โ”‚3    โ”‚12.00Bโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚Tree  โ”‚6    โ”‚12.00Bโ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜
[3]:
tree = Tree()
print(tc.tree_summary(tree, depth=2))
โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name โ”‚Type  โ”‚Countโ”‚Size  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.a   โ”‚int   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[0]โ”‚float โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[1]โ”‚float โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.c   โ”‚f32[3]โ”‚3    โ”‚12.00Bโ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ    โ”‚Tree  โ”‚6    โ”‚12.00Bโ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜

tree_diagram#

[4]:
tree = Tree()
print(tc.tree_diagram(tree, depth=1))
Tree
โ”œโ”€โ”€ .a=1
โ”œโ”€โ”€ .b=(...)
โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
[5]:
tree = Tree()
print(tc.tree_diagram(tree, depth=2))
Tree
โ”œโ”€โ”€ .a=1
โ”œโ”€โ”€ .b:tuple
โ”‚   โ”œโ”€โ”€ [0]=2.0
โ”‚   โ””โ”€โ”€ [1]=3.0
โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])

tree_repr#

[6]:
print(tc.tree_repr(tree, depth=1))
Tree(a=1, b=(...), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
[7]:
print(tc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))

tree_str#

[8]:
print(tc.tree_str(tree, depth=1))
Tree(a=1, b=(...), c=[4. 5. 6.])
[9]:
print(tc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])

๐Ÿƒ Working with jax transformation#

Parameters are defined in Tree at the top of class definition similar to defining dataclasses.dataclass field. Lets optimize our parameters

[10]:
@jax.grad
def loss_func(tree: Tree, x: jax.Array):
    tree = tc.tree_unmask(tree)  # <--- unmask the tree
    preds = jax.vmap(tree)(x)  # <--- vectorize the tree call over the leading axis
    return jnp.mean(preds**2)  # <--- return the mean squared error


@jax.jit
def train_step(tree: Tree, x: jax.Array):
    grads = loss_func(tree, x)
    # apply a small gradient step
    return jax.tree_util.tree_map(lambda x, g: x - 1e-3 * g, tree, grads)


# lets mask the non-differentiable parts of the tree with a frozen mask
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
tree = tc.tree_mask(tree)

for epoch in range(1_000):
    tree = train_step(tree, jnp.ones([10, 1]))

print(tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])


# unmask the frozen node (e.g. non-inexact) of  the tree
tree = tc.tree_unmask(tree)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])
Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])
Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778  3.4190805])

โ˜๏ธ Advanced Indexing with .at[]#

Out-of-place updates using mask, attribute name or index

pytreeclass offers 3 means of indexing through .at[]

  1. Indexing by boolean mask.

  2. Indexing by attribute name.

  3. Indexing by Leaf index.

Since ``treeclass`` wrapped class are immutable, ``.at[]`` operations returns new instance of the tree

Index update by boolean mask#

[11]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x > 4, tree)

print(mask)
# Tree(a=False, b=(False, False), c=[False  True  True])

print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])

print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])

print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
Tree(a=False, b=(False, False), c=[False  True  True])
Tree(a=None, b=(None, None), c=[5. 6.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])

Index update by attribute name#

[12]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)

print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])

print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
Tree(a=1, b=(None, None), c=None)
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])

Index update by integer index#

[13]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4,6]))

print(tree.at[1][0].get())
# Tree(a=None, b=(2.0, None), c=None)

print(tree.at[1][0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])

print(tree.at[1][0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
Tree(a=None, b=(2.0, None), c=None)
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])