# Copyright 2023 pytreeclass authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for pytrees."""
from __future__ import annotations
import functools as ft
import operator as op
from copy import copy
from math import ceil, floor, trunc
from typing import Any, Callable, Hashable, Iterator, Sequence, Tuple, TypeVar
from typing_extensions import ParamSpec
from pytreeclass._src.backend import arraylib, treelib
T = TypeVar("T")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
P = ParamSpec("P")
PyTree = Any
EllipsisType = TypeVar("EllipsisType")
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
TypeEntry = TypeVar("TypeEntry", bound=type)
TraceEntry = Tuple[KeyEntry, TypeEntry]
KeyPath = Tuple[KeyEntry, ...]
TypePath = Tuple[TypeEntry, ...]
KeyTypePath = Tuple[KeyPath, TypePath]
def tree_hash(*trees: PyTree) -> int:
leaves, treedef = treelib.tree_flatten(trees)
return hash((*leaves, treedef))
def tree_copy(tree: T) -> T:
"""Return a copy of the tree."""
return treelib.tree_map(lambda x: copy(x), tree)
def _is_leaf_rhs_equal(leaf, rhs) -> bool | arraylib.ndarray:
if isinstance(leaf, arraylib.ndarray):
if isinstance(rhs, arraylib.ndarray):
if leaf.shape != rhs.shape:
return False
if leaf.dtype != rhs.dtype:
return False
try:
return bool(verdict := arraylib.all(leaf == rhs))
except Exception:
return verdict # fail under `jit`
return False
return leaf == rhs
[docs]def is_tree_equal(*trees: Any) -> bool | arraylib.ndarray:
"""Return ``True`` if all pytrees are equal.
Note:
trees are compared using their leaves and treedefs.
Note:
Under boolean ``Array`` if compiled otherwise ``bool``.
"""
tree0, *rest = trees
leaves0, treedef0 = treelib.tree_flatten(tree0)
verdict = True
for tree in rest:
leaves, treedef = treelib.tree_flatten(tree)
if (treedef != treedef0) or verdict is False:
return False
verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict)
return verdict
class Partial:
"""``Partial`` function with support for positional partial application.
Args:
func: The function to be partially applied.
args: Positional arguments to be partially applied. use ``...`` as a
placeholder for positional arguments.
kwargs: Keyword arguments to be partially applied.
Example:
>>> import pytreeclass as tc
>>> def f(a, b, c):
... print(f"a: {a}, b: {b}, c: {c}")
... return a + b + c
>>> # positional arguments using `...` placeholder
>>> f_a = tc.Partial(f, ..., 2, 3)
>>> f_a(1)
a: 1, b: 2, c: 3
6
>>> # keyword arguments
>>> f_b = tc.Partial(f, b=2, c=3)
>>> f_a(1)
a: 1, b: 2, c: 3
6
Note:
- The ``...`` is used to indicate a placeholder for positional arguments.
- https://stackoverflow.com/a/7811270
"""
__slots__ = ["func", "args", "kwargs", "__weakref__"] # type: ignore
def __init__(self, func: Callable[..., Any], *args: Any, **kwargs: Any):
self.func = func
self.args = args
self.kwargs = kwargs
def __call__(self, *args: Any, **kwargs: Any) -> Any:
iargs = iter(args)
args = (next(iargs) if arg is ... else arg for arg in self.args) # type: ignore
return self.func(*args, *iargs, **{**self.kwargs, **kwargs})
def __repr__(self) -> str:
return f"Partial({self.func}, {self.args}, {self.kwargs})"
def __hash__(self) -> int:
return tree_hash(self)
def __eq__(self, other: Any) -> bool:
return is_tree_equal(self, other)
treelib.register_static(Partial)
[docs]def bcmap(
func: Callable[P, T],
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> Callable[P, T]:
"""Map a function over pytree leaves with automatic broadcasting for scalar arguments.
Args:
func: the function to be mapped over the pytree
is_leaf: a predicate function that returns True if the node is a leaf
Example:
>>> import jax
>>> import pytreeclass as tc
>>> import functools as ft
>>> @tc.autoinit
... @tc.leafwise
... class Test(tc.TreeClass):
... a: tuple[int, int, int] = (1, 2, 3)
... b: tuple[int, int, int] = (4, 5, 6)
... c: jax.Array = jnp.array([1, 2, 3])
>>> tree = Test()
>>> # 0 is broadcasted to all leaves of the pytree
>>> print(tc.bcmap(jnp.where)(tree > 1, tree, 0))
Test(a=(0, 2, 3), b=(4, 5, 6), c=[0 2 3])
>>> print(tc.bcmap(jnp.where)(tree > 1, 0, tree))
Test(a=(1, 0, 0), b=(0, 0, 0), c=[1 0 0])
>>> # 1 is broadcasted to all leaves of the list pytree
>>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], 1)
[2, 3, 4]
>>> # trees are summed leaf-wise
>>> tc.bcmap(lambda x, y: x + y)([1, 2, 3], [1, 2, 3])
[2, 4, 6]
>>> # Non scalar second args case
>>> try:
... tc.bcmap(lambda x, y: x + y)([1, 2, 3], [[1, 2, 3], [1, 2, 3]])
... except TypeError as e:
... print(e)
unsupported operand type(s) for +: 'int' and 'list'
>>> # using **numpy** functions on pytrees
>>> import jax.numpy as jnp
>>> tc.bcmap(jnp.add)([1, 2, 3], [1, 2, 3]) # doctest: +SKIP
[2, 4, 6]
"""
@ft.wraps(func)
def wrapper(*args, **kwargs):
if len(args) > 0:
# positional arguments are passed the argument to be compare
# the tree structure with is the first argument
leaves0, treedef0 = treelib.tree_flatten(args[0], is_leaf=is_leaf)
masked_args = [...]
masked_kwargs = {}
leaves = [leaves0]
leaves_keys = []
for arg in args[1:]:
_, argdef = treelib.tree_flatten(arg)
if treedef0 == argdef:
masked_args += [...]
leaves += [treedef0.flatten_up_to(arg)]
else:
masked_args += [arg]
else:
# only kwargs are passed the argument to be compare
# the tree structure with is the first kwarg
key0 = next(iter(kwargs))
leaves0, treedef0 = treelib.tree_flatten(kwargs.pop(key0), is_leaf=is_leaf)
masked_args = []
masked_kwargs = {key0: ...}
leaves = [leaves0]
leaves_keys = [key0]
for key in kwargs:
_, kwargdef = treelib.tree_flatten(kwargs[key])
if treedef0 == kwargdef:
masked_kwargs[key] = ...
leaves += [treedef0.flatten_up_to(kwargs[key])]
leaves_keys += [key]
else:
masked_kwargs[key] = kwargs[key]
bfunc = Partial(func, *masked_args, **masked_kwargs)
if len(leaves_keys) == 0:
# no kwargs leaves are present, so we can immediately zip
return treelib.tree_unflatten(treedef0, [bfunc(*xs) for xs in zip(*leaves)])
# kwargs leaves are present, so we need to zip them
kwargnum = len(leaves) - len(leaves_keys)
all_leaves = []
for xs in zip(*leaves):
xs_args, xs_kwargs = xs[:kwargnum], xs[kwargnum:]
all_leaves += [bfunc(*xs_args, **dict(zip(leaves_keys, xs_kwargs)))]
return treelib.tree_unflatten(treedef0, all_leaves)
name = getattr(func, "__name__", func)
docs = f"Broadcasted version of {name}\n{func.__doc__}"
wrapper.__doc__ = docs
return wrapper
def uop(func):
def wrapper(self):
return treelib.tree_map(func, self)
return ft.wraps(func)(wrapper)
def bop(func):
def wrapper(leaf, rhs=None):
if isinstance(rhs, type(leaf)):
return treelib.tree_map(func, leaf, rhs)
return treelib.tree_map(lambda x: func(x, rhs), leaf)
return ft.wraps(func)(wrapper)
def swop(func):
# swaping the arguments of a two-arg function
return ft.wraps(func)(lambda leaf, rhs: func(rhs, leaf))
[docs]def leafwise(klass: type[T]) -> type[T]:
"""A class decorator that adds leafwise operators to a class.
Leafwise operators are operators that are applied to the leaves of a pytree.
For example leafwise ``__add__`` is equivalent to:
- ``tree_map(lambda x: x + rhs, tree)`` if ``rhs`` is a scalar.
- ``tree_map(lambda x, y: x + y, tree, rhs)`` if ``rhs`` is a pytree
with the same structure as ``tree``.
Args:
klass: The class to be decorated.
Returns:
The decorated class.
Example:
>>> # use ``numpy`` functions on :class:`TreeClass`` classes decorated with ``leafwise``
>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> @tc.leafwise
... @tc.autoinit
... class Point(tc.TreeClass):
... x: float = 0.5
... y: float = 1.0
... description: str = "point coordinates"
>>> # use :func:`tree_mask` to mask the non-inexact part of the tree
>>> # i.e. mask the string leaf ``description`` to ``Point`` work
>>> # with ``jax.numpy`` functions
>>> co = tc.tree_mask(Point())
>>> print(tc.bcmap(jnp.where)(co > 0.5, co, 1000))
Point(x=1000.0, y=1.0, description=#point coordinates)
Note:
If a mathematically equivalent operator is already defined on the class,
then it is not overridden.
================== ============
Method Operator
================== ============
``__add__`` ``+``
``__and__`` ``&``
``__ceil__`` ``math.ceil``
``__divmod__`` ``divmod``
``__eq__`` ``==``
``__floor__`` ``math.floor``
``__floordiv__`` ``//``
``__ge__`` ``>=``
``__gt__`` ``>``
``__invert__`` ``~``
``__le__`` ``<=``
``__lshift__`` ``<<``
``__lt__`` ``<``
``__matmul__`` ``@``
``__mod__`` ``%``
``__mul__`` ``*``
``__ne__`` ``!=``
``__neg__`` ``-``
``__or__`` ``|``
``__pos__`` ``+``
``__pow__`` ``**``
``__round__`` ``round``
``__sub__`` ``-``
``__truediv__`` ``/``
``__trunc__`` ``math.trunc``
``__xor__`` ``^``
================== ============
"""
for key, method in (
("__abs__", uop(abs)),
("__add__", bop(op.add)),
("__and__", bop(op.and_)),
("__ceil__", uop(ceil)),
("__divmod__", bop(divmod)),
("__eq__", bop(op.eq)),
("__floor__", uop(floor)),
("__floordiv__", bop(op.floordiv)),
("__ge__", bop(op.ge)),
("__gt__", bop(op.gt)),
("__invert__", uop(op.invert)),
("__le__", bop(op.le)),
("__lshift__", bop(op.lshift)),
("__lt__", bop(op.lt)),
("__matmul__", bop(op.matmul)),
("__mod__", bop(op.mod)),
("__mul__", bop(op.mul)),
("__ne__", bop(op.ne)),
("__neg__", uop(op.neg)),
("__or__", bop(op.or_)),
("__pos__", uop(op.pos)),
("__pow__", bop(op.pow)),
("__radd__", bop(swop(op.add))),
("__rand__", bop(swop(op.and_))),
("__rdivmod__", bop(swop(divmod))),
("__rfloordiv__", bop(swop(op.floordiv))),
("__rlshift__", bop(swop(op.lshift))),
("__rmatmul__", bop(swop(op.matmul))),
("__rmod__", bop(swop(op.mod))),
("__rmul__", bop(swop(op.mul))),
("__ror__", bop(swop(op.or_))),
("__round__", bop(round)),
("__rpow__", bop(swop(op.pow))),
("__rrshift__", bop(swop(op.rshift))),
("__rshift__", bop(op.rshift)),
("__rsub__", bop(swop(op.sub))),
("__rtruediv__", bop(swop(op.truediv))),
("__rxor__", bop(swop(op.xor))),
("__sub__", bop(op.sub)),
("__truediv__", bop(op.truediv)),
("__trunc__", uop(trunc)),
("__xor__", bop(op.xor)),
):
if key not in vars(klass):
# do not override any user defined methods
# this behavior similar is to `dataclasses.dataclass`
setattr(klass, key, method)
return klass
_, atomicdef = treelib.tree_flatten(1)
def flatten_one_typed_path_level(
typedpath: KeyTypePath,
tree: PyTree,
is_leaf: Callable[[Any], bool] | None,
is_path_leaf: Callable[[KeyTypePath], bool] | None,
):
# predicate and type path
if (is_leaf and is_leaf(tree)) or (is_path_leaf and is_path_leaf(typedpath)):
yield typedpath, tree
return
one_level_is_leaf = lambda node: False if (id(node) == id(tree)) else True
path_leaf, treedef = treelib.tree_path_flatten(tree, is_leaf=one_level_is_leaf)
if treedef == atomicdef:
yield typedpath, tree
return
for key, value in path_leaf:
keys, types = typedpath
path = ((*keys, *key), (*types, type(value)))
yield from flatten_one_typed_path_level(path, value, is_leaf, is_path_leaf)
def tree_leaves_with_typed_path(
tree: PyTree,
*,
is_leaf: Callable[[Any], bool] | None = None,
is_path_leaf: Callable[[KeyTypePath], bool] | None = None,
) -> Sequence[tuple[KeyTypePath, Any]]:
# mainly used for visualization
return list(flatten_one_typed_path_level(((), ()), tree, is_leaf, is_path_leaf))
class Node:
# mainly used for visualization
__slots__ = ["data", "parent", "children", "__weakref__"]
def __init__(
self,
data: tuple[TraceEntry, Any],
parent: Node | None = None,
):
self.data = data
self.parent = parent
self.children: dict[TraceEntry, Node] = {}
def add_child(self, child: Node) -> None:
# add child node to this node and set
# this node as the parent of the child
if not isinstance(child, Node):
raise TypeError(f"`child` must be a `Node`, got {type(child)}")
ti, __ = child.data
if ti not in self.children:
# establish parent-child relationship
child.parent = self
self.children[ti] = child
def __iter__(self) -> Iterator[Node]:
# iterate over children nodes
return iter(self.children.values())
def __repr__(self) -> str:
return f"Node(data={self.data})"
def __contains__(self, key: TraceEntry) -> bool:
return key in self.children
def is_path_leaf_depth_factory(depth: int | float):
# generate `is_path_leaf` function to stop tracing at a certain `depth`
# in essence, depth is the length of the trace entry
def is_path_leaf(trace) -> bool:
keys, _ = trace
# stop tracing if depth is reached
return False if depth is None else (depth <= len(keys))
return is_path_leaf
def construct_tree(
tree: PyTree,
is_leaf: Callable[[Any], bool] | None = None,
is_path_leaf: Callable[[KeyTypePath], bool] | None = None,
) -> Node:
# construct a tree with `Node` objects using `tree_leaves_with_typed_path`
# to establish parent-child relationship between nodes
traces_leaves = tree_leaves_with_typed_path(
tree,
is_leaf=is_leaf,
is_path_leaf=is_path_leaf,
)
ti = (None, type(tree))
vi = tree
root = Node(data=(ti, vi))
for trace, leaf in traces_leaves:
keys, types = trace
cur = root
for i, ti in enumerate(zip(keys, types)):
if ti in cur:
# common parent node
cur = cur.children[ti]
else:
# new path
vi = leaf if i == len(keys) - 1 else None
child = Node(data=(ti, vi))
cur.add_child(child)
cur = child
return root