➕ Advanced API#
- pytreeclass.bcmap(func, *, is_leaf=None)[source]#
Map a function over pytree leaves with automatic broadcasting for scalar arguments.
- Parameters:
func (Callable[P, T]) – the function to be mapped over the pytree
is_leaf (Callable[[Any], bool] | None) – a predicate function that returns True if the node is a leaf
- Return type:
Callable[P, T]
Example
>>> import jax >>> import pytreeclass as tc >>> import functools as ft
>>> @tc.autoinit ... @tc.leafwise ... class Test(tc.TreeClass): ... a: tuple[int, int, int] = (1, 2, 3) ... b: tuple[int, int, int] = (4, 5, 6) ... c: jax.Array = jnp.array([1, 2, 3])
>>> tree = Test()
>>> # 0 is broadcasted to all leaves of the pytree >>> print(tc.bcmap(jnp.where)(tree > 1, tree, 0)) Test(a=(0, 2, 3), b=(4, 5, 6), c=[0 2 3]) >>> print(tc.bcmap(jnp.where)(tree > 1, 0, tree)) Test(a=(1, 0, 0), b=(0, 0, 0), c=[1 0 0])
>>> # 1 is broadcasted to all leaves of the list pytree >>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], 1) [2, 3, 4]
>>> # trees are summed leaf-wise >>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], [1, 2, 3]) [2, 4, 6]
>>> # Non scalar second args case >>> try: ... tc.bcmap(lambda x, y: x + y)([1, 2, 3], [[1, 2, 3], [1, 2, 3]]) ... except TypeError as e: ... print(e) unsupported operand type(s) for +: 'int' and 'list'
>>> # using **numpy** functions on pytrees >>> import jax.numpy as jnp >>> tc.bcmap(jnp.add)([1, 2, 3], [1, 2, 3]) [2, 4, 6]
- class pytreeclass.Partial(func, *args, **kwargs)#
Partial
function with support for positional partial application.- Parameters:
func (
Callable
[...
,Any
]) – The function to be partially applied.args (
Any
) – Positional arguments to be partially applied. use...
as a placeholder for positional arguments.kwargs (
Any
) – Keyword arguments to be partially applied.
Example
>>> import pytreeclass as tc >>> def f(a, b, c): ... print(f"a: {a}, b: {b}, c: {c}") ... return a + b + c
>>> # positional arguments using `...` placeholder >>> f_a = tc.Partial(f, ..., 2, 3) >>> f_a(1) a: 1, b: 2, c: 3 6
>>> # keyword arguments >>> f_b = tc.Partial(f, b=2, c=3) >>> f_a(1) a: 1, b: 2, c: 3 6
Note
The
...
is used to indicate a placeholder for positional arguments.
- class pytreeclass.AtIndexer(tree: PyTree, where: tuple[BaseKey | PyTree] | tuple[()] = ())#
Index a pytree at a given path using a path or mask.
- Parameters:
tree – pytree to index
where –
one of the following:
str
for mapping keys or class attributes.int
for positional indexing for sequences....
to select all leaves.a boolean mask of the same structure as the tree
re.Pattern
to index all keys matching a regex pattern.an instance of
BaseKey
with custom logic to index a pytree.a tuple of the above to match multiple keys at the same level.
Example
>>> # use `AtIndexer` on a pytree (e.g. dict,list,tuple,etc.) >>> import jax >>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': None} >>> # get multiple keys at once at the same level >>> indexer["level1_0"]["level2_0", "level2_1"].get() {'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None} >>> # get with a mask >>> mask = {"level1_0": {"level2_0": True, "level2_1": False}, "level1_1": True} >>> indexer[mask].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': 300}
Example
>>> # use ``AtIndexer`` in a class >>> import jax.tree_util as jtu >>> import pytreeclass as tc >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b): ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... kva = (jtu.GetAttrKey("a"), self.a) ... kvb = (jtu.GetAttrKey("b"), self.b) ... return (kva, kvb), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) ... def __repr__(self) -> str: ... return f"{self.__class__.__name__}(a={self.a}, b={self.b})" >>> Tree(1, 2).at["a"].get() Tree(a=1, b=None)
- get(*, is_leaf=None, is_parallel=False)[source]#
Get the leaf values at the specified location.
- Parameters:
is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Return type:
PyTree
- Returns:
A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array.
Example
>>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) # construct an indexer >>> indexer["b"][0].get() # get the first element of "b" {'a': None, 'b': [1, None, None]}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # get ``a`` and return a new instance >>> # with ``None`` for all other leaves >>> tree.at['a'].get() Tree(a=1, b=None)
- set(set_value, *, is_leaf=None, is_parallel=False)[source]#
Set the leaf values at the specified location.
- Parameters:
set_value (Any) – the value to set at the specified location.
is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Return type:
PyTree
- Returns:
A pytree with the leaf values at the specified location set to
set_value
.
Example
>>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) >>> indexer["b"][0].set(100) # set the first element of "b" to 100 {'a': 1, 'b': [100, 2, 3]}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # set ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].set(100) Tree(a=100, b=2)
- apply(func, *, is_leaf=None, is_parallel=False)[source]#
Apply a function to the leaf values at the specified location.
- Parameters:
func (Callable[[Any], Any]) – the function to apply to the leaf values.
is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Return type:
PyTree
- Returns:
A pytree with the leaf values at the specified location set to the result of applying
func
to the leaf values.
Example
>>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) >>> indexer["b"][0].apply(lambda x: x + 100) # add 100 to the first element of "b" {'a': 1, 'b': [101, 2, 3]}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # apply to ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].apply(lambda _: 100) Tree(a=100, b=2)
Example
>>> # read images in parallel >>> import pytreeclass as tc >>> from matplotlib.pyplot import imread >>> indexer = tc.AtIndexer({"lenna": "lenna.png", "baboon": "baboon.png"}) >>> images = indexer[...].apply(imread, parallel=dict(max_workers=2))
- scan(func, state, *, is_leaf=None)[source]#
Apply a function while carrying a state.
- Parameters:
func – the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state.
state – the initial state to carry.
is_leaf – a predicate function to determine if a value is a leaf. for example,
lambda x: isinstance(x, list)
will treat all lists as leaves and will not recurse into list items.
- Returns:
A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying
func
to the leaf values.
Example
>>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> def scan_func(leaf, state): ... return 'SET', state + 1 >>> init_state = 0 >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].scan(scan_func, state=init_state) ({'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}, 1)
Example
>>> import pytreeclass as tc >>> from typing import NamedTuple >>> class State(NamedTuple): ... func_evals: int = 0 >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int ... c: int >>> tree = Tree(a=1, b=2, c=3) >>> def scan_func(leaf, state: State): ... state = State(state.func_evals + 1) ... return leaf + 1, state >>> # apply to ``a`` and ``b`` and return a new instance with all other >>> # leaves unchanged and the new state that counts the number of >>> # function evaluations >>> tree.at['a','b'].scan(scan_func, state=State()) (Tree(a=2, b=3, c=3), State(func_evals=2))
Note
scan
applies a binaryfunc
to the leaf values while carrying a state and returning a tree leaves with the thefunc
applied to them with final state. Whilereduce
applies a binaryfunc
to the leaf values while carrying a state and returning a single value.
- reduce(func, *, initializer=<object object>, is_leaf=None)[source]#
Reduce the leaf values at the specified location.
- Parameters:
func (Callable[[Any, Any], Any]) – the function to reduce the leaf values.
initializer (Any) – the initializer value for the reduction.
is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.
- Return type:
Any
- Returns:
The result of reducing the leaf values at the specified location.
Note
If
initializer
is not specified, the first leaf value is used as the initializer.reduce
applies a binaryfunc
to each leaf values while accumulating a state a returns the final result. whilescan
appliesfunc
to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applyingfunc
to each leaf.
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> tree.at[...].reduce(lambda a, b: a + b, initializer=0) 3
- class pytreeclass.BaseKey[source]#
Parent class for all match classes.
Subclass this class to create custom match keys by implementing the __eq__ method. The
__eq__
method should return True if the key matches the given path entry and False otherwise. The path entry refers to the entry defined in thetree_flatten_with_keys
method of the pytree class.Typical path entries in
jax
are:jax.tree_util.GetAttrKey
for attributesjax.tree_util.DictKey
for mapping keysjax.tree_util.SequenceKey
for sequence indices
When implementing the
__eq__
method you can use thesingledispatchmethod
to unpack the path entry for example:jax.tree_util.GetAttrKey
-> key.namejax.tree_util.DictKey
-> key.keyjax.tree_util.SequenceKey
-> key.index
See Examples for more details.
Example
>>> # define an match strategy to match a leaf with a given name and type >>> import pytreeclass as tc >>> from typing import NamedTuple >>> import jax >>> class NameTypeContainer(NamedTuple): ... name: str ... type: type >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b) -> None: ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... ak = (NameTypeContainer("a", type(self.a)), self.a) ... bk = (NameTypeContainer("b", type(self.b)), self.b) ... return (ak, bk), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) >>> tree = Tree(1, 2) >>> class MatchNameType(tc.BaseKey): ... def __init__(self, name, type): ... self.name = name ... self.type = type ... def __eq__(self, other): ... if isinstance(other, NameTypeContainer): ... return other == (self.name, self.type) ... return False >>> tree = tree.at[MatchNameType("a", int)].get() >>> assert jax.tree_util.tree_leaves(tree) == [1]
Note
use
BaseKey.def_alias(type, func)
to define an index type alias for BaseKey subclasses. This is useful for convience when creating new match strategies.>>> import pytreeclass as tc >>> import functools as ft >>> from types import FunctionType >>> import jax.tree_util as jtu >>> # lets define a new match strategy called `FuncKey` that applies >>> # a function to the path entry and returns True if the function >>> # returns True and False otherwise. >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match >>> # all leaves that start with "a". >>> class FuncKey(tc.BaseKey): ... def __init__(self, func): ... self.func = func ... @ft.singledispatchmethod ... def __eq__(self, key): ... return self.func(key) ... @__eq__.register(jtu.GetAttrKey) ... def _(self, key: jtu.GetAttrKey): ... # unpack the GetAttrKey ... return self.func(key.name) ... @__eq__.register(jtu.DictKey) ... def _(self, key: jtu.DictKey): ... # unpack the DictKey ... return self.func(key.key) ... @__eq__.register(jtu.SequenceKey) ... def _(self, key: jtu.SequenceKey): ... return self.func(key.index) >>> # instead of using ``FuncKey(function)`` we can define an alias >>> # for `FuncKey`, for this example we will define any FunctionType >>> # as a `FuncKey` by default. >>> @tc.BaseKey.def_alias(FunctionType) ... def _(func): ... return FuncKey(func) >>> # create a simple pytree >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: str >>> tree = Tree(1, "string") >>> # now we can use the `FuncKey` alias to match all leaves that >>> # are strings and start with "a" >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() Tree(a=1, b=None)