💼 Masking API#
- pytreeclass.is_nondiff(value)[source]#
Returns True for non-inexact types, False otherwise.
- Parameters:
value (
Any) – A value to check.- Return type:
bool
Note
is_nondiff()uses single dispatch to support custom types. To define a custom behavior for a certain type, useis_nondiff.def_type(type, func).
Example
>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tc.is_nondiff(jnp.array(1)) # int array is non-diff type True >>> tc.is_nondiff(jnp.array(1.)) # float array is diff type False >>> tc.is_nondiff(1) # int is non-diff type True >>> tc.is_nondiff(1.) # float is diff type False
Note
This function is meant to be used with
jax.tree_mapto create a mask for non-differentiable nodes in a tree, that can be used to freeze the non-differentiable nodes before passing the tree to ajaxtransformation.
- pytreeclass.freeze(value)[source]#
Freeze a value to avoid updating it by
jaxtransformations.- Parameters:
value (
TypeVar(T)) – A value to freeze.- Return type:
_FrozenBase[TypeVar(T)]
Note
freeze()is idempotent, i.e.freeze(freeze(x)) == freeze(x).
Example
>>> import jax >>> import pytreeclass as tc >>> import jax.tree_util as jtu >>> # Usage with `jax.tree_util.tree_leaves` >>> # no leaves for a wrapped value >>> jtu.tree_leaves(tc.freeze(2.)) []
>>> # retrieve the frozen wrapper value using `is_leaf=tc.is_frozen` >>> jtu.tree_leaves(tc.freeze(2.), is_leaf=tc.is_frozen) [#2.0]
>>> # Usage with `jax.tree_util.tree_map` >>> a= [1,2,3] >>> a[1] = tc.freeze(a[1]) >>> jtu.tree_map(lambda x:x+100, a) [101, #2, 103]
- pytreeclass.unfreeze(value)[source]#
Unfreeze
freeze()value, otherwise return the value itself.- Parameters:
value (
TypeVar(T)) – A value to unfreeze.- Return type:
TypeVar(T)
Note
use
is_leaf=tc.is_frozenwithjax.tree_mapto unfreeze a tree.**
Example
>>> import pytreeclass as tc >>> import jax >>> frozen_value = tc.freeze(1) >>> tc.unfreeze(frozen_value) 1 >>> # usage with `jax.tree_map` >>> frozen_tree = jax.tree_map(tc.freeze, {"a": 1, "b": 2}) >>> unfrozen_tree = jax.tree_map(tc.unfreeze, frozen_tree, is_leaf=tc.is_frozen) >>> unfrozen_tree {'a': 1, 'b': 2}
- pytreeclass.is_frozen(value)[source]#
Returns True if the value is a frozen wrapper.
- Return type:
bool
- pytreeclass.tree_mask(tree, mask=<function is_nondiff>, *, is_leaf=None)[source]#
Mask leaves of a pytree based on
maskboolean pytree or callable.- Parameters:
tree (
TypeVar(T)) – A pytree of values.mask (
Union[TypeVar(T),Callable[[Any],bool]]) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf isTrueeither in the mask or the callable, the leaf is wrapped by with a wrapper that yields no leaves whenjax.tree_util.tree_flattenis called on it, otherwise it is unchanged. defaults tois_nondiff()which returns true for non-differentiable nodes.is_leaf (
Optional[Callable[[Any],bool]]) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example,is_leaf=lambda x: isinstance(x, list)will treat lists as leaves and will not recurse into them.
Note
Masked leaves are wrapped with a wrapper that yields no leaves when
jax.tree_util.tree_flattenis called on it.Masking is equivalent to applying
freeze()to the masked leaves.>>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> def mask_if_nondiff(x): ... return tc.freeze(x) if tc.is_nondiff(x) else x >>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
Use masking on tree containing non-differentiable nodes before passing the tree to a
jaxtransformation.
Example
>>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
>>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
- pytreeclass.tree_unmask(tree, mask=<function <lambda>>)[source]#
Undo the masking of tree leaves according to
mask. defaults to unmasking all leaves.- Parameters:
tree (
TypeVar(T)) – A pytree of values.mask (
Union[TypeVar(T),Callable[[Any],bool]]) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is True either in the mask or the callable, the leaf is unfrozen, otherwise it is unchanged. defaults unmasking all nodes.
Example
>>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
>>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
Note
Unmasking is equivalent to applying
unfreeze()on the masked leaves.>>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # unmask all nodes >>> tree = jax.tree_map(tc.unfreeze, tree, is_leaf=tc.is_frozen)