🎯 Core API#
- class pytreeclass.TreeClass(*a, **k)#
Convert a class to a
jax-compatible pytree by inheriting fromTreeClass.A pytree is any nested structure that can be used with
jaxfunctions. A pytree can be a container or a leaf. Container examples are: atuple,list, ordict. A leaf is a non-container data structure like anint,float,string, orjax.Array.TreeClassis a container pytree that holds other pytrees in its attributes.Note
pytreeclassoffers two methods to define the__init__method:Manual
__init__method>>> import pytreeclass as tc >>> class Tree(tc.TreeClass): ... def __init__(self, a:int, b:float): ... self.a = a ... self.b = b >>> tree = Tree(a=1, b=2.0)
Auto generated
__init__methodEither by
dataclasses.dataclasssor by usingautoinit()decorator where the type annotations are used to generate the__init__method similar todataclasses.dataclass. Compared todataclasses.dataclass,autoinitwithfield()objects can be used to apply functions on the field values during initialization, and/or support multiple argument kinds. For more details seeautoinit()andfield().>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int ... b:float >>> tree = Tree(a=1, b=2.0)
Note
Leaf-wise math operations are supported using
leafwisedecorator.leafwisedecorator applies math operations to each leaf of the tree. for example:>>> @tc.leafwise ... @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree + 1 Tree(a=2, b=3.0)
Note
Advanced indexing is supported using
atproperty. Indexing can be used toget,set, orapplya function to a leaf or a group of leaves usingleafname, index or by a boolean mask.>>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at[0].get() Tree(a=1, b=None)
Note
Under
jax.tree_util.***allTreeClassattributes are treated as leaves.To hide/ignore a specific attribute from the tree leaves, during
jax.tree_util.***operations, freeze the leaf usingfreeze()ortree_mask().
>>> # freeze(exclude) a leaf from the tree leaves: >>> import jax >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree = tree.at["a"].apply(tc.freeze) >>> jax.tree_util.tree_leaves(tree) [2.0]
>>> # undo the freeze >>> tree = tree.at["a"].apply(tc.unfreeze, is_leaf=tc.is_frozen) >>> jax.tree_util.tree_leaves(tree) [1, 2.0]
>>> # using `tree_mask` to exclude a leaf from the tree leaves >>> freeze_mask = Tree(a=True, b=False) >>> jax.tree_util.tree_leaves(tc.tree_mask(tree, freeze_mask)) [2.0]
Note
TreeClassinherits fromabc.ABCso@abstract...decorators can be used to define abstract behavior.
Warning
The structure should be organized as a tree. In essence, cyclic references are not allowed. The leaves of the tree are the values of the tree and the branches are the containers that hold the leaves.
- pytreeclass.autoinit(klass)[source]#
A class decorator that generates the
__init__method from type hints.Similar to
dataclasses.dataclass, this decorator generates the__init__method for the given class from the type hints or thefield()objects set to the class attributes.Compared to
dataclasses.dataclass,autoinitwithfield()objects can be used to apply functions on the field values during initialization, and/or support multiple argument kinds.Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree: ... x: int ... y: int >>> tree = Tree(1, 2) >>> tree.x, tree.y (1, 2)
Example
>>> # define fields with different argument kinds >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree: ... kw_only_field: int = tc.field(default=1, kind="KW_ONLY") ... pos_only_field: int = tc.field(default=2, kind="POS_ONLY")
Example
>>> # define a converter to apply ``abs`` on the field value >>> @tc.autoinit ... class Tree: ... a:int = tc.field(on_setattr=[abs]) >>> Tree(a=-1).a 1
Warning
The
autoinitdecorator will is no-op if the class already has a user-defined__init__method.
Note
In case of inheritance, the
__init__method is generated from the the type hints of the current class and any base classes that are decorated withautoinit.
>>> import pytreeclass as tc >>> import inspect >>> @tc.autoinit ... class Base: ... x: int >>> @tc.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(x=1, y=2) >>> inspect.signature(obj.__init__) <Signature (x: int, y: int) -> None>
Base classes that are not decorated with
autoinitare ignored during synthesis of the__init__method.
>>> import pytreeclass as tc >>> import inspect >>> class Base: ... x: int >>> @tc.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(y=2) >>> inspect.signature(obj.__init__) <Signature (y: int) -> None>
Note
Use
autoinitinstead ofdataclasses.dataclassif you want to usejax.Arrayas a field default value. Asdataclasses.dataclasswill incorrectly raise an error starting from python 3.11 complaining thatjax.Arrayis not immutable.
- pytreeclass.leafwise(klass)[source]#
A class decorator that adds leafwise operators to a class.
Leafwise operators are operators that are applied to the leaves of a pytree. For example leafwise
__add__is equivalent to:jax.tree_map(lambda x: x + rhs, tree)ifrhsis a scalar.jax.tree_map(lambda x, y: x + y, tree, rhs)ifrhsis a pytree with the same structure astree.
- Parameters:
klass – The class to be decorated.
- Returns:
The decorated class.
Example
>>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise`` >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> @tc.leafwise ... @tc.autoinit ... class Point(tc.TreeClass): ... x: float = 0.5 ... y: float = 1.0 ... description: str = "point coordinates" >>> # use :func:`tree_mask` to mask the non-inexact part of the tree >>> # i.e. mask the string leaf ``description`` to ``Point`` work >>> # with ``jax.numpy`` functions >>> co = tc.tree_mask(Point()) >>> print(tc.bcmap(jnp.where)(co > 0.5, co, 1000)) Point(x=1000.0, y=1.0, description=#point coordinates)
Note
If a mathematically equivalent operator is already defined on the class, then it is not overridden.
Method
Operator
__add__+__and__&__ceil__math.ceil__divmod__divmod__eq__==__floor__math.floor__floordiv__//__ge__>=__gt__>__invert__~__le__<=__lshift__<<__lt__<__matmul__@__mod__%__mul__*__ne__!=__neg__-__or__|__pos__+__pow__**__round__round__sub__-__truediv__/__trunc__math.trunc__xor__^
- pytreeclass.is_tree_equal(*trees)[source]#
Return
Trueif all pytrees are equal. :rtype: bool | jax.ArrayNote
trees are compared using their leaves and treedefs.
Note
Under
jitthe return type is boolean jax.Array instead ofbool.
- pytreeclass.field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None)[source]#
Field placeholder for type hinted attributes.
- Parameters:
default – The default value of the field.
init – Whether the field is included in the object’s
__init__function.repr – Whether the field is included in the object’s
__repr__function.kind –
Argument kind, one of:
POS_ONLY: positional only argument (e.g.xindef f(x, /):)VAR_POS: variable positional argument (e.g.*xindef f(*x):)POS_OR_KW: positional or keyword argument (e.g.xindef f(x):)KW_ONLY: keyword only argument (e.g.xindef f(*, x):)VAR_KW: variable keyword argument (e.g.**xindef f(**x):)
metadata – A mapping of user-defined data for the field.
on_setattr – A sequence of functions to called on
__setattr__.on_getattr – A sequence of functions to called on
__getattr__.alias – An a alias for the field name in the constructor. e.g
name=x,alias=ywill allowobj = Class(y=1)to be equivalent toobj = Class(x=1).
Example
Type and range validation using
on_setattr:>>> import pytreeclass as tc >>> @tc.autoinit ... class IsInstance(tc.TreeClass): ... klass: type ... def __call__(self, x): ... assert isinstance(x, self.klass) ... return x >>> @tc.autoinit ... class Range(tc.TreeClass): ... start: int|float = float("-inf") ... stop: int|float = float("inf") ... def __call__(self, x): ... assert self.start <= x <= self.stop ... return x >>> @tc.autoinit ... class Employee(tc.TreeClass): ... # assert employee ``name`` is str ... name: str = tc.field(on_setattr=[IsInstance(str)]) ... # use callback compostion to assert employee ``age`` is int and positive ... age: int = tc.field(on_setattr=[IsInstance(int), Range(1)]) >>> employee = Employee(name="Asem", age=10) >>> print(employee) Employee(name=Asem, age=10)
Example
Private attribute using
alias:>>> import pytreeclass as tc >>> @tc.autoinit ... class Employee(tc.TreeClass): ... # `alias` is the name used in the constructor ... _name: str = tc.field(alias="name") >>> employee = Employee(name="Asem") # use `name` in the constructor >>> print(employee) # `_name` is the private attribute name Employee(_name=Asem)
Example
Buffer creation using
on_getattr:>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> @tc.autoinit ... class Tree(tc.TreeClass): ... buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient]) >>> tree = Tree(buffer=jnp.array((1.0, 2.0))) >>> def sum_buffer(tree): ... return tree.buffer.sum() >>> print(jax.grad(sum_buffer)(tree)) # no gradient on `buffer` Tree(buffer=[0. 0.])
Example
Parameterization using
on_getattr:>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> def symmetric(array: jax.Array) -> jax.Array: ... triangle = jnp.triu(array) # upper triangle ... return triangle + triangle.transpose(-1, -2) >>> @tc.autoinit ... class Tree(tc.TreeClass): ... symmetric_matrix: jax.Array = tc.field(on_getattr=[symmetric]) >>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3)) >>> print(tree.symmetric_matrix) [[ 0 1 2] [ 1 8 5] [ 2 5 16]]
Note
field()is commonly used to annotate the class attributes to be used by theautoinit()decorator to generate the__init__method similar todataclasses.dataclass.field()can be used without theautoinit()as a descriptor to apply functions on the field values during initialization using theon_setattr/on_getattrargument.>>> import pytreeclass as tc >>> def print_and_return(x): ... print(f"Setting {x}") ... return x >>> class Tree: ... # `a` must be defined as a class attribute for the descriptor to work ... a: int = tc.field(on_setattr=[print_and_return]) ... def __init__(self, a): ... self.a = a >>> tree = Tree(1) Setting 1
- pytreeclass.fields(x)[source]#
Returns a tuple of
Fieldobjects for the given instance or class.Fieldobjects are generated from the class type hints and contains the information about the field information.if the user uses thepytreeclass.fieldto annotate.Note
If the class is not annotated, an empty tuple is returned.
The
Fieldgeneration is cached for class and its bases.