📜 Stateful computations#
[ ]:
!pip install pytreeclass --quiet
In this notebook, we demonstrate how to handle internal states in the immutable pytreeclass
with functional API.
First, Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass
no need to separate the instance variables; instead the whole instance is passed as a state.
[ ]:
# wip