Open In Colab

๐Ÿณ Common recipes#

This section introduces common recipes you might need while using pytreeclass to train/build models.

[1]:
!pip install pytreeclass --quiet

[1] Add a leaf to the instance after instantiation.#

The following recipe, adds a method add_leaf that sets a leaf value and name. however, since this method mutate the internal state of the instance .at['add_leaf'] is used to apply the method functionally and return method call value and a new instance .

[2]:
import pytreeclass as tc


@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0

    def add_leaf(self, name: str, value):
        setattr(self, name, value)


tree = Tree()
# Tree(a=1.0, b=2.0, c=3.0)

_, tree_with_d = tree.at["add_leaf"]("d", 4.0)

tree_with_d
[2]:
Tree(a=1.0, b=2.0, c=3.0, d=4.0)

[2] Customize optimizers-leaf updates using pytreeclass mask + Optax.#

The following recipe, optax.masked is used to apply certain optmizers to certain leaves using masking.

[3]:
import optax
import pytreeclass as tc
import jax


@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0
    b: float = 2.0
    c: float = 3.0


tree = Tree()

false_mask = tree.at[...].set(False)

a_mask = false_mask.at["a"].set(True)
b_mask = false_mask.at["b"].set(True)
c_mask = false_mask.at["c"].set(True)

optim = optax.chain(
    # update `a` with sgd of learning rate 1
    optax.masked(optax.sgd(learning_rate=1), a_mask),
    # update `b` with sgd of learning rate -1
    optax.masked(optax.sgd(learning_rate=-1), b_mask),
    # update `c` with sgd of learning rate 0
    optax.masked(optax.sgd(learning_rate=0), c_mask),
)


# freeze non-differentiable parameters
# in this case all parameters are differentiable
# but we do it incase we add a non-differentiable parameter later
tree = tree.at[jax.tree_map(tc.is_nondiff, tree)].apply(tc.freeze)

optim_state = optim.init(tree)

[3] Use numpy functions on TreeClass instance.#

jax.numpy functions can be applied to TreeClass instance using a function transformation bcmap around the numpy function and enabling the feature through @leafwise. @leafwise additionally enable math operation per-leaf, for example tree+1 will add 1 to all leaves.

[4]:
import pytreeclass as tc
import jax.numpy as jnp


@tc.leafwise
@tc.autoinit
class Tree(tc.TreeClass):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])


tree = Tree()

# make where work with arbitrary pytrees
tree_where = tc.bcmap(jnp.where)

print(tree_where(tree > 2, tree + 100, 0))
# Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])

print(tree.at[tree > 1].apply(lambda x: x + 100))
# Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])

mask = tree_where(tree > 1, True, False)
print(tree.at[mask].apply(lambda x: x + 100))
# Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])
Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])
Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])
Tree(a=1, b=(102.0, 103.0), c=[104. 105. 106.])

[4] Use visualization tools with arbitrary pytrees#

[5]:
import jax
import pytreeclass as tc

tree = [1, [2, 3], 4]

print(tc.tree_diagram(tree, depth=1))
print(tc.tree_diagram(tree, depth=2))
print(tc.tree_summary(tree, depth=1))
print(tc.tree_summary(tree, depth=2))
list
โ”œโ”€โ”€ [0]=1
โ”œโ”€โ”€ [1]=[...]
โ””โ”€โ”€ [2]=4
list
โ”œโ”€โ”€ [0]=1
โ”œโ”€โ”€ [1]:list
โ”‚   โ”œโ”€โ”€ [0]=2
โ”‚   โ””โ”€โ”€ [1]=3
โ””โ”€โ”€ [2]=4
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Typeโ”‚Countโ”‚Sizeโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[0] โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[1] โ”‚listโ”‚2    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[2] โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚listโ”‚4    โ”‚    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”
โ”‚Name  โ”‚Typeโ”‚Countโ”‚Sizeโ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[0]   โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[1][0]โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[1][1]โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚[2]   โ”‚int โ”‚1    โ”‚    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ     โ”‚listโ”‚4    โ”‚    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜

