Open In Colab

📜 Stateful computations#

[ ]:
!pip install pytreeclass --quiet

In this notebook, we demonstrate how to handle internal states in the immutable pytreeclass with functional API.

First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass no need to separate the instance variables; instead the whole instance is passed as a state.

[ ]:
# wip