# Copyright 2023 pytreeclass authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constructor code generation from type annotations."""
# this modules contains lots of functionality similar to `dataclasses` and attrs.
# however, notable differences are:
# - allow marking fields as positional only, keyword only, variable positional,...
# - allow applying functions on the field values during initialization using descriptors.
# - does not allow mutable defaults.
# - allow registering additional types to be excluded from `autoinit`. e.g. raise an error.
# - only code generation is supported is done. other functionality like `__repr__`,
# `__eq__`, `__hash__`, etc. are not done here.
# one design choice is that `autoinit` and `Field` are not tightly coupled.
# Field` can be used without `autoinit` as a descriptor to apply functions on
# the field values during initialization. Moreover, `TreeClass` is not coupled with
# `autoinit` or `Field` and can be used without them. this simplifies the code
# by separating the functionality.
from __future__ import annotations
import functools as ft
import sys
from collections import defaultdict
from collections.abc import Callable, MutableMapping, MutableSequence, MutableSet
from typing import Any, Literal, Mapping, Sequence, TypeVar, Union, get_args
from typing_extensions import dataclass_transform
T = TypeVar("T")
PyTree = Any
EllipsisType = type(Ellipsis)
ArgKindType = Literal["POS_ONLY", "POS_OR_KW", "VAR_POS", "KW_ONLY", "VAR_KW"]
ArgKind = get_args(ArgKindType)
@ft.singledispatch
def check_excluded_type(value: T) -> None:
...
@check_excluded_type.register(MutableSequence)
@check_excluded_type.register(MutableMapping)
@check_excluded_type.register(MutableSet)
def _(value) -> None:
raise TypeError(f"Mutable {value=} is not allowed.")
class Null:
__slots__ = []
__repr__ = lambda _: "NULL"
__bool__ = lambda _: False
NULL = Null()
def slots(klass) -> tuple[str, ...]:
return getattr(klass, "__slots__", ())
class Field:
"""Field placeholder for `autoinit`."""
__slots__ = [
"name",
"type",
"default",
"init",
"repr",
"kind",
"metadata",
"on_setattr",
"on_getattr",
"alias",
]
def __init__(
self,
*,
name: str | Null = NULL,
type: type | Null = NULL,
default: Any = NULL,
init: bool = True,
repr: bool = True,
kind: ArgKind = "POS_OR_KW",
metadata: dict[str, Any] | None = None,
on_setattr: Sequence[Callable[[Any], Any]] = (),
on_getattr: Sequence[Callable[[Any], Any]] = (),
alias: str | None = None,
):
"""Initialize the field attributes.
Args:
name: The field name.
type: The field type.
default: The default value of the field.
init: Whether the field is included in the object's ``__init__`` function.
repr: Whether the field is included in the object's ``__repr__`` function.
kind: Argument kind, one of:
- ``POS_ONLY``: positional only argument (e.g. ``x`` in ``def f(x, /):``)
- ``VAR_POS``: variable positional argument (e.g. ``*x`` in ``def f(*x):``)
- ``POS_OR_KW``: positional or keyword argument (e.g. ``x`` in ``def f(x):``)
- ``KW_ONLY``: keyword only argument (e.g. ``x`` in ``def f(*, x):``)
- ``VAR_KW``: variable keyword argument (e.g. ``**x`` in ``def f(**x):``)
metadata: A mapping of user-defined data for the field.
on_setattr: A sequence of functions called on ``__setattr__``.
on_getattr: A sequence of functions called on ``__getattr__``.
alias: An a alias for the field name in the constructor. e.g ``name=x``,
``alias=y`` will allow ``obj = Class(y=1)`` to be equivalent to
``obj = Class(x=1)``.
"""
self.name = name
self.type = type
self.default = default
self.init = init
self.repr = repr
self.kind = kind
self.metadata = metadata
self.on_setattr = on_setattr
self.on_getattr = on_getattr
self.alias = alias
def replace(self, **kwargs) -> Field:
"""Replace the field attributes."""
# define a `replace` method similar to `dataclasses.replace` or namedtuple
# to allow the user to replace the field attributes.
return type(self)(**{k: kwargs.get(k, getattr(self, k)) for k in slots(Field)})
def pipe(self, funcs: Sequence[Callable[[Any], Any]], value: Any):
"""Apply a sequence of functions on the field value."""
for func in funcs:
# for a given sequence of unary functions, apply them on the field value
# and return the result. if an error is raised, emit a descriptive error
try:
value = func(value)
except Exception as e:
# emit a *descriptive* error message with the name of the attribute
# associated with the field and the name of the function that raised
# the error.
cname = getattr(func, "__name__", func)
raise type(e)(f"On applying {cname} for field=`{self.name}`:\n{e}")
return value
def __repr__(self) -> str:
"""Return the string representation of the field."""
attrs = [f"{k}={getattr(self, k)!r}" for k in slots(Field)]
return f"{type(self).__name__}({', '.join(attrs)})"
def __set_name__(self, owner, name: str) -> None:
"""Set the field name."""
# set the name of the field to the attribute name in the class
# and the type to the type hint of the attribute if it exists
self.name = name
# in case the user uses `field` as a descriptor without annotating the class
if "__annotations__" in (variables := vars(owner)):
# set the type to the type hint of the attribute if it exists
self.type = variables.get(name, NULL)
def __get__(self: T, instance, _) -> T | Any:
"""Return the field value."""
if instance is None:
return self
return self.pipe(self.on_getattr, vars(instance)[self.name])
def __set__(self: T, instance, value) -> None:
"""Set the field value."""
vars(instance)[self.name] = self.pipe(self.on_setattr, value)
def __delete__(self: T, instance) -> None:
"""Delete the field value."""
del vars(instance)[self.name]
[docs]def field(
*,
default: Any = NULL,
init: bool = True,
repr: bool = True,
kind: ArgKindType = "POS_OR_KW",
metadata: dict[str, Any] | None = None, # type: ignore
on_setattr: Sequence[Any] = (),
on_getattr: Sequence[Any] = (),
alias: str | None = None,
) -> Field:
"""Field placeholder for type hinted attributes.
Args:
default: The default value of the field.
init: Whether the field is included in the object's ``__init__`` function.
repr: Whether the field is included in the object's ``__repr__`` function.
kind: Argument kind, one of:
- ``POS_ONLY``: positional only argument (e.g. ``x`` in ``def f(x, /):``)
- ``VAR_POS``: variable positional argument (e.g. ``*x`` in ``def f(*x):``)
- ``POS_OR_KW``: positional or keyword argument (e.g. ``x`` in ``def f(x):``)
- ``KW_ONLY``: keyword only argument (e.g. ``x`` in ``def f(*, x):``)
- ``VAR_KW``: variable keyword argument (e.g. ``**x`` in ``def f(**x):``)
metadata: A mapping of user-defined data for the field.
on_setattr: A sequence of functions to called on ``__setattr__``.
on_getattr: A sequence of functions to called on ``__getattr__``.
alias: An a alias for the field name in the constructor. e.g ``name=x``,
``alias=y`` will allow ``obj = Class(y=1)`` to be equivalent to
``obj = Class(x=1)``.
Example:
Type and range validation using :attr:`on_setattr`:
>>> import pytreeclass as tc
>>> @tc.autoinit
... class IsInstance(tc.TreeClass):
... klass: type
... def __call__(self, x):
... assert isinstance(x, self.klass)
... return x
<BLANKLINE>
>>> @tc.autoinit
... class Range(tc.TreeClass):
... start: int|float = float("-inf")
... stop: int|float = float("inf")
... def __call__(self, x):
... assert self.start <= x <= self.stop
... return x
<BLANKLINE>
>>> @tc.autoinit
... class Employee(tc.TreeClass):
... # assert employee ``name`` is str
... name: str = tc.field(on_setattr=[IsInstance(str)])
... # use callback compostion to assert employee ``age`` is int and positive
... age: int = tc.field(on_setattr=[IsInstance(int), Range(1)])
>>> employee = Employee(name="Asem", age=10)
>>> print(employee)
Employee(name=Asem, age=10)
Example:
Private attribute using :attr:`alias`:
>>> import pytreeclass as tc
>>> @tc.autoinit
... class Employee(tc.TreeClass):
... # `alias` is the name used in the constructor
... _name: str = tc.field(alias="name")
>>> employee = Employee(name="Asem") # use `name` in the constructor
>>> print(employee) # `_name` is the private attribute name
Employee(_name=Asem)
Example:
Buffer creation using :attr:`on_getattr`:
>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> @tc.autoinit
... class Tree(tc.TreeClass):
... buffer: jax.Array = tc.field(on_getattr=[jax.lax.stop_gradient])
>>> tree = Tree(buffer=jnp.array((1.0, 2.0)))
>>> def sum_buffer(tree):
... return tree.buffer.sum()
>>> print(jax.grad(sum_buffer)(tree)) # no gradient on `buffer`
Tree(buffer=[0. 0.])
Example:
Parameterization using :attr:`on_getattr`:
>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> def symmetric(array: jax.Array) -> jax.Array:
... triangle = jnp.triu(array) # upper triangle
... return triangle + triangle.transpose(-1, -2)
>>> @tc.autoinit
... class Tree(tc.TreeClass):
... symmetric_matrix: jax.Array = tc.field(on_getattr=[symmetric])
>>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
>>> print(tree.symmetric_matrix)
[[ 0 1 2]
[ 1 8 5]
[ 2 5 16]]
Note:
- :func:`field` is commonly used to annotate the class attributes to be
used by the :func:`autoinit` decorator to generate the ``__init__``
method similar to ``dataclasses.dataclass``.
- :func:`field` can be used without the :func:`autoinit` as a descriptor
to apply functions on the field values during initialization using
the ``on_setattr`` / ``on_getattr`` argument.
>>> import pytreeclass as tc
>>> def print_and_return(x):
... print(f"Setting {x}")
... return x
>>> class Tree:
... # `a` must be defined as a class attribute for the descriptor to work
... a: int = tc.field(on_setattr=[print_and_return])
... def __init__(self, a):
... self.a = a
>>> tree = Tree(1)
Setting 1
"""
if not isinstance(alias, (str, type(None))):
raise TypeError(f"Non-string {alias=} argument provided to `field`")
if not isinstance(metadata, (dict, type(None))):
raise TypeError(f"Non-dict {metadata=} argument provided to `field`")
if kind not in ArgKind:
raise ValueError(f"{kind=} not in {ArgKind}")
if not isinstance(on_setattr, Sequence):
raise TypeError(f"Non-sequence {on_setattr=} argument provided to `field`")
if not isinstance(on_getattr, Sequence):
raise TypeError(f"Non-sequence {on_getattr=} argument provided to `field`")
if not isinstance(init, bool):
raise TypeError(f"Non-bool {init=} argument provided to `field`")
for func in on_setattr:
if not isinstance(func, Callable): # type: ignore
raise TypeError(f"Non-callable {func=} provided to `field` on_setattr")
for func in on_getattr:
if not isinstance(func, Callable):
raise TypeError(f"Non-callable {func=} provided to `field` on_getattr")
return Field(
default=default,
init=init,
repr=repr,
kind=kind,
metadata=metadata, # type: ignore
on_setattr=on_setattr,
on_getattr=on_getattr,
alias=alias,
)
def build_field_map(klass: type) -> dict[str, Field]:
field_map: dict[str, Field] = dict()
excluded = set(["self", "__post_init__", "__annotations__"])
if klass is object:
return dict(field_map)
for base in reversed(klass.__mro__[1:]):
field_map.update(build_field_map(base))
if (hint_map := vars(klass).get("__annotations__", NULL)) is NULL:
return dict(field_map)
if excluded.intersection(hint_map):
raise ValueError(f"`Field` name cannot be in {excluded=}")
for key, hint in hint_map.items():
# get the current base key
value = vars(klass).get(key, NULL)
if not isinstance(value, Field):
# non-`Field` annotation is ignored
# non-autoinit base class type hints are ignored
continue
# in case the user uses mutable defaults or any other user-defined
# excluded types, raise an error
check_excluded_type(value.default)
# case: `x: Any = field(default=1)`
field_map[key] = value.replace(name=key, type=hint)
return field_map
[docs]def fields(x: Any) -> tuple[Field, ...]:
"""Returns a tuple of ``Field`` objects for the given instance or class.
``Field`` objects are generated from the class type hints and contains
the information about the field information.if the user uses
the ``pytreeclass.field`` to annotate.
Note:
- If the class is not annotated, an empty tuple is returned.
- The ``Field`` generation is cached for class and its bases.
"""
return tuple(build_field_map(x if isinstance(x, type) else type(x)).values())
def convert_hints_to_fields(klass: type[T]) -> type[T]:
# convert klass hints to `Field` objects for the current decorated class
if (hint_map := vars(klass).get("__annotations__", NULL)) is NULL:
return klass
for key, hint in hint_map.items():
if not isinstance(value := vars(klass).get(key, NULL), Field):
setattr(klass, key, Field(default=value, type=hint, name=key))
return klass
def check_duplicate_var_kind(field_map: dict[str, Field]) -> None:
# check for duplicate `VAR_POS` and `VAR_KW` arguments
seen: set[Literal["VAR_POS", "VAR_KW"]] = set()
for field in field_map.values():
if field.kind in ("VAR_POS", "VAR_KW"):
# disallow multiple `VAR_POS` and `VAR_KW` arguments
# for example more than one field(kind="VAR_POS") is not allowed
if field.kind in seen:
raise TypeError(f"Duplicate {field.kind=} for {field.name=}")
seen.add(field.kind)
def build_init_method(klass: type[T]) -> type[T]:
field_map: dict[str, Field] = build_field_map(klass)
check_duplicate_var_kind(field_map)
hints = {"return": None} # annotations
body: list[str] = []
head: list[str] = ["self"]
heads: dict[str, list[str]] = defaultdict(list)
for field in field_map.values():
if field.init:
# add to field to head and body
hints[field.name] = field.type
# how to name the field in the constructor
alias = field.alias or field.name
body += [f"self.{field.name}={alias}"]
if field.default is NULL:
# e.g. def __init__(.., x)
heads[field.kind] += [alias]
else:
# e.g def __init__(.., x=value) but
# pass reference to the default value
heads[field.kind] += [f"{alias}=refmap['{field.name}'].default"]
else:
if field.default is not NULL:
# case for fields with `init=False` and no default value
# usaully declared in __post_init__
body += [f"self.{field.name}=refmap['{field.name}'].default"]
has_post = (key := "__post_init__") in vars(klass)
body += [f"self.{key}()"] if has_post else ["pass"]
# organize the arguments order:
# (POS_ONLY, POS_OR_KW, VAR_POS, KW_ONLY, VAR_KW)
head += (heads["POS_ONLY"] + ["/"]) if heads["POS_ONLY"] else []
head += heads["POS_OR_KW"]
head += ["*" + "".join(heads["VAR_POS"])] if heads["VAR_POS"] else []
# case for ...(*a, b) and ...(a, *, b)
head += ["*"] if (heads["KW_ONLY"] and not heads["VAR_POS"]) else []
head += heads["KW_ONLY"]
head += ["**" + "".join(heads["VAR_KW"])] if heads["VAR_KW"] else []
# generate the code for the method
code = "def closure(refmap):\n"
code += f"\tdef __init__({','.join(head)}):"
field_map["__annotations__"] = hints
code += f"\n\t\t{';'.join(body)}"
code += f"\n\t__init__.__qualname__ = '{klass.__qualname__}.__init__'"
code += f"\n\t__init__.__annotations__ = refmap['__annotations__']"
code += "\n\treturn __init__"
# execute the code in the class namespace to generate the method
exec(code, vars(sys.modules[klass.__module__]), namespace := dict())
method = namespace["closure"](field_map)
# add the method to the class
setattr(klass, "__init__", method)
return klass
[docs]@dataclass_transform(field_specifiers=(Field, field))
def autoinit(klass: type[T]) -> type[T]:
"""A class decorator that generates the ``__init__`` method from type hints.
Similar to ``dataclasses.dataclass``, this decorator generates the ``__init__``
method for the given class from the type hints or the :func:`field` objects
set to the class attributes.
Compared to ``dataclasses.dataclass``, ``autoinit`` with :func:`field` objects
can be used to apply functions on the field values during initialization,
and/or support multiple argument kinds.
Example:
>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree:
... x: int
... y: int
>>> tree = Tree(1, 2)
>>> tree.x, tree.y
(1, 2)
Example:
>>> # define fields with different argument kinds
>>> import pytreeclass as tc
>>> @tc.autoinit
... class Tree:
... kw_only_field: int = tc.field(default=1, kind="KW_ONLY")
... pos_only_field: int = tc.field(default=2, kind="POS_ONLY")
Example:
>>> # define a converter to apply ``abs`` on the field value
>>> @tc.autoinit
... class Tree:
... a:int = tc.field(on_setattr=[abs])
>>> Tree(a=-1).a
1
.. warning::
- The ``autoinit`` decorator will is no-op if the class already has a
user-defined ``__init__`` method.
Note:
- In case of inheritance, the ``__init__`` method is generated from the
the type hints of the current class and any base classes that
are decorated with ``autoinit``.
>>> import pytreeclass as tc
>>> import inspect
>>> @tc.autoinit
... class Base:
... x: int
>>> @tc.autoinit
... class Derived(Base):
... y: int
>>> obj = Derived(x=1, y=2)
>>> inspect.signature(obj.__init__)
<Signature (x: int, y: int) -> None>
- Base classes that are not decorated with ``autoinit`` are ignored during
synthesis of the ``__init__`` method.
>>> import pytreeclass as tc
>>> import inspect
>>> class Base:
... x: int
>>> @tc.autoinit
... class Derived(Base):
... y: int
>>> obj = Derived(y=2)
>>> inspect.signature(obj.__init__)
<Signature (y: int) -> None>
Note:
Use ``autoinit`` instead of ``dataclasses.dataclass`` if you want to
use ``jax.Array`` as a field default value. As ``dataclasses.dataclass``
will incorrectly raise an error starting from python 3.11 complaining
that ``jax.Array`` is not immutable.
Note:
By default ``autoinit`` will raise an error if the user uses mutable defaults.
To register an additional type to be excluded from ``autoinit``, use
:func:`autoinit.register_excluded_type`, with an optional ``reason``
for excluding the type.
>>> import pytreeclass as tc
>>> class T:
... pass
>>> tc.autoinit.register_excluded_type(T, reason="not allowed")
>>> @tc.autoinit
... class Tree:
... x: T = tc.field(default=T()) # doctest: +SKIP
Traceback (most recent call last):
...
"""
return (
klass
# if the class already has a user-defined __init__ method
# then return the class as is without any modification
if "__init__" in vars(klass)
# first convert the current class hints to fields
# then build the __init__ method from the fields of the current class
# and any base classes that are decorated with `autoinit`
else build_init_method(convert_hints_to_fields(klass))
)
def register_excluded_type(klass: type, reason: str | None = None) -> None:
"""Exclude a type from being used in the ``autoinit`` decorator.
Args:
klass: The type to be excluded.
reason: The reason for excluding the type.
"""
reason = f" {reason=}" if reason is not None else ""
@check_excluded_type.register(klass)
def _(value) -> None:
raise TypeError(f"{value=} is excluded from `autoinit`.{reason}")
autoinit.register_excluded_type = register_excluded_type