[5] Using on_setattr to validate/convert inputs#

Type and number range check

[6]:
import jax
import pytreeclass as tc


# you can use any function
@tc.autoinit
class Range(tc.TreeClass):
    min: int | float = -float("inf")
    max: int | float = float("inf")

    def __call__(self, x):
        if not (self.min <= x <= self.max):
            raise ValueError(f"{x} not in range [{self.min}, {self.max}]")
        return x


@tc.autoinit
class IsInstance(tc.TreeClass):
    klass: type | tuple[type, ...]

    def __call__(self, x):
        if not isinstance(x, self.klass):
            raise TypeError(f"{x} not an instance of {self.klass}")
        return x


@tc.autoinit
class Foo(tc.TreeClass):
    # allow in_dim to be an integer between [1,100]
    in_dim: int = tc.field(on_setattr=[IsInstance(int), Range(1, 100)])


tree = Foo(1)
# no error

try:
    tree = Foo(0)
except ValueError as e:
    print(e)

try:
    tree = Foo(1.0)
except TypeError as e:
    print(e)
On applying Range(min=1, max=100) for field=`in_dim`:
0 not in range [1, 100]
On applying IsInstance(klass=<class 'int'>) for field=`in_dim`:
1.0 not an instance of <class 'int'>

Array shape and dtype check, then dtype conversion

[7]:
import pytreeclass as tc
from typing import Any
import jax
import jax.numpy as jnp


class ArrayValidator(tc.TreeClass):
    def __init__(self, shape, dtype):
        """Validate shape and dtype of input array.

        Args:
            shape: Expected shape of array. available values are int, None, ...
                use int for fixed size, None for any size, and ... for any number
                of dimensions. for example (..., 1) allows any number of dimensions
                with the last dimension being 1. (1, ..., 1) allows any number of
                dimensions with the first and last dimensions being 1.
            dtype: Expected dtype of array.

        Example:
            >>> x = jnp.ones((5, 5))
            >>> # any number of dimensions with last dim=5
            >>> shape = (..., 5)
            >>> dtype = jnp.float32
            >>> validator = ArrayValidator(shape, dtype)
            >>> validator(x)  # no error

            >>> # must be 2 dimensions with first dim unconstrained and last dim=5
            >>> shape = (None, 5)
            >>> validator = ArrayValidator(shape, dtype)
            >>> validator(x)  # no error
        """

        if shape.count(...) > 1:
            raise ValueError("Only one ellipsis allowed")

        for si in shape:
            if not isinstance(si, (int, type(...), type(None))):
                raise TypeError(f"Expected int or ..., got {si}")

        self.shape = shape
        self.dtype = dtype

    def __call__(self, x):
        if not (hasattr(x, "shape") and hasattr(x, "dtype")):
            raise TypeError(f"Expected array with shape {self.shape}, got {x}")

        shape = list(self.shape)
        array_shape = list(x.shape)
        array_dtype = x.dtype

        if self.shape and array_dtype != self.dtype:
            raise TypeError(f"Dtype mismatch, {array_dtype=} != {self.dtype=}")

        if ... in shape:
            index = shape.index(...)
            shape = (
                shape[:index]
                + [None] * (len(array_shape) - len(shape) + 1)
                + shape[index + 1 :]
            )

        if len(shape) != len(array_shape):
            raise ValueError(f"{len(shape)=} != {len(array_shape)=}")

        for i, (li, ri) in enumerate(zip(shape, array_shape)):
            if li is None:
                continue
            if li != ri:
                raise ValueError(f"Size mismatch, {li} != {ri} at dimension {i}")
        return x


# any number of dimensions with firt dim=3 and last dim=6
shape = (3, ..., 6)
# dtype must be float32
dtype = jnp.float32

validator = ArrayValidator(shape=shape, dtype=dtype)

# convert to half precision from float32
converter = lambda x: x.astype(jnp.float16)


@tc.autoinit
class Tree(tc.TreeClass):
    array: jax.Array = tc.field(on_setattr=[validator, converter])


