➕ Advanced API#

pytreeclass.bcmap(func, *, is_leaf=None)[source]#

Map a function over pytree leaves with automatic broadcasting for scalar arguments.

Parameters:
  • func (Callable[P, T]) – the function to be mapped over the pytree

  • is_leaf (Callable[[Any], bool] | None) – a predicate function that returns True if the node is a leaf

Return type:

Callable[P, T]

Example

>>> import jax
>>> import pytreeclass as tc
>>> import functools as ft
>>> @tc.autoinit
... @tc.leafwise
... class Test(tc.TreeClass):
...    a: tuple[int, int, int] = (1, 2, 3)
...    b: tuple[int, int, int] = (4, 5, 6)
...    c: jax.Array = jnp.array([1, 2, 3])
>>> tree = Test()
>>> # 0 is broadcasted to all leaves of the pytree
>>> print(tc.bcmap(jnp.where)(tree > 1, tree, 0))
Test(a=(0, 2, 3), b=(4, 5, 6), c=[0 2 3])
>>> print(tc.bcmap(jnp.where)(tree > 1, 0, tree))
Test(a=(1, 0, 0), b=(0, 0, 0), c=[1 0 0])
>>> # 1 is broadcasted to all leaves of the list pytree
>>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], 1)
[2, 3, 4]
>>> # trees are summed leaf-wise
>>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], [1, 2, 3])
[2, 4, 6]
>>> # Non scalar second args case
>>> try:
...     tc.bcmap(lambda x, y: x + y)([1, 2, 3], [[1, 2, 3], [1, 2, 3]])
... except TypeError as e:
...     print(e)
unsupported operand type(s) for +: 'int' and 'list'
>>> # using **numpy** functions on pytrees
>>> import jax.numpy as jnp
>>> tc.bcmap(jnp.add)([1, 2, 3], [1, 2, 3]) 
[2, 4, 6]
class pytreeclass.Partial(func, *args, **kwargs)#

Partial function with support for positional partial application.

Parameters:
  • func (Callable[..., Any]) – The function to be partially applied.

  • args (Any) – Positional arguments to be partially applied. use ... as a placeholder for positional arguments.

  • kwargs (Any) – Keyword arguments to be partially applied.

Example

>>> import pytreeclass as tc
>>> def f(a, b, c):
...     print(f"a: {a}, b: {b}, c: {c}")
...     return a + b + c
>>> # positional arguments using `...` placeholder
>>> f_a = tc.Partial(f, ..., 2, 3)
>>> f_a(1)
a: 1, b: 2, c: 3
6
>>> # keyword arguments
>>> f_b = tc.Partial(f, b=2, c=3)
>>> f_a(1)
a: 1, b: 2, c: 3
6

Note

class pytreeclass.AtIndexer(tree: PyTree, where: tuple[BaseKey | PyTree] | tuple[()] = ())#

Index a pytree at a given path using a path or mask.

Parameters:
  • tree – pytree to index

  • where

    one of the following:

    • str for mapping keys or class attributes.

    • int for positional indexing for sequences.

    • ... to select all leaves.

    • a boolean mask of the same structure as the tree

    • re.Pattern to index all keys matching a regex pattern.

    • an instance of BaseKey with custom logic to index a pytree.

    • a tuple of the above to match multiple keys at the same level.

Example

