Open In Colab

🧮 Building a mini optimizer library#

In the following example an optax-like mini-library is built using pytreeclass. The optimizer library is then used to train a simple neural network.

[1]:
!pip install pytreeclass --quiet

Imports#

[2]:
import jax
import jax.numpy as jnp
import pytreeclass as tc
from typing import Any, TypeVar, Generic, Callable
import abc
import matplotlib.pyplot as plt

PyTree = Any
T = TypeVar("T")

Template#

[3]:
class GradientTransformation(tc.TreeClass, Generic[T]):
    # TreeClass inherits from `abc.ABC`
    # following the same pattern as `optax` optimizers
    # https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/base.py#L85
    @abc.abstractclassmethod
    def update(self, _: T) -> tuple[T, "GradientTransformation"]:
        pass

Optimizer functions#

[4]:
def moment_update(grads, moments, *, beta: float, order: int):
    def moment_step(grad, moment):
        return beta * moment + (1 - beta) * (grad**order)

    return jax.tree_map(moment_step, grads, moments)


def debias_update(moments, *, beta: float, count: int):
    def debias_step(moment):
        return moment / (1 - beta**count)

    return jax.tree_map(debias_step, moments)


def ema(decay_rate: float, debias: float = True) -> GradientTransformation:
    """Exponential moving average

    Args:
        decay_rate: The decay rate of the moving average.
        debias: Whether to debias the moving average.
    """

    class EMA(GradientTransformation):
        def __init__(self, tree):
            self.state = jax.tree_map(jnp.zeros_like, tree)
            self.count = 0

        def _update(self, grads: T) -> T:
            self.count += 1
            self.state = moment_update(grads, self.state, beta=decay_rate, order=1)
            if debias:
                return debias_update(self.state, beta=decay_rate, count=self.count)
            return self.state

        def update(self, grads: T) -> tuple[T, "EMA"]:
            return self.at["_update"](grads)

    return EMA


def adam(
    *,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
) -> GradientTransformation:
    """Adam optimizer

    Args:
        beta1: The decay rate of the first moment.
        beta2: The decay rate of the second moment.
        eps: A small value to prevent division by zero.

    Note:
        Kingma et al, 2014: https://arxiv.org/abs/1412.6980
    """

    class Adam(GradientTransformation):
        def __init__(self, tree):
            self.mu = jax.tree_map(jnp.zeros_like, tree)
            self.nu = jax.tree_map(jnp.zeros_like, tree)
            self.count = 0

        def _update(self, grads: T) -> T:
            self.count += 1
            self.mu = moment_update(grads, self.mu, beta=beta1, order=1)
            self.nu = moment_update(grads, self.nu, beta=beta2, order=2)
            mu_hat = debias_update(self.mu, beta=beta1, count=self.count)
            nu_hat = debias_update(self.nu, beta=beta2, count=self.count)
            return jax.tree_map(
                lambda mu, nu: mu / (jnp.sqrt(nu) + eps), mu_hat, nu_hat
            )

        def update(self, grads: T) -> tuple[T, "Adam"]:
            # since self._update mutates the state, we need to use self.at
            # to return the method value and the mutated state
            return self.at["_update"](grads)

    return Adam


def scale(rate_func: Callable[[int], float]) -> GradientTransformation:
    """Scale the gradients by a scheduler function"""

    class Scale(GradientTransformation):
        def __init__(self, _=None):
            self.count = 0
            self.rate = rate_func(self.count)

        def _update(self, grads: T) -> T:
            self.count += 1
            self.rate = rate_func(self.count)
            return jax.tree_map(lambda x: x * self.rate, grads)

        def update(self, grads: T) -> tuple[T, "Scale"]:
            return self.at["_update"](grads)

    return Scale


def chain(
    *transformations: tuple[GradientTransformation, ...]
) -> GradientTransformation:
    """Chain multiple transformations together similar to `optax.chain`"""

    class Chain(GradientTransformation):
        def __init__(self, tree):
            self.transformations = tuple(T(tree) for T in transformations)

        def _update(self, grads: T) -> T:
            state = []
            for transformation in self.transformations:
                grads, optim_state = transformation.update(grads)
                state += [optim_state]
            self.transformations = tuple(state)
            return grads

        def update(self, grads: T) -> tuple[T, "Chain"]:
            return self.at["_update"](grads)

    return Chain

Construct a fully connected neural network#

[5]:
class FNN(tc.TreeClass):
    def __init__(self):
        self.w1 = jax.random.normal(jax.random.PRNGKey(0), [1, 10])
        self.b1 = jnp.zeros([10], dtype=jnp.float32)
        self.w2 = jax.random.normal(jax.random.PRNGKey(1), [10, 1])
        self.b2 = jnp.zeros([1], dtype=jnp.float32)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = x @ self.w1 + self.b1
        x = jax.nn.relu(x)
        x = x @ self.w2 + self.b2
        return x

Train function#

[6]:
def scheduler(count: int) -> float:
    # build a scheduler function
    return jnp.where(count < 2_000, -1e-2, jnp.where(count < 5_000, -1e-3, -1e-4))


fnn = FNN()
optim = chain(adam(), scale(scheduler))
optim_state = optim(fnn)


def loss_func(fnn: FNN, x: jax.Array, y: jax.Array) -> jax.Array:
    return jnp.mean((fnn(x) - y) ** 2)


@jax.jit
def train_step(fnn, optim_state, x, y):
    grads = jax.grad(loss_func)(fnn, x, y)

    grads, optim_state = optim_state.update(grads)
    fnn = jax.tree_map(lambda p, g: p + g, fnn, grads)
    return fnn, optim_state


x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
y = x**2 + 0.1

for i in range(1, 10_000 + 1):
    fnn, optim_state = train_step(fnn, optim_state, x, y)
    if i % 1_000 == 0:
        loss = loss_func(fnn, x, y)
        learning_rate = optim_state.transformations[1].rate
        print(
            f"Epoch={i:003d}\tLoss: {loss:.3e}\t Gradient scaler: {learning_rate:.3e}"
        )

plt.plot(x, y, label="true")
plt.plot(x, fnn(x), label="pred")
plt.legend()
Epoch=1000      Loss: 1.322e-04  Gradient scaler: -1.000e-02
Epoch=2000      Loss: 8.230e-05  Gradient scaler: -1.000e-03
Epoch=3000      Loss: 8.062e-05  Gradient scaler: -1.000e-03
Epoch=4000      Loss: 7.908e-05  Gradient scaler: -1.000e-03
Epoch=5000      Loss: 7.730e-05  Gradient scaler: -1.000e-04
Epoch=6000      Loss: 7.702e-05  Gradient scaler: -1.000e-04
Epoch=7000      Loss: 7.657e-05  Gradient scaler: -1.000e-04
Epoch=8000      Loss: 7.584e-05  Gradient scaler: -1.000e-04
Epoch=9000      Loss: 7.464e-05  Gradient scaler: -1.000e-04
Epoch=10000     Loss: 7.271e-05  Gradient scaler: -1.000e-04
[6]:
<matplotlib.legend.Legend at 0x1313c18d0>
../_images/notebooks_build_mini_optimizer_library_12_2.png