x = jnp.ones([3, 1, 2, 6])
tree = Tree(array=x)


try:
    y = jnp.ones([1, 1, 2, 3])
    tree = Tree(array=y)
except ValueError as e:
    print(e, "\n")
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

try:
    z = x.astype(jnp.float16)
    tree = Tree(array=z)
except TypeError as e:
    print(e)
    # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
    # Size mismatch, 3 != 1 at dimension 0
On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Size mismatch, 3 != 1 at dimension 0

On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=<class 'jax.numpy.float32'>) for field=`array`:
Dtype mismatch, array_dtype=dtype('float16') != self.dtype=<class 'jax.numpy.float32'>

[6] Freeze custom parameters using .at manually/with mask#

In the following example, some classes like Dropout, can contain some leaves that are differentiable, but we do not wish to update them. in Dropout Example, the drop_rate is a float that should not be updated by optimization. the following recipe shows how to deal with such values.

[8]:
import pytreeclass as tc
import jax


@tc.autoinit
class Dropout(tc.TreeClass):
    drop_rate: float = 0.0  # dropout rate, 0 mean no dropout

    def __call__(self, x, *, key):
        keep_rate = 1.0 - self.drop_rate
        mask = jax.random.bernoulli(key, keep_rate, x.shape)
        return jnp.where(mask, x / keep_rate, 0.0)


x = jnp.arange(10)
dropout = Dropout(drop_rate=0.5)
dropout(x, key=jax.random.PRNGKey(0))


@jax.grad
def f(layer: Dropout, x: jax.Array):
    return layer(x, key=jax.random.PRNGKey(0)).sum()


print(f(dropout, x))
# Dropout(drop_rate=108.0)  # <--- this is the gradient which is undesired


# lets fix this by freezing the dropout rate
@tc.autoinit
class Dropout(tc.TreeClass):
    drop_rate: float = tc.field(on_setattr=[tc.freeze], default=0.0)

    def __call__(self, x, *, key):
        keep_rate = 1.0 - self.drop_rate
        mask = jax.random.bernoulli(key, keep_rate, x.shape)
        return jnp.where(mask, x / keep_rate, 0.0)


x = jnp.arange(10)
dropout = Dropout(drop_rate=0.5)

dropout
# Dropout(drop_rate=#0.5)  # -> dropout rate is frozen, to call dropout layer we need to unfreeze it first


@jax.grad
def f(layer: Dropout, x: jax.Array):
    layer = jax.tree_map(tc.unfreeze, layer, is_leaf=tc.is_frozen)
    return layer(x, key=jax.random.PRNGKey(0)).sum()


f(dropout, x)
# Dropout(drop_rate=#0.5)  # <- dropout rate is not updated, can be used safely with optax


# lets say, for evaluation we want to set the dropout rate to 0.0
# then we can do the following

disable_dropout = dropout.at["drop_rate"].set(0.0, is_leaf=tc.is_frozen)
print(disable_dropout)
# Dropout(drop_rate=0.0)  # now the dropout rate is 0. and unfrozen.
# this layer is now safe to use for evaluation without special handling (like eval in pytorch)
Dropout(drop_rate=108.0)
Dropout(drop_rate=0.0)

[7] Use pytreeclass with Flax/Equinox#

The following recipe adds at support for Flax and Equinox. note for equinox use eqx.Module instead of struct.PyTreeNode

[9]:
import jax
import pytreeclass as tc
from flax import struct

import jax
import pytreeclass as tc
from flax import struct

# note that flax is registered with `jax.tree_util.register_pytree_with_keys`
# otherwise for arbitrary objects you need to do the key registration


class FlaxTree(struct.PyTreeNode):
    a: int = 1
    b: tuple[float] = (2.0, 3.0)
    c: jax.Array = jax.numpy.array([4.0, 5.0, 6.0])

    def __repr__(self) -> str:
        return tc.tree_repr(self)

    def __str__(self) -> str:
        return tc.tree_str(self)

    @property
    def at(self):
        return tc.AtIndexer(self)


flax_tree = FlaxTree()