>>> # use `AtIndexer` on a pytree (e.g. dict,list,tuple,etc.)
>>> import jax
>>> import pytreeclass as tc
>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> indexer = tc.AtIndexer(tree)
>>> indexer["level1_0"]["level2_0"].get()
{'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': None}
>>> # get multiple keys at once at the same level
>>> indexer["level1_0"]["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # get with a mask
>>> mask = {"level1_0": {"level2_0": True, "level2_1": False}, "level1_1": True}
>>> indexer[mask].get()
{'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': 300}

Example

>>> # use ``AtIndexer`` in a class
>>> import jax.tree_util as jtu
>>> import pytreeclass as tc
>>> @jax.tree_util.register_pytree_with_keys_class
... class Tree:
...    def __init__(self, a, b):
...        self.a = a
...        self.b = b
...    def tree_flatten_with_keys(self):
...        kva = (jtu.GetAttrKey("a"), self.a)
...        kvb = (jtu.GetAttrKey("b"), self.b)
...        return (kva, kvb), None
...    @classmethod
...    def tree_unflatten(cls, aux_data, children):
...        return cls(*children)
...    @property
...    def at(self):
...        return tc.AtIndexer(self)
...    def __repr__(self) -> str:
...        return f"{self.__class__.__name__}(a={self.a}, b={self.b})"
>>> Tree(1, 2).at["a"].get()
Tree(a=1, b=None)
get(*, is_leaf=None, is_parallel=False)[source]#

Get the leaf values at the specified location.

Parameters:
  • is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:
      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Return type:

PyTree

Returns:

A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array.

Example

>>> import pytreeclass as tc
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> indexer = tc.AtIndexer(tree)  # construct an indexer
>>> indexer["b"][0].get()  # get the first element of "b"
{'a': None, 'b': [1, None, None]}

Example

>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree(tc.TreeClass):
...     a: int
...     b: int
>>> tree = Tree(a=1, b=2)
>>> # get ``a`` and return a new instance
>>> # with ``None`` for all other leaves
>>> tree.at['a'].get()
Tree(a=1, b=None)
set(set_value, *, is_leaf=None, is_parallel=False)[source]#

Set the leaf values at the specified location.

Parameters:
  • set_value (Any) – the value to set at the specified location.

  • is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:
      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Return type:

PyTree

Returns:

A pytree with the leaf values at the specified location set to set_value.

Example

>>> import pytreeclass as tc
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> indexer = tc.AtIndexer(tree)
>>> indexer["b"][0].set(100)  # set the first element of "b" to 100
{'a': 1, 'b': [100, 2, 3]}

Example

>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree(tc.TreeClass):
...     a: int
...     b: int
>>> tree = Tree(a=1, b=2)
>>> # set ``a`` and return a new instance
>>> # with all other leaves unchanged
>>> tree.at['a'].set(100)
Tree(a=100, b=2)
apply(func, *, is_leaf=None, is_parallel=False)[source]#

Apply a function to the leaf values at the specified location.

Parameters:
  • func (Callable[[Any], Any]) – the function to apply to the leaf values.

  • is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:
      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Return type:

PyTree

Returns:

A pytree with the leaf values at the specified location set to the result of applying func to the leaf values.

Example

>>> import pytreeclass as tc
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> indexer = tc.AtIndexer(tree)
>>> indexer["b"][0].apply(lambda x: x + 100)  # add 100 to the first element of "b"
{'a': 1, 'b': [101, 2, 3]}

Example

>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree(tc.TreeClass):
...     a: int
...     b: int
>>> tree = Tree(a=1, b=2)
>>> # apply to ``a`` and return a new instance
>>> # with all other leaves unchanged
>>> tree.at['a'].apply(lambda _: 100)
Tree(a=100, b=2)

Example

>>> # read images in parallel
>>> import pytreeclass as tc
>>> from matplotlib.pyplot import imread
>>> indexer = tc.AtIndexer({"lenna": "lenna.png", "baboon": "baboon.png"})
>>> images = indexer[...].apply(imread, parallel=dict(max_workers=2))  
scan(func, state, *, is_leaf=None)[source]#

Apply a function while carrying a state.

Parameters:
  • func – the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state.

  • state – the initial state to carry.

  • is_leaf – a predicate function to determine if a value is a leaf. for example, lambda x: isinstance(x, list) will treat all lists as leaves and will not recurse into list items.

Returns:

A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying func to the leaf values.

Example

>>> import pytreeclass as tc
>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> def scan_func(leaf, state):
...     return 'SET', state + 1
>>> init_state = 0
>>> indexer = tc.AtIndexer(tree)
>>> indexer["level1_0"]["level2_0"].scan(scan_func, state=init_state)
({'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}, 1)

Example

>>> import pytreeclass as tc
>>> from typing import NamedTuple
>>> class State(NamedTuple):
...     func_evals: int = 0
>>> @tc.autoinit
... class Tree(tc.TreeClass):
...     a: int
...     b: int
...     c: int
>>> tree = Tree(a=1, b=2, c=3)
>>> def scan_func(leaf, state: State):
...     state = State(state.func_evals + 1)
...     return leaf + 1, state
>>> # apply to ``a`` and ``b`` and return a new instance with all other
>>> # leaves unchanged and the new state that counts the number of
>>> # function evaluations
>>> tree.at['a','b'].scan(scan_func, state=State())
(Tree(a=2, b=3, c=3), State(func_evals=2))

Note

scan applies a binary func to the leaf values while carrying a state and returning a tree leaves with the the func applied to them with final state. While reduce applies a binary func to the leaf values while carrying a state and returning a single value.

reduce(func, *, initializer=<object object>, is_leaf=None)[source]#

Reduce the leaf values at the specified location.

Parameters:
  • func (Callable[[Any, Any], Any]) – the function to reduce the leaf values.

  • initializer (Any) – the initializer value for the reduction.

  • is_leaf (Callable[[Any], None] | None) – a predicate function to determine if a value is a leaf.

Return type:

Any

Returns:

The result of reducing the leaf values at the specified location.

Note

  • If initializer is not specified, the first leaf value is used as the initializer.

  • reduce applies a binary func to each leaf values while accumulating a state a returns the final result. while scan applies func to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applying func to each leaf.

Example

>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree(tc.TreeClass):
...     a: int
...     b: int
>>> tree = Tree(a=1, b=2)
>>> tree.at[...].reduce(lambda a, b: a + b, initializer=0)
3
class pytreeclass.BaseKey[source]#

Parent class for all match classes.

  • Subclass this class to create custom match keys by implementing the __eq__ method. The __eq__ method should return True if the key matches the given path entry and False otherwise. The path entry refers to the entry defined in the tree_flatten_with_keys method of the pytree class.

  • Typical path entries in jax are:

    • jax.tree_util.GetAttrKey for attributes

    • jax.tree_util.DictKey for mapping keys

    • jax.tree_util.SequenceKey for sequence indices

  • When implementing the __eq__ method you can use the singledispatchmethod to unpack the path entry for example:

    • jax.tree_util.GetAttrKey -> key.name

    • jax.tree_util.DictKey -> key.key

    • jax.tree_util.SequenceKey -> key.index

    See Examples for more details.

Example

>>> # define an match strategy to match a leaf with a given name and type
>>> import pytreeclass as tc
>>> from typing import NamedTuple
>>> import jax
>>> class NameTypeContainer(NamedTuple):
...     name: str
...     type: type
>>> @jax.tree_util.register_pytree_with_keys_class
... class Tree:
...    def __init__(self, a, b) -> None:
...        self.a = a
...        self.b = b
...    def tree_flatten_with_keys(self):
...        ak = (NameTypeContainer("a", type(self.a)), self.a)
...        bk = (NameTypeContainer("b", type(self.b)), self.b)
...        return (ak, bk), None
...    @classmethod
...    def tree_unflatten(cls, aux_data, children):
...        return cls(*children)
...    @property
...    def at(self):
...        return tc.AtIndexer(self)
>>> tree = Tree(1, 2)
>>> class MatchNameType(tc.BaseKey):
...    def __init__(self, name, type):
...        self.name = name
...        self.type = type
...    def __eq__(self, other):
...        if isinstance(other, NameTypeContainer):
...            return other == (self.name, self.type)
...        return False
>>> tree = tree.at[MatchNameType("a", int)].get()
>>> assert jax.tree_util.tree_leaves(tree) == [1]

Note

  • use BaseKey.def_alias(type, func) to define an index type alias for BaseKey subclasses. This is useful for convience when creating new match strategies.

    >>> import pytreeclass as tc
    >>> import functools as ft
    >>> from types import FunctionType
    >>> import jax.tree_util as jtu
    >>> # lets define a new match strategy called `FuncKey` that applies
    >>> # a function to the path entry and returns True if the function
    >>> # returns True and False otherwise.
    >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match
    >>> # all leaves that start with "a".
    >>> class FuncKey(tc.BaseKey):
    ...    def __init__(self, func):
    ...        self.func = func
    ...    @ft.singledispatchmethod
    ...    def __eq__(self, key):
    ...        return self.func(key)
    ...    @__eq__.register(jtu.GetAttrKey)
    ...    def _(self, key: jtu.GetAttrKey):
    ...        # unpack the GetAttrKey
    ...        return self.func(key.name)
    ...    @__eq__.register(jtu.DictKey)
    ...    def _(self, key: jtu.DictKey):
    ...        # unpack the DictKey
    ...        return self.func(key.key)
    ...    @__eq__.register(jtu.SequenceKey)
    ...    def _(self, key: jtu.SequenceKey):
    ...        return self.func(key.index)
    >>> # instead of using ``FuncKey(function)`` we can define an alias
    >>> # for `FuncKey`, for this example we will define any FunctionType
    >>> # as a `FuncKey` by default.
    >>> @tc.BaseKey.def_alias(FunctionType)
    ... def _(func):
    ...    return FuncKey(func)
    >>> # create a simple pytree
    >>> @tc.autoinit
    ... class Tree(tc.TreeClass):
    ...    a: int
    ...    b: str
    >>> tree = Tree(1, "string")
    >>> # now we can use the `FuncKey` alias to match all leaves that
    >>> # are strings and start with "a"
    >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get()
    Tree(a=1, b=None)
    
abstract __eq__(entry)[source]#

Return self==value.

Return type:

bool