🎯 Core API#
- class pytreeclass.TreeClass(*a, **k)#
Convert a class to a pytree by inheriting from
TreeClass
.A pytree is any nested structure of containers and leaves. A container is a pytree can be a container or a leaf. Container examples are: a
tuple
,list
, ordict
. A leaf is a non-container data structure like anint
,float
,string
, orArray
.TreeClass
is a container pytree that holds other pytrees in its attributes.Note
TreeClass
is immutable by default. This means that setting or deleting attributes after initialization is not allowed. This behavior is intended to prevent accidental mutation of the tree. All tree modifications on TreeClass are out-of-place. This means that all tree modifications return a new instance of the tree with the modified values.There are two ways to set or delete attributes after initialization:
Using
at
property to modify an existing leaf of the tree.>>> import pytreeclass as tc >>> class Tree(tc.TreeClass): ... def __init__(self, leaf: int): ... self.leaf = leaf >>> tree = Tree(leaf=1) >>> new_tree = tree.at["leaf"].set(100) >>> tree is new_tree # new instance is created False
Using
at[mutating_method_name]
to call a mutating method and apply the mutation on a copy of the tree. This option allows writing methods that mutate the tree instance but with these updates applied on a copy of the tree.>>> import pytreeclass as tc >>> class Tree(tc.TreeClass): ... def __init__(self, leaf: int): ... self.leaf = leaf ... def add_leaf(self, name:str, value:int) -> None: ... # this method mutates the tree instance ... # and will raise an `AttributeError` if called directly. ... setattr(self, name, value) >>> tree = Tree(leaf=1) >>> # now lets try to call `add_leaf` directly >>> tree.add_leaf(name="new_leaf", value=100) Cannot set attribute value=100 to `key='new_leaf'` on an immutable instance of `Tree`. >>> # now lets try to call `add_leaf` using `at["add_leaf"]` >>> method_output, new_tree = tree.at["add_leaf"](name="new_leaf", value=100) >>> new_tree Tree(leaf=1, new_leaf=100)
This pattern is useful to write freely mutating methods, but with The expense of having to call through at[“method_name”] instead of calling the method directly.
Note
pytreeclass
offers two methods to construct the__init__
method:Manual
__init__
method>>> import pytreeclass as tc >>> class Tree(tc.TreeClass): ... def __init__(self, a:int, b:float): ... self.a = a ... self.b = b >>> tree = Tree(a=1, b=2.0)
Auto generated
__init__
method from type annotations.Either by
dataclasses.dataclasss
or by usingautoinit()
decorator where the type annotations are used to generate the__init__
method similar todataclasses.dataclass
. Compared todataclasses.dataclass
,autoinit`()
withfield()
objects can be used to apply functions on the field values during initialization, support multiple argument kinds, and can apply functions on field values on getting the value. For more details seeautoinit()
andfield()
.>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int ... b:float >>> tree = Tree(a=1, b=2.0)
Note
Leaf-wise math operations are supported using
leafwise
decorator.leafwise
decorator adds__add__
,__sub__
,__mul__
, … etc to registered pytrees. These methods apply math operations to each leaf of the tree. for example:>>> @tc.leafwise ... @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree + 1 # will add 1 to each leaf Tree(a=2, b=3.0)
Note
Advanced indexing is supported using
at
property. Indexing can be used toget
,set
, orapply
a function to a leaf or a group of leaves usingleaf
name, index or by a boolean mask.>>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree.at["a"].get() Tree(a=1, b=None) >>> tree.at[0].get() Tree(a=1, b=None)
Note
Under
jax.tree_util.***
allTreeClass
attributes are treated as leaves.To hide/ignore a specific attribute from the tree leaves, during
jax.tree_util.***
operations, freeze the leaf usingfreeze()
ortree_mask()
.
>>> # freeze(exclude) a leaf from the tree leaves: >>> import jax >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a:int = 1 ... b:float = 2.0 >>> tree = Tree() >>> tree = tree.at["a"].apply(tc.freeze) >>> jax.tree_util.tree_leaves(tree) [2.0]
>>> # undo the freeze >>> tree = tree.at["a"].apply(tc.unfreeze, is_leaf=tc.is_frozen) >>> jax.tree_util.tree_leaves(tree) [1, 2.0]
>>> # using `tree_mask` to exclude a leaf from the tree leaves >>> freeze_mask = Tree(a=True, b=False) >>> jax.tree_util.tree_leaves(tc.tree_mask(tree, freeze_mask)) [2.0]
Note
TreeClass
inherits fromabc.ABC
so@abstract...
decorators can be used to define abstract behavior.
Warning
The structure should be organized as a tree. In essence, cyclic references are not allowed. The leaves of the tree are the values of the tree and the branches are the containers that hold the leaves.
- property at: TreeClassIndexer#
Immutable out-of-place indexing.
.at[***].get()
:Return a new instance with the value at the index otherwise None.
.at[***].set(value)
:Set the value and return a new instance with the updated value.
.at[***].apply(func)
:Apply a
func
and return a new instance with the updated value.
.at['method'](*a, **k)
:Call a
method
and return a (return value, new instance) tuple.
- Acceptable indexing types are:
str
for mapping keys or class attributes.int
for positional indexing for sequences....
to select all leaves.a boolean mask of the same structure as the tree
re.Pattern
to index all keys matching a regex pattern.an instance of
BaseKey
with custom logic to index a pytree.a tuple of the above types to index multiple keys at same level.
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int = 1 ... b: float = 2.0 ... def add(self, x: int) -> int: ... self.a += x ... return self.a >>> tree = Tree() >>> # get `a` and return a new instance >>> # with `None` for all other leaves >>> tree.at["a"].get() Tree(a=1, b=None) >>> # set `a` and return a new instance >>> # with all other leaves unchanged >>> tree.at["a"].set(100) Tree(a=100, b=2.0) >>> # apply to `a` and return a new instance >>> # with all other leaves unchanged >>> tree.at["a"].apply(lambda x: 100) Tree(a=100, b=2.0) >>> # call `add` and return a tuple of >>> # (return value, new instance) >>> tree.at["add"](99) (100, Tree(a=100, b=2.0))
Note
pytree.at[*][**]
is equivalent to selecting pytree.*.** .pytree.at[*, **]
is equivalent selecting pytree.* and pytree.**
Note
AttributeError
is raised, If a method that mutates the instance is called directly. Instead useat["method_name"]
to call a method that mutates the instance.
Example
Building immutable chainable methods with
at
:The following example shows how to build a chainable methods using
at
property. Note that while the methods are mutating the instance, the mutation is applied on a copy of the tree and the original tree is not mutated.>>> import pytreeclass as tc >>> class Tree(tc.TreeClass): ... def set_x(self, x): ... self.x = x ... def set_y(self, y): ... self.y = y ... def calculate(self): ... return self.x + self.y >>> tree = Tree() >>> tree.at["set_x"](x=1)[1].at["set_y"](y=2)[1].calculate() 3
- pytreeclass.autoinit(klass)[source]#
A class decorator that generates the
__init__
method from type hints.Similar to
dataclasses.dataclass
, this decorator generates the__init__
method for the given class from the type hints or thefield()
objects set to the class attributes.Compared to
dataclasses.dataclass
,autoinit
withfield()
objects can be used to apply functions on the field values during initialization, and/or support multiple argument kinds.Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree: ... x: int ... y: int >>> tree = Tree(1, 2) >>> tree.x, tree.y (1, 2)
Example
>>> # define fields with different argument kinds >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree: ... kw_only_field: int = tc.field(default=1, kind="KW_ONLY") ... pos_only_field: int = tc.field(default=2, kind="POS_ONLY")
Example
>>> # define a converter to apply ``abs`` on the field value >>> @tc.autoinit ... class Tree: ... a:int = tc.field(on_setattr=[abs]) >>> Tree(a=-1).a 1
Warning
The
autoinit
decorator will is no-op if the class already has a user-defined__init__
method.
Note
In case of inheritance, the
__init__
method is generated from the the type hints of the current class and any base classes that are decorated withautoinit
.
>>> import pytreeclass as tc >>> import inspect >>> @tc.autoinit ... class Base: ... x: int >>> @tc.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(x=1, y=2) >>> inspect.signature(obj.__init__) <Signature (x: int, y: int) -> None>
Base classes that are not decorated with
autoinit
are ignored during synthesis of the__init__
method.
>>> import pytreeclass as tc >>> import inspect >>> class Base: ... x: int >>> @tc.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(y=2) >>> inspect.signature(obj.__init__) <Signature (y: int) -> None>
Note
Use
autoinit
instead ofdataclasses.dataclass
if you want to usejax.Array
as a field default value. Asdataclasses.dataclass
will incorrectly raise an error starting from python 3.11 complaining thatjax.Array
is not immutable.Note
By default
autoinit
will raise an error if the user uses mutable defaults. To register an additional type to be excluded fromautoinit
, useautoinit.register_excluded_type()
, with an optionalreason
for excluding the type.>>> import pytreeclass as tc >>> class T: ... pass >>> tc.autoinit.register_excluded_type(T, reason="not allowed") >>> @tc.autoinit ... class Tree: ... x: T = tc.field(default=T()) Traceback (most recent call last): ...
- pytreeclass.leafwise(klass)[source]#
A class decorator that adds leafwise operators to a class.
Leafwise operators are operators that are applied to the leaves of a pytree. For example leafwise
__add__
is equivalent to:tree_map(lambda x: x + rhs, tree)
ifrhs
is a scalar.tree_map(lambda x, y: x + y, tree, rhs)
ifrhs
is a pytree with the same structure astree
.
- Parameters:
klass – The class to be decorated.
- Returns:
The decorated class.
Example
>>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise`` >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> @tc.leafwise ... @tc.autoinit ... class Point(tc.TreeClass): ... x: float = 0.5 ... y: float = 1.0 ... description: str = "point coordinates" >>> # use :func:`tree_mask` to mask the non-inexact part of the tree >>> # i.e. mask the string leaf ``description`` to ``Point`` work >>> # with ``jax.numpy`` functions >>> co = tc.tree_mask(Point()) >>> print(tc.bcmap(jnp.where)(co > 0.5, co, 1000)) Point(x=1000.0, y=1.0, description=#point coordinates)
Note
If a mathematically equivalent operator is already defined on the class, then it is not overridden.
Method
Operator
__add__
+
__and__
&
__ceil__
math.ceil
__divmod__
divmod
__eq__
==
__floor__
math.floor
__floordiv__
//
__ge__
>=
__gt__
>
__invert__
~
__le__
<=
__lshift__
<<
__lt__
<
__matmul__
@
__mod__
%
__mul__
*
__ne__
!=
__neg__
-
__or__
|
__pos__
+
__pow__
**
__round__
round
__sub__
-
__truediv__
/
__trunc__
math.trunc
__xor__
^
- pytreeclass.is_tree_equal(*trees)[source]#
Return
True
if all pytrees are equal. :rtype: bool | arraylib.ndarrayNote
trees are compared using their leaves and treedefs.
Note
Under boolean
Array
if compiled otherwisebool
.
- pytreeclass.field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None)[source]#
Field placeholder for type hinted attributes.
- Parameters:
default – The default value of the field.
init – Whether the field is included in the object’s
__init__
function.repr – Whether the field is included in the object’s
__repr__
function.kind –
Argument kind, one of:
POS_ONLY
: positional only argument (e.g.x
indef f(x, /):
)VAR_POS
: variable positional argument (e.g.*x
indef f(*x):
)POS_OR_KW
: positional or keyword argument (e.g.x
indef f(x):
)KW_ONLY
: keyword only argument (e.g.x
indef f(*, x):
)VAR_KW
: variable keyword argument (e.g.**x
indef f(**x):
)
metadata – A mapping of user-defined data for the field.
on_setattr – A sequence of functions to called on
__setattr__
.on_getattr – A sequence of functions to called on
__getattr__
.alias – An a alias for the field name in the constructor. e.g
name=x
,alias=y
will allowobj = Class(y=1)
to be equivalent toobj = Class(x=1)
.
Example
Type and range validation using
on_setattr
:>>> import pytreeclass as tc >>> @tc.autoinit ... class IsInstance(tc.TreeClass): ... klass: type ... def __call__(self, x): ... assert isinstance(x, self.klass) ... return x >>> @tc.autoinit ... class Range(tc.TreeClass): ... start: int|float = float("-inf") ... stop: int|float = float("inf") ... def __call__(self, x): ... assert self.start <= x <= self.stop ... return x >>> @tc.autoinit ... class Employee(tc.TreeClass): ... # assert employee ``name`` is str ... name: str = tc.field(on_setattr=[IsInstance(str)]) ... # use callback compostion to assert employee ``age`` is int and positive ... age: int = tc.field(on_setattr=[IsInstance(int), Range(1)]) >>> employee = Employee(name="Asem", age=10) >>> print(employee) Employee(name=Asem, age=10)
Example
Private attribute using
alias
:>>> import pytreeclass as tc >>> @tc.autoinit ... class Employee(tc.TreeClass): ... # `alias` is the name used in the constructor ... _name: str = tc.field(alias="name") >>> employee = Employee(name="Asem") # use `name` in the constructor >>> print(employee) # `_name` is the private attribute name Employee(_name=Asem)
Example
Buffer creation using
on_getattr
:>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> @tc.autoinit ... class Tree(tc.TreeClass): ... buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient]) >>> tree = Tree(buffer=jnp.array((1.0, 2.0))) >>> def sum_buffer(tree): ... return tree.buffer.sum() >>> print(jax.grad(sum_buffer)(tree)) # no gradient on `buffer` Tree(buffer=[0. 0.])
Example
Parameterization using
on_getattr
:>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> def symmetric(array: jax.Array) -> jax.Array: ... triangle = jnp.triu(array) # upper triangle ... return triangle + triangle.transpose(-1, -2) >>> @tc.autoinit ... class Tree(tc.TreeClass): ... symmetric_matrix: jax.Array = tc.field(on_getattr=[symmetric]) >>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3)) >>> print(tree.symmetric_matrix) [[ 0 1 2] [ 1 8 5] [ 2 5 16]]
Note
field()
is commonly used to annotate the class attributes to be used by theautoinit()
decorator to generate the__init__
method similar todataclasses.dataclass
.field()
can be used without theautoinit()
as a descriptor to apply functions on the field values during initialization using theon_setattr
/on_getattr
argument.>>> import pytreeclass as tc >>> def print_and_return(x): ... print(f"Setting {x}") ... return x >>> class Tree: ... # `a` must be defined as a class attribute for the descriptor to work ... a: int = tc.field(on_setattr=[print_and_return]) ... def __init__(self, a): ... self.a = a >>> tree = Tree(1) Setting 1
- pytreeclass.fields(x)[source]#
Returns a tuple of
Field
objects for the given instance or class.Field
objects are generated from the class type hints and contains the information about the field information.if the user uses thepytreeclass.field
to annotate.Note
If the class is not annotated, an empty tuple is returned.
The
Field
generation is cached for class and its bases.