print(f"{flax_tree!r}")
print(f"{flax_tree!s}")
print(tc.tree_diagram(flax_tree))
print(tc.tree_summary(flax_tree))

flax_tree.at["a"].set(10)
# FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
FlaxTree(a=1, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))
FlaxTree(a=1, b=(2.0, 3.0), c=[4. 5. 6.])
FlaxTree
โ”œโ”€โ”€ .a=1
โ”œโ”€โ”€ .b:tuple
โ”‚   โ”œโ”€โ”€ [0]=2.0
โ”‚   โ””โ”€โ”€ [1]=3.0
โ””โ”€โ”€ .c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00])
โ”Œโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name โ”‚Type    โ”‚Countโ”‚Size  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.a   โ”‚int     โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[0]โ”‚float   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.b[1]โ”‚float   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚.c   โ”‚f32[3]  โ”‚3    โ”‚12.00Bโ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ    โ”‚FlaxTreeโ”‚6    โ”‚12.00Bโ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜
[9]:
FlaxTree(a=10, b=(2.0, 3.0), c=f32[3](ฮผ=5.00, ฯƒ=0.82, โˆˆ[4.00,6.00]))

[8] named_parameters() like in pytreeclass#

[10]:
import pytreeclass as tc
import jax


@tc.autoinit
class Tree(tc.TreeClass):
    a: int = 1
    b: tuple[float, float] = (2.0, 3.0)


tree = Tree()

for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:
    print(path, leaf)
(NamedSequenceKey(idx=0, name='a'),) 1
(NamedSequenceKey(idx=1, name='b'), SequenceKey(idx=0)) 2.0
(NamedSequenceKey(idx=1, name='b'), SequenceKey(idx=1)) 3.0

[9] Initialize parameters based on input#

In this example, a Linear layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., weight) after instance initialization, we will use .at to create a new instance with the added parameter.

[11]:
import pytreeclass as tc
from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jr


@tc.autoinit
class LazyLinear(tc.TreeClass):
    out_features: int

    def param(self, name: str, value: Any):
        # return the value if it exists, otherwise set it and return it
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
        weight = self.param("weight", jnp.ones((x.shape[-1], self.out_features)))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return x @ weight + bias


x = jnp.ones([10, 1])

lazy_linear = LazyLinear(out_features=1)

lazy_linear
print(f"Layer before param is set:\t{lazy_linear}")


# first call will set the parameters
_, linear = lazy_linear.at["__call__"](x, key=jr.PRNGKey(0))

print(f"Layer after param is set:\t{linear}")
# subsequent calls will use the same parameters and not set them again
linear(x)
Layer before param is set:      LazyLinear(out_features=1)
Layer after param is set:       LazyLinear(out_features=1, weight=[[1.]], bias=[0.])
[11]:
Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)

[10] Store intermediate values#

This example shows how to capture specific intermediate values within each function call in this example.

Use state threading#

[12]:
from typing import Any
import pytreeclass as tc
import jax
import optax
import jax.numpy as jnp


@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0

    def __call__(self, x: jax.Array, intermediate: tuple[Any, ...]):
        x = x + self.a
        # store intermediate variables
        return x, intermediate + (x,)


def loss_func(tree: Tree, x: jax.Array, y: jax.Array, intermediate: tuple[Any, ...]):
    ypred, intermediate = tree(x, intermediate)
    loss = jnp.mean((ypred - y) ** 2)
    return loss, intermediate


@jax.jit
def train_step(
    tree: Tree,
    optim_state: optax.OptState,
    x: jax.Array,
    y: jax.Array,
    intermediate: tuple[Any, ...],
):
    grads, intermediate = jax.grad(loss_func, has_aux=True)(tree, x, y, intermediate)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state, intermediate


tree = Tree()
optim = optax.adam(1e-1)
optim_state = optim.init(tree)

x = jnp.linspace(-1, 1, 5)[:, None]
y = x**2

intermediate = ()

for i in range(2):
    tree, optim_state, intermediate = train_step(tree, optim_state, x, y, intermediate)


