Source code for pytreeclass._src.tree_mask

# Copyright 2023 pytreeclass authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to work with non-jax type tree leaves across JAX transformations."""

from __future__ import annotations

import functools as ft
import hashlib
from typing import Any, Callable, Generic, NamedTuple, TypeVar, Union

import jax
import jax.tree_util as jtu
import numpy as np

from pytreeclass._src.tree_pprint import tree_repr, tree_str, tree_summary
from pytreeclass._src.tree_util import IsLeafType, is_tree_equal, tree_copy, tree_hash

T = TypeVar("T")
MaskType = Union[T, Callable[[Any], bool]]


class _FrozenError(NamedTuple):
    opname: str

    def __call__(self, *a, **k):
        raise NotImplementedError(
            f"Cannot apply `{self.opname}` operation to a frozen object "
            f"{', '.join(map(str, a))} "
            f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n"
            "Unfreeze the object first by unmasking the frozen mask:\n"
            "Example:\n"
            ">>> import jax\n"
            ">>> import pytreeclass as tc\n"
            ">>> tree = tc.tree_unmask(tree)"
        )


class _FrozenBase(Generic[T]):
    __slots__ = ["__wrapped__", "__weakref__"]
    __wrapped__: T

    def __init__(self, node: T) -> None:
        object.__setattr__(self, "__wrapped__", node)

    def __setattr__(self, _, __) -> None:
        raise AttributeError("Cannot assign to frozen instance.")

    def __delattr__(self, _: str) -> None:
        raise AttributeError("Cannot delete from frozen instance.")

    def __repr__(self) -> str:
        return "#" + tree_repr(self.__wrapped__)

    def __str__(self) -> str:
        return "#" + tree_str(self.__wrapped__)

    def __copy__(self) -> _FrozenBase[T]:
        return type(self)(tree_copy(self.__wrapped__))

    def __init_subclass__(klass, **k) -> None:
        # register subclass as an empty pytree node
        super().__init_subclass__(**k)

        jtu.register_pytree_node(
            nodetype=klass,
            flatten_func=lambda tree: ((), tree),
            unflatten_func=lambda treedef, _: treedef,
        )

    # raise helpful error message when trying to interact with frozen object
    __add__ = __radd__ = __iadd__ = _FrozenError("+")
    __sub__ = __rsub__ = __isub__ = _FrozenError("-")
    __mul__ = __rmul__ = __imul__ = _FrozenError("*")
    __matmul__ = __rmatmul__ = __imatmul__ = _FrozenError("@")
    __truediv__ = __rtruediv__ = __itruediv__ = _FrozenError("/")
    __floordiv__ = __rfloordiv__ = __ifloordiv__ = _FrozenError("//")
    __mod__ = __rmod__ = __imod__ = _FrozenError("%")
    __pow__ = __rpow__ = __ipow__ = _FrozenError("**")
    __lshift__ = __rlshift__ = __ilshift__ = _FrozenError("<<")
    __rshift__ = __rrshift__ = __irshift__ = _FrozenError(">>")
    __and__ = __rand__ = __iand__ = _FrozenError("and")
    __xor__ = __rxor__ = __ixor__ = _FrozenError("")
    __or__ = __ror__ = __ior__ = _FrozenError("or")
    __neg__ = __pos__ = __abs__ = __invert__ = _FrozenError("unary operation")
    __call__ = _FrozenError("__call__")


@tree_summary.def_type(_FrozenBase)
def _(node) -> str:
    return f"#{tree_summary.type_dispatcher(node.__wrapped__)}"


class _FrozenHashable(_FrozenBase):
    def __hash__(self) -> int:
        return tree_hash(self.__wrapped__)

    def __eq__(self, rhs: Any) -> bool | jax.Array:
        if not isinstance(rhs, _FrozenHashable):
            return False
        return is_tree_equal(self.__wrapped__, rhs.__wrapped__)


class _FrozenArray(_FrozenBase):
    def __hash__(self) -> int:
        bytes = np.array(self.__wrapped__).tobytes()
        return int(hashlib.sha256(bytes).hexdigest(), 16)

    def __eq__(self, other) -> bool:
        if not isinstance(other, _FrozenArray):
            return False

        lhs, rhs = self.__wrapped__, other.__wrapped__
        if lhs.shape != rhs.shape:
            return False
        if lhs.dtype != rhs.dtype:
            return False
        return np.all(lhs == rhs)


