๐ Getting started#
๐ ๏ธ Installation#
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/pytreeclass
๐ Description#
pytreeclass
is a JAX-compatible class builder to create and operate on stateful JAX PyTrees.
The package aims to achieve two goals:
๐ To maintain safe and correct behaviour by using immutable modules with functional API.
To achieve the most intuitive user experience in the
JAX
ecosystem by :๐๏ธ Defining layers similar to
PyTorch
orTensorFlow
subclassing style.โ๏ธ Filtering:nbsphinx-math:Indexing `layer values similar to ``jax.numpy.at[].{get,set,apply,โฆ}`
๐จ Visualize defined layers in plethora of ways.
โฉ Quick Example#
[1]:
import jax
import jax.numpy as jnp
import pytreeclass as tc
@tc.autoinit
class Tree(tc.TreeClass):
a: int = 1
b: tuple[float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
tree
[1]:
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
๐จ Visualize#
tree_summary
#
[2]:
tree = Tree()
print(tc.tree_summary(tree, depth=1))
โโโโโโฌโโโโโโโฌโโโโโโฌโโโโโโโ
โNameโType โCountโSize โ
โโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.a โint โ1 โ โ
โโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.b โtuple โ2 โ โ
โโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.c โf32[3]โ3 โ12.00Bโ
โโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โฮฃ โTree โ6 โ12.00Bโ
โโโโโโดโโโโโโโดโโโโโโดโโโโโโโ
[3]:
tree = Tree()
print(tc.tree_summary(tree, depth=2))
โโโโโโโฌโโโโโโโฌโโโโโโฌโโโโโโโ
โName โType โCountโSize โ
โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.a โint โ1 โ โ
โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.b[0]โfloat โ1 โ โ
โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.b[1]โfloat โ1 โ โ
โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โ.c โf32[3]โ3 โ12.00Bโ
โโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค
โฮฃ โTree โ6 โ12.00Bโ
โโโโโโโดโโโโโโโดโโโโโโดโโโโโโโ
tree_diagram
#
[4]:
tree = Tree()
print(tc.tree_diagram(tree, depth=1))
Tree
โโโ .a=1
โโโ .b=(...)
โโโ .c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
[5]:
tree = Tree()
print(tc.tree_diagram(tree, depth=2))
Tree
โโโ .a=1
โโโ .b:tuple
โ โโโ [0]=2.0
โ โโโ [1]=3.0
โโโ .c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00])
tree_repr
#
[6]:
print(tc.tree_repr(tree, depth=1))
Tree(a=1, b=(...), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
[7]:
print(tc.tree_repr(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯ=0.82, โ[4.00,6.00]))
tree_str
#
[8]:
print(tc.tree_str(tree, depth=1))
Tree(a=1, b=(...), c=[4. 5. 6.])
[9]:
print(tc.tree_str(tree, depth=2))
Tree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
๐ Working with jax
transformation#
Parameters are defined in Tree
at the top of class definition similar to defining dataclasses.dataclass
field. Lets optimize our parameters
[10]:
@jax.grad
def loss_func(tree: Tree, x: jax.Array):
tree = tc.tree_unmask(tree) # <--- unmask the tree
preds = jax.vmap(tree)(x) # <--- vectorize the tree call over the leading axis
return jnp.mean(preds**2) # <--- return the mean squared error
@jax.jit
def train_step(tree: Tree, x: jax.Array):
grads = loss_func(tree, x)
# apply a small gradient step
return jax.tree_util.tree_map(lambda x, g: x - 1e-3 * g, tree, grads)
# lets mask the non-differentiable parts of the tree with a frozen mask
# in essence any non inexact type should be frozen to
# make the tree differentiable and work with jax transformations
tree = tc.tree_mask(tree)
for epoch in range(1_000):
tree = train_step(tree, jnp.ones([10, 1]))
print(tree)
# **the `frozen` params have "#" prefix**
# Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
# unmask the frozen node (e.g. non-inexact) of the tree
tree = tc.tree_unmask(tree)
print(tree)
# Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
Tree(a=#1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
Tree(a=1, b=(-4.2826524, 3.0), c=[2.3924797 2.905778 3.4190805])
โ๏ธ Advanced Indexing with .at[]
#
Out-of-place updates using mask, attribute name or index
pytreeclass
offers 3 means of indexing through .at[]
Indexing by boolean mask.
Indexing by attribute name.
Indexing by Leaf index.
Since ``treeclass`` wrapped class are immutable, ``.at[]`` operations returns new instance of the tree
Index update by boolean mask#
[11]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
# lets create a mask for values > 4
mask = jax.tree_util.tree_map(lambda x: x > 4, tree)
print(mask)
# Tree(a=False, b=(False, False), c=[False True True])
print(tree.at[mask].get())
# Tree(a=None, b=(None, None), c=[5 6])
print(tree.at[mask].set(10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
print(tree.at[mask].apply(lambda x: 10))
# Tree(a=1, b=(2, 3), c=[ 4 10 10])
Tree(a=False, b=(False, False), c=[False True True])
Tree(a=None, b=(None, None), c=[5. 6.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
Tree(a=1, b=(2.0, 3.0), c=[ 4. 10. 10.])
Index update by attribute name#
[12]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
print(tree.at["a"].get())
# Tree(a=1, b=(None, None), c=None)
print(tree.at["a"].set(10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
print(tree.at["a"].apply(lambda x: 10))
# Tree(a=10, b=(2, 3), c=[4 5 6])
Tree(a=1, b=(None, None), c=None)
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])
Tree(a=10, b=(2.0, 3.0), c=[4. 5. 6.])
Index update by integer index#
[13]:
tree = Tree()
# Tree(a=1, b=(2, 3), c=i32[3](ฮผ=5.00, ฯ=0.82, โ[4,6]))
print(tree.at[1][0].get())
# Tree(a=None, b=(2.0, None), c=None)
print(tree.at[1][0].set(10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
print(tree.at[1][0].apply(lambda x: 10))
# Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
Tree(a=None, b=(2.0, None), c=None)
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])
Tree(a=1, b=(10, 3.0), c=[4. 5. 6.])