print("Intermediate values:\t\n", intermediate)
print("\nFinal tree:\t\n", tree)
Intermediate values:
 (Array([[0. ],
       [0.5],
       [1. ],
       [1.5],
       [2. ]], dtype=float32), Array([[-0.09999937],
       [ 0.40000063],
       [ 0.90000063],
       [ 1.4000006 ],
       [ 1.9000006 ]], dtype=float32))

Final tree:
 Tree(a=0.801189)

Using oryx#

[13]:
from typing import Any
import pytreeclass as tc
import jax
import optax
import jax.numpy as jnp
import oryx


@tc.autoinit
class Tree(tc.TreeClass):
    a: float = 1.0

    def __call__(self, x: jax.Array):
        x = x + self.a
        # store intermediate variables with oryx
        x = oryx.core.sow(x, tag="intermediates", name="x")
        return x


def loss_func(tree: Tree, x: jax.Array, y: jax.Array):
    ypred = tree(x)
    loss = jnp.mean((ypred - y) ** 2)
    return loss


@jax.jit
def train_step(
    tree: Tree,
    optim_state: optax.OptState,
    x: jax.Array,
    y: jax.Array,
):
    grads = jax.grad(loss_func)(tree, x, y)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state


tree = Tree()
optim = optax.adam(1e-1)
optim_state = optim.init(tree)

x = jnp.linspace(-1, 1, 5)[:, None]
y = x**2

intermediate = ()

train_step_reap = oryx.core.reap(train_step, tag="intermediates")

for i in range(2):
    intermediate += (train_step_reap(tree, optim_state, x, y),)
    tree, optim_state = train_step(tree, optim_state, x, y)


print("Intermediate values:\t\n", intermediate)
print("\nFinal tree:\t\n", tree)
Intermediate values:
 ({'x': Array([[0. ],
       [0.5],
       [1. ],
       [1.5],
       [2. ]], dtype=float32)}, {'x': Array([[-0.09999937],
       [ 0.40000063],
       [ 0.90000063],
       [ 1.4000006 ],
       [ 1.9000006 ]], dtype=float32)})

Final tree:
 Tree(a=0.801189)

[11] Create layers from configuration files#

The next example shows how to use pytreeclass.bcmap to loop over a configuration dictionary that defines creation of simple linear layers.

[14]:
import pytreeclass as tc
import jax


class Linear(tc.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, *, key: jax.random.KeyArray):
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jnp.zeros((out_dim,))

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight + self.bias


config = {
    # each layer gets a different input dimension
    "in_dim": [1, 2, 3, 4],
    # out_dim is broadcasted to all layers
    "out_dim": 1,
    # each layer gets a different key
    "key": list(jax.random.split(jax.random.PRNGKey(0), 4)),
}


# `bcmap` transforms a function that takes a single input into a function that
# arbitrary pytree inputs. in case of a single input, the input is broadcasted
# to match the tree structure of the first argument
# (in our example is a list of 4 inputs)


@tc.bcmap
def build_layer(in_dim, out_dim, *, key: jax.random.KeyArray):
    return Linear(in_dim, out_dim, key=key)


build_layer(config["in_dim"], config["out_dim"], key=config["key"])
[14]:
[Linear(
   weight=f32[1,1](ฮผ=0.31, ฯƒ=0.00, โˆˆ[0.31,0.31]),
   bias=f32[1](ฮผ=0.00, ฯƒ=0.00, โˆˆ[0.00,0.00])
 ),
 Linear(
   weight=f32[2,1](ฮผ=-1.27, ฯƒ=0.33, โˆˆ[-1.59,-0.94]),
   bias=f32[1](ฮผ=0.00, ฯƒ=0.00, โˆˆ[0.00,0.00])
 ),
 Linear(
   weight=f32[3,1](ฮผ=0.24, ฯƒ=0.53, โˆˆ[-0.48,0.77]),
   bias=f32[1](ฮผ=0.00, ฯƒ=0.00, โˆˆ[0.00,0.00])
 ),
 Linear(
   weight=f32[4,1](ฮผ=-0.28, ฯƒ=0.21, โˆˆ[-0.64,-0.08]),
   bias=f32[1](ฮผ=0.00, ฯƒ=0.00, โˆˆ[0.00,0.00])
 )]

