💼 Masking API#
- pytreeclass.is_nondiff(value)[source]#
Returns True for non-inexact types, False otherwise.
- Parameters:
value (
Any
) – A value to check.- Return type:
bool
Note
is_nondiff()
uses single dispatch to support custom types. To define a custom behavior for a certain type, useis_nondiff.def_type(type, func)
.
Example
>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tc.is_nondiff(jnp.array(1)) # int array is non-diff type True >>> tc.is_nondiff(jnp.array(1.)) # float array is diff type False >>> tc.is_nondiff(1) # int is non-diff type True >>> tc.is_nondiff(1.) # float is diff type False
Note
This function is meant to be used with
jax.tree_map
to create a mask for non-differentiable nodes in a tree, that can be used to freeze the non-differentiable nodes before passing the tree to ajax
transformation.
- pytreeclass.freeze(value)[source]#
Freeze a value to avoid updating it by through function transformations.
- Parameters:
value (
TypeVar
(T
)) – A value to freeze.- Return type:
_FrozenBase
[TypeVar
(T
)]
Note
freeze()
is idempotent, i.e.freeze(freeze(x)) == freeze(x)
.
Example
>>> import jax >>> import pytreeclass as tc >>> import jax.tree_util as jtu >>> # Usage with `jax.tree_util.tree_leaves` >>> # no leaves for a wrapped value >>> jtu.tree_leaves(tc.freeze(2.)) []
>>> # retrieve the frozen wrapper value using `is_leaf=tc.is_frozen` >>> jtu.tree_leaves(tc.freeze(2.), is_leaf=tc.is_frozen) [#2.0]
>>> # Usage with `jax.tree_util.tree_map` >>> a= [1,2,3] >>> a[1] = tc.freeze(a[1]) >>> jtu.tree_map(lambda x:x+100, a) [101, #2, 103]
- pytreeclass.unfreeze(value)[source]#
Unfreeze
freeze()
value, otherwise return the value itself.- Parameters:
value (
TypeVar
(T
)) – A value to unfreeze.- Return type:
TypeVar
(T
)
Note
use
is_leaf=tc.is_frozen
withtree_map
to unfreeze a tree.**
Example
>>> import pytreeclass as tc >>> import jax >>> frozen_value = tc.freeze(1) >>> tc.unfreeze(frozen_value) 1 >>> # usage with `jax.tree_map` >>> frozen_tree = jax.tree_map(tc.freeze, {"a": 1, "b": 2}) >>> unfrozen_tree = jax.tree_map(tc.unfreeze, frozen_tree, is_leaf=tc.is_frozen) >>> unfrozen_tree {'a': 1, 'b': 2}
- pytreeclass.is_frozen(value)[source]#
Returns True if the value is a frozen wrapper.
- Return type:
bool
- pytreeclass.tree_mask(tree, mask=<function is_nondiff>, *, is_leaf=None)[source]#
Mask leaves of a pytree based on
mask
boolean pytree or callable.- Parameters:
tree (T) – A pytree of values.
mask (MaskType) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is
True
either in the mask or the callable, the leaf is wrapped by with a wrapper that yields no leaves whentree_flatten
is called on it, otherwise it is unchanged. defaults tois_nondiff()
which returns true for non-differentiable nodes.is_leaf (Callable[[Any], None] | None) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example,
is_leaf=lambda x: isinstance(x, list)
will treat lists as leaves and will not recurse into them.
Note
Masked leaves are wrapped with a wrapper that yields no leaves when
tree_flatten
is called on it.Masking is equivalent to applying
freeze()
to the masked leaves.>>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> def mask_if_nondiff(x): ... return tc.freeze(x) if tc.is_nondiff(x) else x >>> masked_tree = jax.tree_map(mask_if_nondiff, tree)
Use masking on tree containing non-differentiable nodes before passing the tree to a
jax
transformation.
Example
>>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
>>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
- pytreeclass.tree_unmask(tree, mask=<function <lambda>>)[source]#
Undo the masking of tree leaves according to
mask
. defaults to unmasking all leaves.- Parameters:
tree (
TypeVar
(T
)) – A pytree of values.mask (
Union
[TypeVar
(T
),Callable
[[Any
],bool
]]) – A pytree of boolean values or a callable that accepts a leaf and returns a boolean. If a leaf is True either in the mask or the callable, the leaf is unfrozen, otherwise it is unchanged. defaults unmasking all nodes.
Example
>>> import pytreeclass as tc >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = tc.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> tc.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
>>> # pass non-differentiable values to `jax.grad` >>> import pytreeclass as tc >>> import jax >>> @jax.grad ... def square(tree): ... tree = tc.tree_unmask(tree) ... return tree[0]**2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(tc.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
Note
Unmasking is equivalent to applying
unfreeze()
on the masked leaves.>>> import pytreeclass as tc >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # unmask all nodes >>> tree = jax.tree_map(tc.unfreeze, tree, is_leaf=tc.is_frozen)