Source code for pytreeclass._src.tree_index

# Copyright 2023 pytreeclass authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define lens-like indexing/masking for pytrees."""

# enable get/set/apply/scan/reduce operations on selected parts of a nested
# structure -pytree- in out-of-place manner. this process invovles defining two
# parts: 1) *where* to select the parts of the pytree and 2) *what* to do with
# the selected parts. the *where* part is defined either by a path or a boolean
# mask. the *what* part is defined by a set value, or a function to apply to
# the selected parts. once we have a *final* boolean mask that encompasses all
# path and the boolean mask, we can use `tree_map` to apply the *what* part to
# the *where* part. for example, for a tree = [[1, 2], 3, 4] and boolean mask
# [[True, False], False, True] and path mask [0][1], then we select only leaf
# 1 that is at the intersection of the boolean mask and the path mask. then we
# apply the *what* part to the *where* part.

from __future__ import annotations

import abc
import functools as ft
import re
from typing import Any, Callable, Hashable, NamedTuple, Tuple, TypeVar

from pytreeclass._src.backend import arraylib, treelib
from pytreeclass._src.backend.treelib.base import ParallelConfig

T = TypeVar("T")
S = TypeVar("S")
PyTree = Any
EllipsisType = TypeVar("EllipsisType")
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
TypeEntry = TypeVar("TypeEntry", bound=type)
TraceEntry = Tuple[KeyEntry, TypeEntry]
KeyPath = Tuple[KeyEntry, ...]
TypePath = Tuple[TypeEntry, ...]
TraceType = Tuple[KeyPath, TypePath]
_no_initializer = object()

SequenceKeyType = type(treelib.sequence_key(0))
DictKeyType = type(treelib.dict_key("key"))
GetAttrKeyType = type(treelib.attribute_key("name"))


