Open In Colab

🥶 Dealing with non-jax types#

In essence, how to pass non-inexact types (e.g. int, str, Callables, …) over jax transformations like jax.grad

jax and inexact data types#

jax transformations like jax.grad can handle pytrees of inexact data types ( float, complex, array of float/complex). any other input type will lead to type error, the following example shows this.

[1]:
!pip install pytreeclass --quiet
[2]:
import jax


@jax.grad
def identity_grad(x):
    return sum(x)


# valid input
identity_grad([1.0, 1.0])

# invalid input (not in-exact)
try:
    identity_grad([1])
except TypeError as e:
    print(e)
grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Using tree_mask#

However, in cases when you function needs to pass non-inexact data type, we can mask the non-inexact typed leaves with a frozen wrapper through pytreeclass.tree_mask. Masked leaves are wrapped with a wrapper that yields no leaves when interacting with jax transformations, akin to being frozen.

The following is an example of how to use tree_mask to deal with non-inexact datatype

[3]:
import pytreeclass as tc


@jax.grad
def identity_grad(x):
    return 1.0


try:
    # this will fail
    identity_grad([1, 1.0])
except TypeError as e:
    print(e)
# this will work because the tree_mask will
# wrap the non-inexact type (int)
identity_grad([tc.tree_mask(1), 1.0])
grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
[3]:
[#1, Array(0., dtype=float32, weak_type=True)]

Notice that using tc.tree_mask we were able to pass non-inexact type to jax transformation without jax complaining. however, inside the function we need to unmask this value if we want to use it, if do not need to use the value , we dont need to worry about unfreezing it. the following example shows that.

[4]:
import pytreeclass as tc


@jax.grad
def identity_grad(x):
    # this function does not use the frozen value
    return x[1] ** 2


print(identity_grad([tc.tree_mask(1), 1.0]))
[#1, Array(2., dtype=float32, weak_type=True)]

However, if we need to pass non-inexact value to the function to use inside the function we need to freeze it before passing it to the function, and unfreeze it inside the function. The next example explain this concept

[5]:
import pytreeclass as tc


@jax.grad
def func(x):
    # this function uses the non-inexact and inexact values
    # the non-inexact value is frozen so we need to unfreeze it
    x = tc.tree_unmask(x)
    return x[0] ** 2 + x[1] ** 2


print(func([tc.tree_mask(1), 1.0]))
[#1, Array(2., dtype=float32, weak_type=True)]

The result of previous cell reveals something interesting, we know that \(\frac{d}{dx} x^2 = 2x\), however this derivative is only evaluated for the inexact value of type float and returned the result as Array(2.), but for the value of type int which was frozen on input, it has not changed. this is working as intended, in fact we can use this mechanism not only to pass invalid types to jax transformation without raising an error, but we can use this scheme to prevent values from being updated/take derivative with respect to. the following example shows this:

[6]:
import pytreeclass as tc
import jax


@jax.grad
def func(x):
    x = tc.tree_unmask(x)
    return x**2


# using `tree_mask` with a mask that always returns `True`
# to select all leaves
print(func(tc.tree_mask(1.0, mask=lambda _: True)))

# or using `tc.freeze` to apply frozen wrapper directly
print(func(tc.freeze(1.0)))
#1.0
#1.0

Using tree_mask with a mask recipes#

The following examples shows how to effictively using tree_mask and TreeClass instances to freeze certain values.

[7]:
from __future__ import annotations
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import pytreeclass as tc


@tc.autoinit
class Tree(tc.TreeClass):
    a: int = 1
    b: float = 2.0
    c: jax.Array = jnp.array([3.0, 4.0, 5.0])


tree = Tree()
tree
[7]:
Tree(a=1, b=2.0, c=f32[3](μ=4.00, σ=0.82, ∈[3.00,5.00]))

Freeze leaves by specifying a mask#

[8]:
# lets freeze all int values
mask = jtu.tree_map(lambda x: isinstance(x, int), tree)
frozen_tree = tc.tree_mask(tree, mask)
print(frozen_tree)
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze by a mask
unfrozen_tree = tc.tree_unmask(frozen_tree)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])
Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])

Freeze leaves by specifying the leaf name#

Since tree_mask applies freeze using tree_map, in case of applying on single leaf, we can just use freeze directly.

[9]:
frozen_tree = tree.at["a"].apply(tc.freeze)
print(frozen_tree)  # `a` has a prefix `#`
# Tree(a=#1, b=2.0, c=[3. 4. 5.])

# frozen value are excluded from `tree_leaves`
print(jtu.tree_leaves(frozen_tree))
# [2.0, Array([3., 4., 5.], dtype=float32)]

# `a` does not get updated by `tree_map`
print(jtu.tree_map(lambda x: x + 100, frozen_tree))
# Tree(a=#1, b=102.0, c=[103. 104. 105.])

# unfreeze `a`
unfrozen_tree = tc.tree_unmask(frozen_tree)
print(unfrozen_tree)
# Tree(a=1, b=2.0, c=[3. 4. 5.])
Tree(a=#1, b=2.0, c=[3. 4. 5.])
[2.0, Array([3., 4., 5.], dtype=float32)]
Tree(a=#1, b=102.0, c=[103. 104. 105.])
Tree(a=1, b=2.0, c=[3. 4. 5.])