[12] Model ensembles using jax.vmap#

In this example, simple Linear layers are grouped by their weight on the first axis using jax.vmap. This is useful if the different instances of the model are desired to run in a vectorized fashion (model ensemble).

For more check here

[15]:
import jax
import jax.numpy as jnp
import jax.random as jr
import pytreeclass as tc
import functools as ft
from typing import Generic, TypeVar

T = TypeVar("T")


class Batched(Generic[T]):
    ...


class Linear(tc.TreeClass):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        *,
        key: jr.KeyArray,
        name: str,
    ):
        self.weight = jr.normal(key, (in_dim, out_dim))
        self.bias = jnp.zeros((out_dim,))
        self.name = name  # non-jax type for `tree_mask`/`tree_unmask` demonstration

    def __call__(self, x: jax.Array) -> jax.Array:
        return x @ self.weight + self.bias


class FNN(tc.TreeClass):
    def __init__(self, key: jr.KeyArray):
        k1, k2, k3 = jr.split(key, 3)
        self.l1 = Linear(1, 10, key=k1, name="l1")
        self.l2 = Linear(10, 10, key=k2, name="l2")
        self.l3 = Linear(10, 1, key=k3, name="l3")

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.l1(x)
        x = jax.nn.relu(x)
        x = self.l2(x)
        x = jax.nn.relu(x)
        x = self.l3(x)
        return x


def build_ensemble(keys: jr.KeyArray) -> Batched[FNN]:
    @jax.vmap
    def build_liner(key: jr.KeyArray):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return tc.tree_mask(FNN(key=key))

    return tc.tree_unmask(build_liner(keys))


def run_single_input_ensemble(fnns: Batched[FNN], x: jax.Array):
    def run_linear(fnn: FNN):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return tc.tree_mask(fnn(x))

    return jax.vmap(run_linear)(tc.tree_mask(fnns))


def run_multi_input_ensemble(fnns: Batched[FNN], x: Batched[jax.Array]):
    def run_linear(fnn: FNN, x: jax.Array):
        # `jax.vmap` require jax-type return
        # so use `tree_mask` on return
        return tc.tree_mask(fnn(x))

    return jax.vmap(run_linear)(tc.tree_mask(fnns), x)


num_layers = 4
keys = jr.split(jr.PRNGKey(0), num_layers)

# single input ensemble
# e.g. each model in the ensemble gets the same input
x = jnp.ones([10, 1])
fnns = build_ensemble(keys=keys)
y = run_single_input_ensemble(fnns, x)
print(f"Single input ensemble shape:\t{y.shape}")

# multi input ensemble
# e.g. each model in the ensemble gets a different input
xs = jnp.stack([x, x * 2, x * 3, x * 4])
fnns = build_ensemble(keys=keys)
ys = run_multi_input_ensemble(fnns, xs)
print(f"Multi input ensemble shape:\t{ys.shape}")
Single input ensemble shape:    (4, 10, 1)
Multi input ensemble shape:     (4, 10, 1)

[13] Functional method chaining.#

If a certain class has a method that mutate its internal state, then .at[method_name].__call__(*args,**kwargs) is used to return a tuple of method return value and a new instance. This example shows how to leverage .at to enable method chaining by using the at functionality.

The objective is to achieve the following pattern in a functional way.

instance = instance.method1(...).method2(...).method3(...)
[16]:
import pytreeclass as tc


@tc.autoinit
class Tree(tc.TreeClass):
    a: int

    def _add(self, x):
        self.a += x

    def add(self, x):
        # use `.at` to return the new instance
        # and avoid mutating the original instance
        _, self = self.at["_add"](x)
        # return the new instance and discard the return value
        return self

    def _mul(self, x):
        self.a *= x

    def mul(self, x):
        # use `.at` to return the new instance
        # and avoid mutating the original instance
        _, self = self.at["_mul"](x)
        # return the new instance and discard the return value
        return self