[docs]class BaseKey(abc.ABC): """Parent class for all match classes. - Subclass this class to create custom match keys by implementing the `__eq__` method. The ``__eq__`` method should return True if the key matches the given path entry and False otherwise. The path entry refers to the entry defined in the ``tree_flatten_with_keys`` method of the pytree class. - Typical path entries in ``jax`` are: - ``jax.tree_util.GetAttrKey`` for attributes - ``jax.tree_util.DictKey`` for mapping keys - ``jax.tree_util.SequenceKey`` for sequence indices - When implementing the ``__eq__`` method you can use the ``singledispatchmethod`` to unpack the path entry for example: - ``jax.tree_util.GetAttrKey`` -> `key.name` - ``jax.tree_util.DictKey`` -> `key.key` - ``jax.tree_util.SequenceKey`` -> `key.index` See Examples for more details. Example: >>> # define an match strategy to match a leaf with a given name and type >>> import pytreeclass as tc >>> from typing import NamedTuple >>> import jax >>> class NameTypeContainer(NamedTuple): ... name: str ... type: type >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b) -> None: ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... ak = (NameTypeContainer("a", type(self.a)), self.a) ... bk = (NameTypeContainer("b", type(self.b)), self.b) ... return (ak, bk), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) >>> tree = Tree(1, 2) >>> class MatchNameType(tc.BaseKey): ... def __init__(self, name, type): ... self.name = name ... self.type = type ... def __eq__(self, other): ... if isinstance(other, NameTypeContainer): ... return other == (self.name, self.type) ... return False >>> tree = tree.at[MatchNameType("a", int)].get() >>> assert jax.tree_util.tree_leaves(tree) == [1] Note: - use ``BaseKey.def_alias(type, func)`` to define an index type alias for `BaseKey` subclasses. This is useful for convience when creating new match strategies. >>> import pytreeclass as tc >>> import functools as ft >>> from types import FunctionType >>> import jax.tree_util as jtu >>> # lets define a new match strategy called `FuncKey` that applies >>> # a function to the path entry and returns True if the function >>> # returns True and False otherwise. >>> # for example `FuncKey(lambda x: x.startswith("a"))` will match >>> # all leaves that start with "a". >>> class FuncKey(tc.BaseKey): ... def __init__(self, func): ... self.func = func ... @ft.singledispatchmethod ... def __eq__(self, key): ... return self.func(key) ... @__eq__.register(jtu.GetAttrKey) ... def _(self, key: jtu.GetAttrKey): ... # unpack the GetAttrKey ... return self.func(key.name) ... @__eq__.register(jtu.DictKey) ... def _(self, key: jtu.DictKey): ... # unpack the DictKey ... return self.func(key.key) ... @__eq__.register(jtu.SequenceKey) ... def _(self, key: jtu.SequenceKey): ... return self.func(key.index) >>> # instead of using ``FuncKey(function)`` we can define an alias >>> # for `FuncKey`, for this example we will define any FunctionType >>> # as a `FuncKey` by default. >>> @tc.BaseKey.def_alias(FunctionType) ... def _(func): ... return FuncKey(func) >>> # create a simple pytree >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: str >>> tree = Tree(1, "string") >>> # now we can use the `FuncKey` alias to match all leaves that >>> # are strings and start with "a" >>> tree.at[lambda x: isinstance(x, str) and x.startswith("a")].get() Tree(a=1, b=None) """
[docs] @abc.abstractmethod def __eq__(self, entry: KeyEntry) -> bool: pass
class IntKey(BaseKey): def __init__(self, idx: int) -> None: self.idx = idx @ft.singledispatchmethod def __eq__(self, _: KeyEntry) -> bool: return False @__eq__.register(int) def _(self, other: int) -> bool: return self.idx == other @__eq__.register(SequenceKeyType) def _(self, other: SequenceKeyType) -> bool: return self.idx == other.idx class NameKey(BaseKey): def __init__(self, name: str) -> None: self.name = name @ft.singledispatchmethod def __eq__(self, _: KeyEntry) -> bool: return False @__eq__.register(str) def _(self, other: str) -> bool: return self.name == other @__eq__.register(GetAttrKeyType) def _(self, other: GetAttrKeyType) -> bool: return self.name == other.name @__eq__.register(DictKeyType) def _(self, other: DictKeyType) -> bool: return self.name == other.key class EllipsisKey(BaseKey): """Match all leaves.""" def __init__(self, _): del _ def __eq__(self, _: KeyEntry) -> bool: return True class MultiKey(BaseKey): """Match a leaf with multiple keys at the same level.""" def __init__(self, *keys: tuple[BaseKey, ...]): self.keys = tuple(keys) def __eq__(self, entry) -> bool: return any(entry == key for key in self.keys) class RegexKey(BaseKey): """Match a leaf with a regex pattern inside 'at' property. Args: pattern: regex pattern to match. Example: >>> import pytreeclass as tc >>> import re >>> @tc.autoinit ... class Tree(tc.TreeClass): ... weight_1: float = 1.0 ... weight_2: float = 2.0 ... weight_3: float = 3.0 ... bias: float = 0.0 >>> tree = Tree() >>> tree.at[re.compile(r"weight_.*")].set(100.0) # set all weights to 100.0 Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0) """ def __init__(self, pattern: str) -> None: self.pattern = pattern @ft.singledispatchmethod def __eq__(self, _: KeyEntry) -> bool: return False @__eq__.register(str) def _(self, other: str) -> bool: return re.fullmatch(self.pattern, other) is not None @__eq__.register(GetAttrKeyType) def _(self, other) -> bool: return re.fullmatch(self.pattern, other.name) is not None @__eq__.register(DictKeyType) def _(self, other) -> bool: return re.fullmatch(self.pattern, other.key) is not None # dispatch on type of indexer to convert input item to at indexer # `__getitem__` to the appropriate key # avoid using container pytree types to avoid conflict between # matching as a mask or as an instance of `BaseKey` indexer_dispatcher = ft.singledispatch(lambda x: x) indexer_dispatcher.register(type(...), EllipsisKey) indexer_dispatcher.register(int, IntKey) indexer_dispatcher.register(str, NameKey) indexer_dispatcher.register(re.Pattern, RegexKey) BaseKey.def_alias = indexer_dispatcher.register _NOT_IMPLEMENTED_INDEXING = """Indexing with {} is not implemented, supported indexing types are: - `str` for mapping keys or class attributes. - `int` for positional indexing for sequences. - `...` to select all leaves. - Boolean mask of the same structure as the tree - `re.Pattern` to index all keys matching a regex pattern. - Instance of `BaseKey` with custom logic to index a pytree. - `tuple` of the above types to match multiple leaves at the same level. """ def _generate_path_mask( tree: PyTree, where: tuple[BaseKey, ...], is_leaf: Callable[[Any], None] | None = None, ) -> PyTree: # generate a boolean mask for `where` path in `tree` # where path is a tuple of indices or keys, for example # where=("a",) wil set all leaves of `tree` with key "a" to True and # all other leaves to False match = False def map_func(path, _: Any): if len(where) > len(path): # path is shorter than `where` path. for example # where=("a", "b") and the current path is ("a",) then # the current path is not a match return False for wi, ki in zip(where, path): if not (wi == ki): return False nonlocal match match = True return match mask = treelib.tree_path_map(map_func, tree, is_leaf=is_leaf) if not match: raise LookupError(f"No leaf match is found for {where=}.") return mask def _combine_bool_leaves(*leaves): verdict = True for leaf in leaves: verdict &= leaf return verdict def _is_bool_leaf(leaf: Any) -> bool: if isinstance(leaf, arraylib.ndarray): return arraylib.is_bool(leaf) return isinstance(leaf, bool) def _resolve_where( tree: T, where: tuple[Any, ...], # type: ignore is_leaf: Callable[[Any], None] | None = None, ) -> T | None: # given a pytree `tree` and a `where` path, that is composed of keys or # boolean masks, generate a boolean mask that will be eventually used to # with `tree_map` to select the leaves at the specified location. mask = None bool_masks: list[T] = [] path_masks: list[BaseKey] = [] _, treedef0 = treelib.tree_flatten(tree, is_leaf=is_leaf) seen_tuple = False # handle multiple keys at the same level level_paths = [] def verify_and_aggregate_is_leaf(x) -> bool: # use is_leaf with non-local to traverse the tree depth-first manner # required for verifying if a pytree is a valid indexing pytree nonlocal seen_tuple, level_paths, bool_masks # used to check if a pytree is a valid indexing pytree # used with `is_leaf` argument of any `tree_*` function leaves, treedef = treelib.tree_flatten(x) if treedef == treedef0 and all(map(_is_bool_leaf, leaves)): # boolean pytrees of same structure as `tree` is a valid indexing pytree bool_masks += [x] return True if isinstance(resolved_key := indexer_dispatcher(x), BaseKey): # valid resolution of `BaseKey` is a valid indexing leaf # makes it possible to dispatch on multi-leaf pytree level_paths += [resolved_key] return False if type(x) is tuple and seen_tuple is False: # e.g. `at[1,2,3]` but not `at[1,(2,3)]`` seen_tuple = True return False # not a container of other keys or a pytree of same structure raise NotImplementedError(_NOT_IMPLEMENTED_INDEXING.format(x)) for level_keys in where: # each for loop iteration is a level in the where path # this means that if where = ("a", "b", "c") then this means # we are travering the tree at level "a" then level "b" then level "c" treelib.tree_flatten(level_keys, is_leaf=verify_and_aggregate_is_leaf) # if len(level_paths) > 1 then this means that we have multiple keys # at the same level, for example where = ("a", ("b", "c")) then this # means that for a parent "a", select "b" and "c". path_masks += [MultiKey(*level_paths)] if len(level_paths) > 1 else level_paths level_paths = [] seen_tuple = False if path_masks: mask = _generate_path_mask(tree, path_masks, is_leaf=is_leaf) if bool_masks: all_masks = [mask, *bool_masks] if mask else bool_masks mask = treelib.tree_map(_combine_bool_leaves, *all_masks) return mask class AtIndexer(NamedTuple): """Index a pytree at a given path using a path or mask. Args: tree: pytree to index where: one of the following: - ``str`` for mapping keys or class attributes. - ``int`` for positional indexing for sequences. - ``...`` to select all leaves. - a boolean mask of the same structure as the tree - ``re.Pattern`` to index all keys matching a regex pattern. - an instance of ``BaseKey`` with custom logic to index a pytree. - a tuple of the above to match multiple keys at the same level. Example: >>> # use `AtIndexer` on a pytree (e.g. dict,list,tuple,etc.) >>> import jax >>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': None} >>> # get multiple keys at once at the same level >>> indexer["level1_0"]["level2_0", "level2_1"].get() {'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None} >>> # get with a mask >>> mask = {"level1_0": {"level2_0": True, "level2_1": False}, "level1_1": True} >>> indexer[mask].get() {'level1_0': {'level2_0': 100, 'level2_1': None}, 'level1_1': 300} Example: >>> # use ``AtIndexer`` in a class >>> import jax.tree_util as jtu >>> import pytreeclass as tc >>> @jax.tree_util.register_pytree_with_keys_class ... class Tree: ... def __init__(self, a, b): ... self.a = a ... self.b = b ... def tree_flatten_with_keys(self): ... kva = (jtu.GetAttrKey("a"), self.a) ... kvb = (jtu.GetAttrKey("b"), self.b) ... return (kva, kvb), None ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... @property ... def at(self): ... return tc.AtIndexer(self) ... def __repr__(self) -> str: ... return f"{self.__class__.__name__}(a={self.a}, b={self.b})" >>> Tree(1, 2).at["a"].get() Tree(a=1, b=None) """ tree: PyTree where: tuple[BaseKey | PyTree] | tuple[()] = () def __getitem__(self, where: Any) -> AtIndexer: # AtIndexer[where] will extend the current path with `where` # for example AtIndexer[where1][where2] will extend the current path # with `where1` and `where2` to indicate the path to the leaves to # select. return type(self)(self.tree, (*self.where, where))
[docs] def get( self, *, is_leaf: Callable[[Any], None] | None = None, is_parallel: bool | ParallelConfig = False, ) -> PyTree: """Get the leaf values at the specified location. Args: is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array. Example: >>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) # construct an indexer >>> indexer["b"][0].get() # get the first element of "b" {'a': None, 'b': [1, None, None]} Example: >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # get ``a`` and return a new instance >>> # with ``None`` for all other leaves >>> tree.at['a'].get() Tree(a=1, b=None) """ where = _resolve_where(self.tree, self.where, is_leaf) config = dict(is_leaf=is_leaf, is_parallel=is_parallel) def leaf_get(leaf: Any, where: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1]) if isinstance(where, arraylib.ndarray) and arraylib.ndim(where) != 0: return leaf[where] # non-array boolean mask we select the leaf if the mask is True # and `None` otherwise return leaf if where else None return treelib.tree_map(leaf_get, self.tree, where, **config)
[docs] def set( self, set_value: Any, *, is_leaf: Callable[[Any], None] | None = None, is_parallel: bool | ParallelConfig = False, ) -> PyTree: """Set the leaf values at the specified location. Args: set_value: the value to set at the specified location. is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A pytree with the leaf values at the specified location set to ``set_value``. Example: >>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) >>> indexer["b"][0].set(100) # set the first element of "b" to 100 {'a': 1, 'b': [100, 2, 3]} Example: >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # set ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].set(100) Tree(a=100, b=2) """ where = _resolve_where(self.tree, self.where, is_leaf) config = dict(is_leaf=is_leaf, is_parallel=is_parallel) def leaf_set(leaf: Any, where: Any, set_value: Any): # support both array and non-array leaves # for array boolean mask we select **parts** of the array that # matches the mask, for example if the mask is Array([True, False, False]) # and the leaf is Array([1, 2, 3]) then the result is Array([1, 100, 100]) # with set_value = 100 if isinstance(where, arraylib.ndarray): return arraylib.where(where, set_value, leaf) return set_value if where else leaf _, lhsdef = treelib.tree_flatten(self.tree, is_leaf=is_leaf) _, rhsdef = treelib.tree_flatten(set_value, is_leaf=is_leaf) if lhsdef == rhsdef: # do not broadcast set_value if it is a pytree of same structure # for example tree.at[where].set(tree2) will set all tree leaves # to tree2 leaves if tree2 is a pytree of same structure as tree # instead of making each leaf of tree a copy of tree2 # is design is similar to ``numpy`` design `np.at[...].set(Array)` return treelib.tree_map(leaf_set, self.tree, where, set_value, **config) # set_value is broadcasted to tree leaves # for example tree.at[where].set(1) will set all tree leaves to 1 leaf_set_ = lambda leaf, where: leaf_set(leaf, where, set_value) return treelib.tree_map(leaf_set_, self.tree, where, **config)
[docs] def apply( self, func: Callable[[Any], Any], *, is_leaf: Callable[[Any], None] | None = None, is_parallel: bool | ParallelConfig = False, ) -> PyTree: """Apply a function to the leaf values at the specified location. Args: func: the function to apply to the leaf values. is_leaf: a predicate function to determine if a value is a leaf. is_parallel: accepts the following: - ``bool``: apply ``func`` in parallel if ``True`` otherwise in serial. - ``dict``: a dict of of: - ``max_workers``: maximum number of workers to use. - ``kind``: kind of pool to use, either ``thread`` or ``process``. Returns: A pytree with the leaf values at the specified location set to the result of applying ``func`` to the leaf values. Example: >>> import pytreeclass as tc >>> tree = {"a": 1, "b": [1, 2, 3]} >>> indexer = tc.AtIndexer(tree) >>> indexer["b"][0].apply(lambda x: x + 100) # add 100 to the first element of "b" {'a': 1, 'b': [101, 2, 3]} Example: >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> # apply to ``a`` and return a new instance >>> # with all other leaves unchanged >>> tree.at['a'].apply(lambda _: 100) Tree(a=100, b=2) Example: >>> # read images in parallel >>> import pytreeclass as tc >>> from matplotlib.pyplot import imread >>> indexer = tc.AtIndexer({"lenna": "lenna.png", "baboon": "baboon.png"}) >>> images = indexer[...].apply(imread, parallel=dict(max_workers=2)) # doctest: +SKIP """ where = _resolve_where(self.tree, self.where, is_leaf) config = dict(is_leaf=is_leaf, is_parallel=is_parallel) def leaf_apply(leaf: Any, where: bool): # same as `leaf_set` but with `func` applied to the leaf # one thing to note is that, the where mask select an array # then the function needs work properly when applied to the selected # array elements if isinstance(where, arraylib.ndarray): return arraylib.where(where, func(leaf), leaf) return func(leaf) if where else leaf return treelib.tree_map(leaf_apply, self.tree, where, **config)
[docs] def scan( self, func: Callable[[Any, S], tuple[Any, S]], state: S, *, is_leaf: Callable[[Any], None] | None = None, ) -> tuple[PyTree, S]: """Apply a function while carrying a state. Args: func: the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state. state: the initial state to carry. is_leaf: a predicate function to determine if a value is a leaf. for example, ``lambda x: isinstance(x, list)`` will treat all lists as leaves and will not recurse into list items. Returns: A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying ``func`` to the leaf values. Example: >>> import pytreeclass as tc >>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300} >>> def scan_func(leaf, state): ... return 'SET', state + 1 >>> init_state = 0 >>> indexer = tc.AtIndexer(tree) >>> indexer["level1_0"]["level2_0"].scan(scan_func, state=init_state) ({'level1_0': {'level2_0': 'SET', 'level2_1': 200}, 'level1_1': 300}, 1) Example: >>> import pytreeclass as tc >>> from typing import NamedTuple >>> class State(NamedTuple): ... func_evals: int = 0 >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int ... c: int >>> tree = Tree(a=1, b=2, c=3) >>> def scan_func(leaf, state: State): ... state = State(state.func_evals + 1) ... return leaf + 1, state >>> # apply to ``a`` and ``b`` and return a new instance with all other >>> # leaves unchanged and the new state that counts the number of >>> # function evaluations >>> tree.at['a','b'].scan(scan_func, state=State()) (Tree(a=2, b=3, c=3), State(func_evals=2)) Note: ``scan`` applies a binary ``func`` to the leaf values while carrying a state and returning a tree leaves with the the ``func`` applied to them with final state. While ``reduce`` applies a binary ``func`` to the leaf values while carrying a state and returning a single value. """ where = _resolve_where(self.tree, self.where, is_leaf) running_state = state def stateless_func(leaf): nonlocal running_state leaf, running_state = func(leaf, running_state) return leaf def leaf_apply(leaf: Any, where: bool): if isinstance(where, arraylib.ndarray): return arraylib.where(where, stateless_func(leaf), leaf) return stateless_func(leaf) if where else leaf out = treelib.tree_map(leaf_apply, self.tree, where, is_leaf=is_leaf) return out, running_state
[docs] def reduce( self, func: Callable[[Any, Any], Any], *, initializer: Any = _no_initializer, is_leaf: Callable[[Any], None] | None = None, ) -> Any: """Reduce the leaf values at the specified location. Args: func: the function to reduce the leaf values. initializer: the initializer value for the reduction. is_leaf: a predicate function to determine if a value is a leaf. Returns: The result of reducing the leaf values at the specified location. Note: - If ``initializer`` is not specified, the first leaf value is used as the initializer. - ``reduce`` applies a binary ``func`` to each leaf values while accumulating a state a returns the final result. while ``scan`` applies ``func`` to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applying ``func`` to each leaf. Example: >>> import pytreeclass as tc >>> @tc.autoinit ... class Tree(tc.TreeClass): ... a: int ... b: int >>> tree = Tree(a=1, b=2) >>> tree.at[...].reduce(lambda a, b: a + b, initializer=0) 3 """ where = _resolve_where(self.tree, self.where, is_leaf) tree = self[where].get(is_leaf=is_leaf) # type: ignore leaves, _ = treelib.tree_flatten(tree, is_leaf=is_leaf) if initializer is _no_initializer: return ft.reduce(func, leaves) return ft.reduce(func, leaves, initializer)