[docs]def freeze(value: T) -> _FrozenBase[T]: """Freeze a value to avoid updating it by ``jax`` transformations. Args: value: A value to freeze. Note: - :func:`.freeze` is idempotent, i.e. ``freeze(freeze(x)) == freeze(x)``. Example: >>> import jax >>> import pytreeclass as tc >>> import jax.tree_util as jtu >>> # Usage with `jax.tree_util.tree_leaves` >>> # no leaves for a wrapped value >>> jtu.tree_leaves(tc.freeze(2.)) [] >>> # retrieve the frozen wrapper value using `is_leaf=tc.is_frozen` >>> jtu.tree_leaves(tc.freeze(2.), is_leaf=tc.is_frozen) [#2.0] >>> # Usage with `jax.tree_util.tree_map` >>> a= [1,2,3] >>> a[1] = tc.freeze(a[1]) >>> jtu.tree_map(lambda x:x+100, a) [101, #2, 103] """ return freeze.type_dispatcher(value)
freeze.type_dispatcher = ft.singledispatch(lambda x: _FrozenHashable(x)) freeze.def_type = freeze.type_dispatcher.register @freeze.def_type(np.ndarray) @freeze.def_type(jax.Array) def _(value: T) -> _FrozenArray[T]: return _FrozenArray(value) @freeze.def_type(_FrozenBase) def _(value: _FrozenBase[T]) -> _FrozenBase[T]: # idempotent freeze return value
[docs]def is_frozen(value: Any) -> bool: """Returns True if the value is a frozen wrapper.""" return isinstance(value, _FrozenBase)
[docs]def unfreeze(value: T) -> T: """Unfreeze :func:`.freeze` value, otherwise return the value itself. Args: value: A value to unfreeze. Note: - use ``is_leaf=tc.is_frozen`` with ``jax.tree_map`` to unfreeze a tree.** Example: >>> import pytreeclass as tc >>> import jax >>> frozen_value = tc.freeze(1) >>> tc.unfreeze(frozen_value) 1 >>> # usage with `jax.tree_map` >>> frozen_tree = jax.tree_map(tc.freeze, {"a": 1, "b": 2}) >>> unfrozen_tree = jax.tree_map(tc.unfreeze, frozen_tree, is_leaf=tc.is_frozen) >>> unfrozen_tree {'a': 1, 'b': 2} """ return unfreeze.type_dispatcher(value)
unfreeze.type_dispatcher = ft.singledispatch(lambda x: x) unfreeze.def_type = unfreeze.type_dispatcher.register @unfreeze.def_type(_FrozenBase) def _(value: _FrozenBase[T]) -> T: return getattr(value, "__wrapped__")
[docs]def is_nondiff(value: Any) -> bool: """Returns True for non-inexact types, False otherwise. Args: value: A value to check. Note: - :func:`.is_nondiff` uses single dispatch to support custom types. To define a custom behavior for a certain type, use ``is_nondiff.def_type(type, func)``. Example: >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tc.is_nondiff(jnp.array(1)) # int array is non-diff type True >>> tc.is_nondiff(jnp.array(1.)) # float array is diff type False >>> tc.is_nondiff(1) # int is non-diff type True >>> tc.is_nondiff(1.) # float is diff type False Note: This function is meant to be used with ``jax.tree_map`` to create a mask for non-differentiable nodes in a tree, that can be used to freeze the non-differentiable nodes before passing the tree to a ``jax`` transformation. """ return is_nondiff.type_dispatcher(value)
is_nondiff.type_dispatcher = ft.singledispatch(lambda x: True) is_nondiff.def_type = is_nondiff.type_dispatcher.register @is_nondiff.def_type(np.ndarray) @is_nondiff.def_type(jax.Array) def _(value: np.ndarray | jax.Array) -> bool: # return True if the node is non-inexact type, otherwise False return False if np.issubdtype(value.dtype, np.inexact) else True @is_nondiff.def_type(float) @is_nondiff.def_type(complex) def _(_: float | complex) -> bool: return False def _tree_mask_map( tree: T, mask: MaskType, func: type | Callable[[Any], Any], *, is_leaf: IsLeafType = None, ): # apply func to leaves satisfying mask pytree/condtion lhsdef = jtu.tree_structure(tree, is_leaf=is_leaf) rhsdef = jtu.tree_structure(mask, is_leaf) if (lhsdef == rhsdef) and (type(mask) is type(tree)): return jax.tree_map( lambda x, y: func(x) if y else x, tree, mask, is_leaf=is_leaf, ) if isinstance(mask, Callable): return jax.tree_map( lambda x: func(x) if mask(x) else x, tree, is_leaf=is_leaf, ) raise ValueError( f"`mask` must be a callable that accepts a leaf and returns a boolean " f"or a tree with the same structure as tree with boolean values." f" Got {mask=} and {tree=}." )
[docs]def tree_mask(tree: T, mask: MaskType = is_nondiff, *, is_leaf: IsLeafType = None): """Mask leaves of a pytree based on ``mask`` boolean pytree or callable. Args: tree: A pytree of values. mask: A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is ``True`` either in the mask or the callable, the leaf is wrapped by with a wrapper that yields no leaves when ``jax.tree_util.tree_flatten`` is called on it, otherwise it is unchanged. defaults to :func:`.is_nondiff` which returns true for non-differentiable nodes. is_leaf: A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, ``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves and will not recurse into them. Note: - Masked leaves are wrapped with a wrapper that yields no leaves when ``jax.tree_util.tree_flatten`` is called on it. - Masking is equivalent to applying :func:`.freeze` to the masked leaves. >>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> def mask_if_nondiff(x): ... return tc.freeze(x) if tc.is_nondiff(x) else x >>> masked_tree = jax.tree_map(mask_if_nondiff, tree) - Use masking on tree containing non-differentiable nodes before passing the tree to a ``jax`` transformation. Example: >>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}] Example: >>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) """ return _tree_mask_map(tree, mask=mask, func=freeze, is_leaf=is_leaf)
[docs]def tree_unmask(tree: T, mask: MaskType = lambda _: True): """Undo the masking of tree leaves according to ``mask``. defaults to unmasking all leaves. Args: tree: A pytree of values. mask: A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is True either in the mask or the callable, the leaf is unfrozen, otherwise it is unchanged. defaults unmasking all nodes. Example: >>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}] Example: >>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) Note: - Unmasking is equivalent to applying :func:`.unfreeze` on the masked leaves. >>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # unmask all nodes >>> tree = jax.tree_map(tc.unfreeze, tree, is_leaf=tc.is_frozen) """ return _tree_mask_map(tree, mask=mask, func=unfreeze, is_leaf=is_frozen)