tree0 = Tree(a=1)
tree1 = tree0.add(1).mul(2).add(1)  # ((1 + 1) * 2) + 1 = 5

print("original instance:\t", tree0)
print("new instance:\t\t", tree1)
original instance:       Tree(a=1)
new instance:            Tree(a=5)

[14] Using regular expression masking.#

In this example, positive values of tree leaves with name starts with weight_ will be manipulated.

[17]:
import pytreeclass as tc
import jax
import jax.numpy as jnp
import re


@tc.autoinit
class Tree(tc.TreeClass):
    weight_1: jax.Array = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    weight_2: jax.Array = jnp.array([-1, -2, -3, -4, -5, 6, 7, 8, 9, 10])
    bias: jax.Array = jnp.ones(10)


tree = Tree()

positive_mask = jax.tree_map(lambda x: x > 0, tree)  # positive mask
tree = tree.at[positive_mask][re.compile(r"weight_.*")].apply(lambda x: x**2)
print(tree)
Tree(
  weight_1=[  1   4   9  16  25  36  49  64  81 100],
  weight_2=[ -1  -2  -3  -4  -5  36  49  64  81 100],
  bias=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
)

[15] Creating buffers#

In this example, certain array will be marked as non-trainable using jax.lax.stop_gradient and field

[18]:
import pytreeclass as tc
import jax
import jax.numpy as jnp


@tc.autoinit
class Tree(tc.TreeClass):
    buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient])

    def __call__(self, x):
        return self.buffer**x


tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0)  # Array([1., 4., 9.], dtype=float32)


@jax.jit
def f(tree, x):
    return jnp.sum(tree(x))


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))
6.0
Tree(buffer=[0. 0. 0.])

[16] Creating Frozen fields#

In this example, field value freezing is done on class level using on_geatattr, and on_setattr. This effectively hide the field value across jax transformation

[19]:
import pytreeclass as tc
import jax


@tc.autoinit
class Tree(tc.TreeClass):
    frozen_a: int = tc.field(on_getattr=[tc.unfreeze], on_setattr=[tc.freeze])

    def __call__(self, x):
        return self.frozen_a + x


tree = Tree(frozen_a=1)  # 1 is non-jaxtype
# can be used in jax transformations


@jax.jit
def f(tree, x):
    return tree(x)


print(f(tree, 1.0))
print(jax.grad(f)(tree, 1.0))
2.0
Tree(frozen_a=#1)

[17] Parameterization#

In this example, field value is parameterized using on_getattr,

[20]:
import pytreeclass as tc
import jax.numpy as jnp


def symmetric(array: jax.Array) -> jax.Array:
    triangle = jnp.triu(array)  # upper triangle
    return triangle + triangle.transpose(-1, -2)


@tc.autoinit
class Tree(tc.TreeClass):
    symmetric_matrix: jax.Array = tc.field(on_getattr=[symmetric])


tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
print(tree.symmetric_matrix)
[[ 0  1  2]
 [ 1  8  5]
 [ 2  5 16]]

[18] Working on data pipelines#

In this example, AtIndexer is used in similar fashion to PyFunctional to work on general data pipelines.

[21]:
from pytreeclass import AtIndexer
import jax


class Transaction:
    def __init__(self, reason, amount):
        self.reason = reason
        self.amount = amount


# this example copied from  https://github.com/EntilZha/PyFunctional
transactions = [
    Transaction("github", 7),
    Transaction("food", 10),
    Transaction("coffee", 5),
    Transaction("digitalocean", 5),
    Transaction("food", 5),
    Transaction("riotgames", 25),
    Transaction("food", 10),
    Transaction("amazon", 200),
    Transaction("paycheck", -1000),
]

indexer = AtIndexer(transactions)
where = jax.tree_map(lambda x: x.reason == "food", transactions)
food_cost = indexer[where].reduce(lambda x, y: x + y.amount, initializer=0)
food_cost
[21]:
25