➕ Advanced API#
- pytreeclass.bcmap(func, *, is_leaf=None)[source]#
Map a function over pytree leaves with automatic broadcasting for scalar arguments.
- Parameters:
func (
Callable) – the function to be mapped over the pytreeis_leaf (
Optional[Callable[[Any],bool]]) – a predicate function that returns True if the node is a leaf
- Return type:
Callable
Example
>>> import jax >>> import pytreeclass as tc >>> import functools as ft
>>> @tc.autoinit ... @tc.leafwise ... class Test(tc.TreeClass): ... a: tuple[int, int, int] = (1, 2, 3) ... b: tuple[int, int, int] = (4, 5, 6) ... c: jax.Array = jnp.array([1, 2, 3])
>>> tree = Test()
>>> # 0 is broadcasted to all leaves of the pytree >>> print(tc.bcmap(jnp.where)(tree > 1, tree, 0)) Test(a=(0, 2, 3), b=(4, 5, 6), c=[0 2 3]) >>> print(tc.bcmap(jnp.where)(tree > 1, 0, tree)) Test(a=(1, 0, 0), b=(0, 0, 0), c=[1 0 0])
>>> # 1 is broadcasted to all leaves of the list pytree >>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], 1) [2, 3, 4]
>>> # trees are summed leaf-wise >>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], [1, 2, 3]) [2, 4, 6]
>>> # Non scalar second args case >>> try: ... tc.bcmap(lambda x, y: x + y)([1, 2, 3], [[1, 2, 3], [1, 2, 3]]) ... except TypeError as e: ... print(e) unsupported operand type(s) for +: 'int' and 'list'
>>> # using **numpy** functions on pytrees >>> import jax.numpy as jnp >>> tc.bcmap(jnp.add)([1, 2, 3], [1, 2, 3]) [2, 4, 6]
- class pytreeclass.Partial(func, *args, **kwargs)#
jax-ablePartialfunction with support for positional partial application.Example
>>> import pytreeclass as tc >>> def f(a, b, c): ... print(f"a: {a}, b: {b}, c: {c}") ... return a + b + c
>>> # positional arguments using `...` placeholder >>> f_a = tc.Partial(f, ..., 2, 3) >>> f_a(1) a: 1, b: 2, c: 3 6
>>> # keyword arguments >>> f_b = tc.Partial(f, b=2, c=3) >>> f_a(1) a: 1, b: 2, c: 3 6
Note
The
...is used to indicate a placeholder for positional arguments.Partial()is used internally bybcmap()which maps a function over pytrees leaves with automatic broadcasting for scalar arguments.
- class pytreeclass.AtIndexer(tree: PyTree, where: tuple[BaseKey | PyTree] | tuple[()] = ())#
Index a pytree at a given path using a path or mask.
- Parameters:
tree – pytree to index
where –
one of the following:
strfor mapping keys or class attributes.intfor positional indexing for sequences....to select all leaves.a boolean mask of the same structure as the tree
re.Patternto index all keys matching a regex pattern.an instance of
BaseKeywith custom logic to index a pytree.a tuple of the above to match multiple keys at the same level.
Example
>>> # use `AtIndexer` on a pytree (e.g. dict,list,tuple,etc.) >>> import jax >>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': None} >>> # get multiple keys at once at the same level >>> indexer["level1_0"]["level2_0", "level2_1"].get() {'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None} >>> # get with a mask >>> mask = {"level1_0": {"level2_0": True, "level2_1": False}, "level1_1": True} >>> indexer[mask].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': 300}
Example
>>> # use ``AtIndexer`` in a class >>> import jax.tree_util as jtu >>> import pytreeclass as tc >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b): ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... kva = (jtu.GetAttrKey("a"), self.a) ... kvb = (jtu.GetAttrKey("b"), self.b) ... return (kva, kvb), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) ... def __repr__(self) -> str: ... return f"{self.__class__.__name__}(a={self.a}, b={self.b})" >>> Tree(1, 2).at["a"].get() Tree(a=1, b=None)
- get(*, is_leaf=None)[source]#
Get the leaf values at the specified location.
- Parameters:
is_leaf (
Optional[Callable[[Any],bool]]) – a predicate function to determine if a value is a leaf.- Return type:
Any- Returns:
A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array.
Example
>>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': None}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # get ``a`` and return a new instance >>> # with ``None`` for all other leaves >>> tree.at['a'].get() Tree(a=1, b=None)
- set(set_value, *, is_leaf=None)[source]#
Set the leaf values at the specified location.
- Parameters:
set_value (
Any) – the value to set at the specified location.is_leaf (
Optional[Callable[[Any],bool]]) – a predicate function to determine if a value is a leaf.
- Returns:
A pytree with the leaf values at the specified location set to
set_value.
Example
>>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].set('SET') {'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # set ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].set(100) Tree(a=100, b=2)
- apply(func, *, is_leaf=None, parallel=None)[source]#
Apply a function to the leaf values at the specified location.
- Parameters:
func (Callable[[Any], Any]) – the function to apply to the leaf values.
is_leaf (IsLeafType) – a predicate function to determine if a value is a leaf.
parallel (ParallelApplyKwargs | bool | None) –
accepts the following:
None: applyfuncto the leaves in serial.bool: applyfuncin parallel ifTrueotherwise in serial.dict: a dict of of:max_workers: maximum number of workers to use.callback: a function to apply to the result offunc.kind: kind of pool to use, either"thread"or"process".
- Returns:
A pytree with the leaf values at the specified location set to the result of applying
functo the leaf values.
Example
>>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].apply(lambda _: 'SET') {'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # apply to ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].apply(lambda _: 100) Tree(a=100, b=2)
Example
>>> # read images in parallel >>> import pytreeclass as tc >>> from matplotlib.pyplot import imread >>> indexer = tc.AtIndexer({"lenna": "lenna.png", "baboon": "baboon.png"}) >>> images = indexer[...].apply(imread, parallel=dict(max_workers=2))
- scan(func, state, *, is_leaf=None)[source]#
Apply a function while carrying a state.
- Parameters:
func – the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state.
state – the initial state to carry.
is_leaf – a predicate function to determine if a value is a leaf. for example,
lambda x: isinstance(x, list)will treat all lists as leaves and will not recurse into list items.
- Returns:
A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying
functo the leaf values.
Example
>>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> def scan_func(leaf, state): ... return 'SET', state + 1 >>> init_state = 0 >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].scan(scan_func, state=init_state) ({'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}, 1)
Example
>>> import pytreeclass as tc >>> from typing import NamedTuple >>> class State(NamedTuple): ... func_evals: int = 0 >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int ... c: int >>> tree = Tree(a=1, b=2, c=3) >>> def scan_func(leaf, state: State): ... state = State(state.func_evals + 1) ... return leaf + 1, state >>> # apply to ``a`` and ``b`` and return a new instance with all other >>> # leaves unchanged and the new state that counts the number of >>> # function evaluations >>> tree.at['a','b'].scan(scan_func, state=State()) (Tree(a=2, b=3, c=3), State(func_evals=2))
Note
scanapplies a binaryfuncto the leaf values while carrying a state and returning a tree leaves with the thefuncapplied to them with final state. Whilereduceapplies a binaryfuncto the leaf values while carrying a state and returning a single value.
- reduce(func, *, initializer=<object object>, is_leaf=None)[source]#
Reduce the leaf values at the specified location.
- Parameters:
func (
Callable[[Any,Any],Any]) – the function to reduce the leaf values.initializer (
Any) – the initializer value for the reduction.is_leaf (
Optional[Callable[[Any],bool]]) – a predicate function to determine if a value is a leaf.
- Return type:
Any- Returns:
The result of reducing the leaf values at the specified location.
Note
If
initializeris not specified, the first leaf value is used as the initializer.reduceapplies a binaryfuncto each leaf values while accumulating a state a returns the final result. whilescanappliesfuncto each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applyingfuncto each leaf.
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> tree.at[...].reduce(lambda a, b: a + b, initializer=0) 3
- class pytreeclass.BaseKey[source]#
Parent class for all match classes.
Subclass this class to create custom match keys by implementing the __eq__ method. The
__eq__method should return True if the key matches the given path entry and False otherwise. The path entry refers to the entry defined in thetree_flatten_with_keysmethod of the pytree class.Typical path entries are:
jax.tree_util.GetAttrKeyfor attributesjax.tree_util.DictKeyfor mapping keysjax.tree_util.SequenceKeyfor sequence indices
When implementing the
__eq__method you can use thesingledispatchmethodto unpack the path entry for example:jax.tree_util.GetAttrKey-> key.namejax.tree_util.DictKey-> key.keyjax.tree_util.SequenceKey-> key.index
See Examples for more details.
Example
>>> # define an match strategy to match a leaf with a given name and type >>> import pytreeclass as tc >>> from typing import NamedTuple >>> import jax >>> class NameTypeContainer(NamedTuple): ... name: str ... type: type >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b) -> None: ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... ak = (NameTypeContainer("a", type(self.a)), self.a) ... bk = (NameTypeContainer("b", type(self.b)), self.b) ... return (ak, bk), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) >>> tree = Tree(1, 2) >>> class MatchNameType(tc.BaseKey): ... def __init__(self, name, type): ... self.name = name ... self.type = type ... def __eq__(self, other): ... if isinstance(other, NameTypeContainer): ... return other == (self.name, self.type) ... return False >>> tree = tree.at[MatchNameType("a", int)].get() >>> assert jax.tree_util.tree_leaves(tree) == [1]
Note
use
BaseKey.def_alias(type, func)to define an index type alias for BaseKey subclasses. This is useful for convience when creating new match strategies.>>> import pytreeclass as tc >>> import functools as ft >>> from types import FunctionType >>> import jax.tree_util as jtu >>> # lets define a new match strategy called `FuncKey` that applies >>> # a function to the path entry and returns True if the function >>> # returns True and False otherwise. >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match >>> # all leaves that start with "a". >>> class FuncKey(tc.BaseKey): ... def __init__(self, func): ... self.func = func ... @ft.singledispatchmethod ... def __eq__(self, key): ... return self.func(key) ... @__eq__.register(jtu.GetAttrKey) ... def _(self, key: jtu.GetAttrKey): ... # unpack the GetAttrKey ... return self.func(key.name) ... @__eq__.register(jtu.DictKey) ... def _(self, key: jtu.DictKey): ... # unpack the DictKey ... return self.func(key.key) ... @__eq__.register(jtu.SequenceKey) ... def _(self, key: jtu.SequenceKey): ... return self.func(key.index)
>>> # instead of using ``FuncKey(function)`` we can define an alias >>> # for `FuncKey`, for this example we will define any FunctionType >>> # as a `FuncKey` by default. >>> @tc.BaseKey.def_alias(FunctionType) ... def _(func): ... return FuncKey(func) >>> # create a simple pytree >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: str >>> tree = Tree(1, "string") >>> # now we can use the `FuncKey` alias to match all leaves that >>> # are strings and start with "a" >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() Tree(a=1, b=None)