# Copyright 2023 pytreeclass authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to work with non-inexact type tree leaves across function transformations."""
from __future__ import annotations
import functools as ft
import hashlib
from typing import Any, Callable, Generic, NamedTuple, TypeVar, Union
from pytreeclass._src.backend import arraylib, treelib
from pytreeclass._src.tree_pprint import tree_repr, tree_str, tree_summary
from pytreeclass._src.tree_util import is_tree_equal, tree_copy, tree_hash
T = TypeVar("T")
MaskType = Union[T, Callable[[Any], bool]]
class _FrozenError(NamedTuple):
opname: str
def __call__(self, *a, **k):
raise NotImplementedError(
f"Cannot apply `{self.opname}` operation to a frozen object "
f"{', '.join(map(str, a))} "
f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n"
"Unfreeze the object first by unmasking the frozen mask:\n"
"Example:\n"
">>> import jax\n"
">>> import pytreeclass as tc\n"
">>> tree = tc.tree_unmask(tree)"
)
class _FrozenBase(Generic[T]):
# the objective of this class is to wrap a pytree node with a custom wrapper
# that yields no leaves when flattened. This is useful to avoid updating
# the node by effectivly *hiding it* from function transformations that operates
# on flattened pytrees.
__slots__ = ["__wrapped__", "__weakref__"]
__wrapped__: T
def __init__(self, node: T) -> None:
object.__setattr__(self, "__wrapped__", node)
def __setattr__(self, _, __) -> None:
raise AttributeError("Cannot assign to frozen instance.")
def __delattr__(self, _: str) -> None:
raise AttributeError("Cannot delete from frozen instance.")
def __repr__(self) -> str:
return "#" + tree_repr(self.__wrapped__)
def __str__(self) -> str:
return "#" + tree_str(self.__wrapped__)
def __copy__(self) -> _FrozenBase[T]:
return type(self)(tree_copy(self.__wrapped__))
def __init_subclass__(klass, **k) -> None:
# register subclasses as an empty pytree node
super().__init_subclass__(**k)
# register with the proper backend
treelib.register_static(klass)
# raise helpful error message when trying to interact with frozen object
__add__ = __radd__ = __iadd__ = _FrozenError("+")
__sub__ = __rsub__ = __isub__ = _FrozenError("-")
__mul__ = __rmul__ = __imul__ = _FrozenError("*")
__matmul__ = __rmatmul__ = __imatmul__ = _FrozenError("@")
__truediv__ = __rtruediv__ = __itruediv__ = _FrozenError("/")
__floordiv__ = __rfloordiv__ = __ifloordiv__ = _FrozenError("//")
__mod__ = __rmod__ = __imod__ = _FrozenError("%")
__pow__ = __rpow__ = __ipow__ = _FrozenError("**")
__lshift__ = __rlshift__ = __ilshift__ = _FrozenError("<<")
__rshift__ = __rrshift__ = __irshift__ = _FrozenError(">>")
__and__ = __rand__ = __iand__ = _FrozenError("and")
__xor__ = __rxor__ = __ixor__ = _FrozenError("")
__or__ = __ror__ = __ior__ = _FrozenError("or")
__neg__ = __pos__ = __abs__ = __invert__ = _FrozenError("unary operation")
__call__ = _FrozenError("__call__")
@tree_summary.def_type(_FrozenBase)
def _(node) -> str:
return f"#{tree_summary.type_dispatcher(node.__wrapped__)}"
class _FrozenHashable(_FrozenBase):
def __hash__(self) -> int:
return tree_hash(self.__wrapped__)
def __eq__(self, rhs: Any) -> bool | arraylib.ndarray:
if not isinstance(rhs, _FrozenHashable):
return False
return is_tree_equal(self.__wrapped__, rhs.__wrapped__)
class _FrozenArray(_FrozenBase):
# wrap arrays with a custom wrapper that implements hash and equality
# using the wrapped array's bytes representation and sha256 hash function
# this is useful to select some array to hold without updating in the process
# of training a model.
def __hash__(self) -> int:
bytes = arraylib.tobytes(self.__wrapped__)
return int(hashlib.sha256(bytes).hexdigest(), 16)
def __eq__(self, other) -> bool:
if not isinstance(other, _FrozenArray):
return False
lhs, rhs = self.__wrapped__, other.__wrapped__
# fast path to avoid calling `all` on large arrays
if arraylib.shape(lhs) != arraylib.shape(rhs):
return False
if arraylib.dtype(lhs) != arraylib.dtype(rhs):
return False
return arraylib.all(lhs == rhs)
[docs]def freeze(value: T) -> _FrozenBase[T]:
"""Freeze a value to avoid updating it by through function transformations.
Args:
value: A value to freeze.
Note:
- :func:`.freeze` is idempotent, i.e. ``freeze(freeze(x)) == freeze(x)``.
Example:
>>> import jax
>>> import pytreeclass as tc
>>> import jax.tree_util as jtu
>>> # Usage with `jax.tree_util.tree_leaves`
>>> # no leaves for a wrapped value
>>> jtu.tree_leaves(tc.freeze(2.))
[]
>>> # retrieve the frozen wrapper value using `is_leaf=tc.is_frozen`
>>> jtu.tree_leaves(tc.freeze(2.), is_leaf=tc.is_frozen)
[#2.0]
>>> # Usage with `jax.tree_util.tree_map`
>>> a= [1,2,3]
>>> a[1] = tc.freeze(a[1])
>>> jtu.tree_map(lambda x:x+100, a)
[101, #2, 103]
"""
# dispatching is used to customize the type of the wrapper based on the type
# of the value. For instance, hashable values dont need custom hash and
# equality implementations, so they are wrapped with a simpler wrapper.
# this approach avoids type logic in the wrapper equality and hash methods,
# thus effectively improving performance of the wrapper.
return freeze.type_dispatcher(value)
freeze.type_dispatcher = ft.singledispatch(lambda x: _FrozenHashable(x))
freeze.def_type = freeze.type_dispatcher.register
@freeze.def_type(arraylib.ndarray)
def _(value: T) -> _FrozenArray[T]:
# wrap arrays with a custom wrapper that implements hash and equality
# arrays can be hashed by converting them to bytes and hashing the bytes
return _FrozenArray(value)
@freeze.def_type(_FrozenBase)
def _(value: _FrozenBase[T]) -> _FrozenBase[T]:
# idempotent freeze operation, meaning that freeze(freeze(x)) == freeze(x)
# this is useful to avoid recursive unwrapping of frozen values, plus its
# meaningless to freeze a frozen value.
return value
[docs]def is_frozen(value: Any) -> bool:
"""Returns True if the value is a frozen wrapper."""
return isinstance(value, _FrozenBase)
[docs]def unfreeze(value: T) -> T:
"""Unfreeze :func:`.freeze` value, otherwise return the value itself.
Args:
value: A value to unfreeze.
Note:
- use ``is_leaf=tc.is_frozen`` with ``tree_map`` to unfreeze a tree.**
Example:
>>> import pytreeclass as tc
>>> import jax
>>> frozen_value = tc.freeze(1)
>>> tc.unfreeze(frozen_value)
1
>>> # usage with `jax.tree_map`
>>> frozen_tree = jax.tree_map(tc.freeze, {"a": 1, "b": 2})
>>> unfrozen_tree = jax.tree_map(tc.unfreeze, frozen_tree, is_leaf=tc.is_frozen)
>>> unfrozen_tree
{'a': 1, 'b': 2}
"""
return unfreeze.type_dispatcher(value)
unfreeze.type_dispatcher = ft.singledispatch(lambda x: x)
unfreeze.def_type = unfreeze.type_dispatcher.register
@unfreeze.def_type(_FrozenBase)
def _(value: _FrozenBase[T]) -> T:
return getattr(value, "__wrapped__")
[docs]def is_nondiff(value: Any) -> bool:
"""Returns True for non-inexact types, False otherwise.
Args:
value: A value to check.
Note:
- :func:`.is_nondiff` uses single dispatch to support custom types. To define
a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``.
Example:
>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> tc.is_nondiff(jnp.array(1)) # int array is non-diff type
True
>>> tc.is_nondiff(jnp.array(1.)) # float array is diff type
False
>>> tc.is_nondiff(1) # int is non-diff type
True
>>> tc.is_nondiff(1.) # float is diff type
False
Note:
This function is meant to be used with ``jax.tree_map`` to
create a mask for non-differentiable nodes in a tree, that can be used
to freeze the non-differentiable nodes before passing the tree to a
``jax`` transformation.
"""
return is_nondiff.type_dispatcher(value)
is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register
@is_nondiff.def_type(arraylib.ndarray)
def _(value: arraylib.ndarray) -> bool:
# return True if the node is non-inexact type, otherwise False
if arraylib.is_inexact(value):
return False
return True
@is_nondiff.def_type(float)
@is_nondiff.def_type(complex)
def _(_: float | complex) -> bool:
return False
def _tree_mask_map(
tree: T,
mask: MaskType,
func: type | Callable[[Any], Any],
*,
is_leaf: Callable[[Any], None] | None = None,
):
# apply func to leaves satisfying mask pytree/condtion
_, lhsdef = treelib.tree_flatten(tree, is_leaf=is_leaf)
_, rhsdef = treelib.tree_flatten(mask, is_leaf=is_leaf)
if (lhsdef == rhsdef) and (type(mask) is type(tree)):
# a tree with the same structure as tree with boolean values
# and also a callable.
def map_func(x, y):
return func(x) if y else x
return treelib.tree_map(map_func, tree, mask, is_leaf=is_leaf)
if isinstance(mask, Callable):
# a callable that accepts a leaf and returns a boolean
# but *not* a tree with the same structure as tree with boolean values.
def map_func(x):
return func(x) if mask(x) else x
return treelib.tree_map(map_func, tree, is_leaf=is_leaf)
raise ValueError(
f"`mask` must be a callable that accepts a leaf and returns a boolean "
f"or a tree with the same structure as tree with boolean values."
f" Got {mask=} and {tree=}."
)
[docs]def tree_mask(
tree: T,
mask: MaskType = is_nondiff,
*,
is_leaf: Callable[[Any], None] | None = None,
):
"""Mask leaves of a pytree based on ``mask`` boolean pytree or callable.
Args:
tree: A pytree of values.
mask: A pytree of boolean values or a callable that accepts a leaf and
returns a boolean. If a leaf is ``True`` either in the mask or the
callable, the leaf is wrapped by with a wrapper that yields no
leaves when ``tree_flatten`` is called on it, otherwise
it is unchanged. defaults to :func:`.is_nondiff` which returns true for
non-differentiable nodes.
is_leaf: A callable that accepts a leaf and returns a boolean. If
provided, it is used to determine if a value is a leaf. for example,
``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves
and will not recurse into them.
Note:
- Masked leaves are wrapped with a wrapper that yields no leaves when
``tree_flatten`` is called on it.
- Masking is equivalent to applying :func:`.freeze` to the masked leaves.
>>> import pytreeclass as tc
>>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> def mask_if_nondiff(x):
... return tc.freeze(x) if tc.is_nondiff(x) else x
>>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
- Use masking on tree containing non-differentiable nodes before passing
the tree to a ``jax`` transformation.
Example:
>>> import pytreeclass as tc
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = tc.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> tc.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]
Example:
>>> # pass non-differentiable values to `jax.grad`
>>> import pytreeclass as tc
>>> import jax
>>> @jax.grad
... def square(tree):
... tree = tc.tree_unmask(tree)
... return tree[0]**2
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(tc.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
"""
return _tree_mask_map(tree, mask=mask, func=freeze, is_leaf=is_leaf)
[docs]def tree_unmask(tree: T, mask: MaskType = lambda _: True):
"""Undo the masking of tree leaves according to ``mask``. defaults to unmasking all leaves.
Args:
tree: A pytree of values.
mask: A pytree of boolean values or a callable that accepts a leaf and
returns a boolean. If a leaf is True either in the mask or the
callable, the leaf is unfrozen, otherwise it is unchanged. defaults
unmasking all nodes.
Example:
>>> import pytreeclass as tc
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = tc.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> tc.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]
Example:
>>> # pass non-differentiable values to `jax.grad`
>>> import pytreeclass as tc
>>> import jax
>>> @jax.grad
... def square(tree):
... tree = tc.tree_unmask(tree)
... return tree[0]**2
>>> tree = (1., 2) # contains a non-differentiable node
>>> square(tc.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
Note:
- Unmasking is equivalent to applying :func:`.unfreeze` on the masked leaves.
>>> import pytreeclass as tc
>>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # unmask all nodes
>>> tree = jax.tree_map(tc.unfreeze, tree, is_leaf=tc.is_frozen)
"""
return _tree_mask_map(tree, mask=mask, func=unfreeze, is_leaf=is_frozen)