💼 Masking API#

pytreeclass.is_nondiff(value)[source]#

Returns True for non-inexact types, False otherwise.

Parameters:

value (Any) – A value to check.

Return type:

bool

Note

  • is_nondiff() uses single dispatch to support custom types. To define a custom behavior for a certain type, use is_nondiff.def_type(type, func).

Example

>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> tc.is_nondiff(jnp.array(1))  # int array is non-diff type
True
>>> tc.is_nondiff(jnp.array(1.))  # float array is diff type
False
>>> tc.is_nondiff(1)  # int is non-diff type
True
>>> tc.is_nondiff(1.)  # float is diff type
False

Note

This function is meant to be used with jax.tree_map to create a mask for non-differentiable nodes in a tree, that can be used to freeze the non-differentiable nodes before passing the tree to a jax transformation.

pytreeclass.freeze(value)[source]#

Freeze a value to avoid updating it by through function transformations.

Parameters:

value (TypeVar(T)) – A value to freeze.

Return type:

_FrozenBase[TypeVar(T)]

Note

  • freeze() is idempotent, i.e. freeze(freeze(x)) == freeze(x).

Example

>>> import jax
>>> import pytreeclass as tc
>>> import jax.tree_util as jtu
>>> # Usage with `jax.tree_util.tree_leaves`
>>> # no leaves for a wrapped value
>>> jtu.tree_leaves(tc.freeze(2.))
[]
>>> # retrieve the frozen wrapper value using `is_leaf=tc.is_frozen`
>>> jtu.tree_leaves(tc.freeze(2.), is_leaf=tc.is_frozen)
[#2.0]
>>> # Usage with `jax.tree_util.tree_map`
>>> a= [1,2,3]
>>> a[1] = tc.freeze(a[1])
>>> jtu.tree_map(lambda x:x+100, a)
[101, #2, 103]
pytreeclass.unfreeze(value)[source]#

Unfreeze freeze() value, otherwise return the value itself.

Parameters:

value (TypeVar(T)) – A value to unfreeze.

Return type:

TypeVar(T)

Note

  • use is_leaf=tc.is_frozen with tree_map to unfreeze a tree.**

Example

>>> import pytreeclass as tc
>>> import jax
>>> frozen_value = tc.freeze(1)
>>> tc.unfreeze(frozen_value)
1
>>> # usage with `jax.tree_map`
>>> frozen_tree = jax.tree_map(tc.freeze, {"a": 1, "b": 2})
>>> unfrozen_tree = jax.tree_map(tc.unfreeze, frozen_tree, is_leaf=tc.is_frozen)
>>> unfrozen_tree
{'a': 1, 'b': 2}
pytreeclass.is_frozen(value)[source]#

Returns True if the value is a frozen wrapper.

Return type:

bool

pytreeclass.tree_mask(tree, mask=<function is_nondiff>, *, is_leaf=None)[source]#

Mask leaves of a pytree based on mask boolean pytree or callable.

Parameters:
  • tree (T) – A pytree of values.

  • mask (MaskType) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is True either in the mask or the callable, the leaf is wrapped by with a wrapper that yields no leaves when tree_flatten is called on it, otherwise it is unchanged. defaults to is_nondiff() which returns true for non-differentiable nodes.

  • is_leaf (Callable[[Any], None] | None) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, is_leaf=lambda x: isinstance(x, list) will treat lists as leaves and will not recurse into them.

Note

  • Masked leaves are wrapped with a wrapper that yields no leaves when tree_flatten is called on it.

  • Masking is equivalent to applying freeze() to the masked leaves.

    >>> import pytreeclass as tc
    >>> import jax
    >>> tree = [1, 2, {"a": 3, "b": 4.}]
    >>> # mask all non-differentiable nodes by default
    >>> def mask_if_nondiff(x):
    ...     return tc.freeze(x) if tc.is_nondiff(x) else x
    >>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
    
  • Use masking on tree containing non-differentiable nodes before passing the tree to a jax transformation.

Example

>>> import pytreeclass as tc
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = tc.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> tc.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]

Example

>>> # pass non-differentiable values to `jax.grad`
>>> import pytreeclass as tc
>>> import jax
>>> @jax.grad
... def square(tree):
...     tree = tc.tree_unmask(tree)
...     return tree[0]**2
>>> tree = (1., 2)  # contains a non-differentiable node
>>> square(tc.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
pytreeclass.tree_unmask(tree, mask=<function <lambda>>)[source]#

Undo the masking of tree leaves according to mask. defaults to unmasking all leaves.

Parameters:
  • tree (TypeVar(T)) – A pytree of values.

  • mask (Union[TypeVar(T), Callable[[Any], bool]]) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is True either in the mask or the callable, the leaf is unfrozen, otherwise it is unchanged. defaults unmasking all nodes.

Example

>>> import pytreeclass as tc
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = tc.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> tc.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]

Example

>>> # pass non-differentiable values to `jax.grad`
>>> import pytreeclass as tc
>>> import jax
>>> @jax.grad
... def square(tree):
...     tree = tc.tree_unmask(tree)
...     return tree[0]**2
>>> tree = (1., 2)  # contains a non-differentiable node
>>> square(tc.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)

Note

  • Unmasking is equivalent to applying unfreeze() on the masked leaves.

    >>> import pytreeclass as tc
    >>> import jax
    >>> tree = [1, 2, {"a": 3, "b": 4.}]
    >>> # unmask all nodes
    >>> tree = jax.tree_map(tc.unfreeze, tree, is_leaf=tc.is_frozen)