Source code for pytreeclass._src.tree_pprint

"""Utilities for pretty printing pytrees."""

from __future__ import annotations

import dataclasses as dc
import functools as ft
import importlib
import inspect
import math
from contextlib import suppress
from itertools import zip_longest
from types import FunctionType
from typing import Any, Callable, Literal, NamedTuple, Sequence

from typing_extensions import TypeAlias, TypedDict, Unpack

from pytreeclass._src.backend import arraylib, treelib
from pytreeclass._src.tree_util import (

class PPSpec(TypedDict):
    indent: int
    kind: Literal["REPR", "STR"]
    width: int
    depth: int | float

PyTree = Any

PP = Callable[[Any, Unpack[PPSpec]], str]

class ShapeDtypePP(NamedTuple):
    shape: tuple[int, ...]
    dtype: Any

def pp_dispatcher(node: Any, **spec: Unpack[PPSpec]) -> str:
    """Register a new or override an existing pretty printer by type using."""
    return general_pp(node, **spec)

def dataclass_pp(node: Any, **spec: Unpack[PPSpec]) -> str:
    name = type(node).__name__
    kvs = ((, vars(node)[]) for f in dc.fields(node) if f.repr)
    return name + "(" + pps(kvs, pp=attr_value_pp, **spec) + ")"

def general_pp(node: Any, **spec: Unpack[PPSpec]) -> str:
    # ducktyping and other fallbacks that are not covered by singledispatch

    if dc.is_dataclass(node):
        return dataclass_pp(node, **spec)

    text = f"{node!r}" if spec["kind"] == "REPR" else f"{node!s}"

    if "\n" not in text:
        return text

    return ("\n" + "\t" * (spec["indent"])).join(text.split("\n"))

def pp(node: Any, **spec: Unpack[PPSpec]) -> str:
    if spec["depth"] < 0:
        return "..."

    return format_width(pp_dispatcher(node, **spec), width=spec["width"])

def pps(xs: Sequence[Any], pp: PP, **spec: Unpack[PPSpec]) -> str:
    if spec["depth"] == 0:
        return "..."

    spec["indent"] += 1
    spec["depth"] -= 1

    text = (
        + "\t" * spec["indent"]
        + (", \n" + "\t" * spec["indent"]).join(pp(x, **spec) for x in xs)
        + "\n"
        + "\t" * (spec["indent"] - 1)

    return format_width(text, width=spec["width"])

def key_value_pp(x: tuple[str, Any], **spec: Unpack[PPSpec]) -> str:
    return f"{x[0]}:{pp(x[1], **spec)}"

def attr_value_pp(x: tuple[str, Any], **spec: Unpack[PPSpec]) -> str:
    return f"{x[0]}={pp(x[1], **spec)}"

def shape_dtype_pp(node: Any, **spec: Unpack[PPSpec]) -> str:
    """Pretty print a node with dtype and shape."""
    shape = f"{arraylib.shape(node)}".replace(",", "")
    shape = shape.replace("(", "[")
    shape = shape.replace(")", "]")
    shape = shape.replace(" ", ",")
    dtype = f"{arraylib.dtype(node)}".replace("int", "i")
    dtype = dtype.replace("float", "f")
    dtype = dtype.replace("complex", "c")
    return dtype + shape

def array_pp(node: arraylib.ndarray, **spec: Unpack[PPSpec]) -> str:
    """Replace ndarray repr with short hand notation for type and shape."""
    if spec["kind"] == "STR":
        return general_pp(node, **spec)

    base = shape_dtype_pp(node, **spec)

    if not (arraylib.is_floating(node) or arraylib.is_integer(node)):
        return base
    if arraylib.size(node) == 0:
        return base

    # Extended repr for numpy array, with extended information
    # this part of the function is inspired by
    # lovely-jax

    with suppress(Exception):
        # maybe the array is a jax tracers
        low, high = arraylib.min(node), arraylib.max(node)
        interval = "(" if math.isinf(low) else "["
        interval += (
            f"{low},{high}" if arraylib.is_integer(node) else f"{low:.2f},{high:.2f}"
        interval += ")" if math.isinf(high) else "]"
        interval = interval.replace("inf", "โˆž")

        mean, std = f"{arraylib.mean(node):.2f}", f"{arraylib.std(node):.2f}"
        return f"{base}(ฮผ={mean}, ฯƒ={std}, โˆˆ{interval})"

    return base

def func_pp(func: Callable, **spec: Unpack[PPSpec]) -> str:
    # Pretty print function
    # Example:
    #     >>> def example(a: int, b=1, *c, d, e=2, **f) -> str:
    #         ...
    #     >>> func_pp(example)
    #     "example(a, b, *c, d, e, **f)"
    del spec
    args, varargs, varkw, _, kwonlyargs, _, _ = inspect.getfullargspec(func)
    args = (", ".join(args)) if len(args) > 0 else ""
    varargs = ("*" + varargs) if varargs is not None else ""
    kwonlyargs = (", ".join(kwonlyargs)) if len(kwonlyargs) > 0 else ""
    varkw = ("**" + varkw) if varkw is not None else ""
    name = getattr(func, "__name__", "")
    text = f"{name}("
    text += ", ".join(item for item in [args, varargs, kwonlyargs, varkw] if item != "")
    text += ")"
    return text

def partial_pp(node: ft.partial, **spec: Unpack[PPSpec]) -> str:
    return f"Partial({func_pp(node.func, **spec)})"

def list_pp(node: list, **spec: Unpack[PPSpec]) -> str:
    return "[" + pps(node, pp=pp, **spec) + "]"

def tuple_pp(node: tuple, **spec: Unpack[PPSpec]) -> str:
    if not hasattr(node, "_fields"):
        return "(" + pps(node, pp=pp, **spec) + ")"
    name = type(node).__name__
    kvs = node._asdict().items()
    return name + "(" + pps(kvs, pp=attr_value_pp, **spec) + ")"

def set_pp(node: set, **spec: Unpack[PPSpec]) -> str:
    return "{" + pps(node, pp=pp, **spec) + "}"

def dict_pp(node: dict, **spec: Unpack[PPSpec]) -> str:
    return "{" + pps(node.items(), pp=key_value_pp, **spec) + "}"

def str_pp(node: str, **spec: Unpack[PPSpec]) -> str:
    return node

[docs]def tree_repr( tree: PyTree, *, width: int = 80, tabwidth: int = 2, depth: int | float = float("inf"), ) -> str: """Prertty print arbitrary pytrees ``__repr__``. Args: tree: arbitrary pytree. width: max width of the repr string. tabwidth: tab width of the repr string. depth: max depth of the repr string. Example: >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])} >>> print(tc.tree_repr(tree, depth=0)) {...} >>> print(tc.tree_repr(tree, depth=1)) {a:1, b:[...], c:{...}, f:i32[2](ฮผ=6.50, ฯƒ=0.50, โˆˆ[6,7])} >>> print(tc.tree_repr(tree, depth=2)) {a:1, b:[2, 3], c:{d:4, e:5}, f:i32[2](ฮผ=6.50, ฯƒ=0.50, โˆˆ[6,7])} """ text = pp(tree, indent=0, kind="REPR", width=width, depth=depth) return text.expandtabs(tabwidth)
[docs]def tree_str( tree: PyTree, *, width: int = 80, tabwidth: int = 2, depth: int | float = float("inf"), ) -> str: """Prertty print arbitrary pytrees ``__str__``. Args: tree: arbitrary pytree. width: max width of the str string. tabwidth: tab width of the repr string. depth: max depth of the repr string. Example: >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])} >>> print(tc.tree_str(tree, depth=1)) {a:1, b:[...], c:{...}, f:[6 7]} >>> print(tc.tree_str(tree, depth=2)) {a:1, b:[2, 3], c:{d:4, e:5}, f:[6 7]} """ text = pp(tree, indent=0, kind="STR", width=width, depth=depth) return text.expandtabs(tabwidth)
[docs]def tree_diagram( tree: Any, *, depth: int | float = float("inf"), is_leaf: Callable[[Any], None] | None = None, tabwidth: int = 4, ): """Pretty print arbitrary pytrees tree with tree structure diagram. Args: tree: arbitrary pytree. depth: depth of the tree to print. default is max depth. is_leaf: function to determine if a node is a leaf. default is None. tabwidth: tab width of the repr string. default is 4. Example: >>> import pytreeclass as tc >>> @tc.autoinit ... class A(tc.TreeClass): ... x: int = 10 ... y: int = (20,30) ... z: int = 40 >>> @tc.autoinit ... class B(tc.TreeClass): ... a: int = 10 ... b: tuple = (20,30, A()) >>> print(tc.tree_diagram(B(), depth=0)) B(...) >>> print(tc.tree_diagram(B(), depth=1)) B โ”œโ”€โ”€ .a=10 โ””โ”€โ”€ .b=(...) >>> print(tc.tree_diagram(B(), depth=2)) B โ”œโ”€โ”€ .a=10 โ””โ”€โ”€ .b:tuple โ”œโ”€โ”€ [0]=20 โ”œโ”€โ”€ [1]=30 โ””โ”€โ”€ [2]=A(...) """ vmark = ("โ”‚\t")[:tabwidth] # vertical mark lmark = ("โ””" + "โ”€" * (tabwidth - 2) + (" \t"))[:tabwidth] # last mark cmark = ("โ”œ" + "โ”€" * (tabwidth - 2) + (" \t"))[:tabwidth] # connector mark smark = (" \t")[:tabwidth] # space mark def step( node: Node, depth: int = 0, is_last: bool = False, is_lasts: tuple[bool, ...] = (), ) -> str: indent = "".join(smark if is_last else vmark for is_last in is_lasts[:-1]) branch = (lmark if is_last else cmark) if depth > 0 else "" if (child_count := len(node.children)) == 0: (key, _), value = text = f"{indent}" text += f"{branch}{key}=" if key is not None else "" text += tree_repr(value, depth=0) return text + "\n" (key, type), _ = text = f"{indent}{branch}" text += f"{key}:" if key is not None else "" text += f"{type.__name__}\n" for i, child in enumerate(node.children.values()): text += step( child, depth=depth + 1, is_last=(i == child_count - 1), is_lasts=is_lasts + (i == child_count - 1,), ) return text is_path_leaf = is_path_leaf_depth_factory(depth) root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) text = step(root, is_last=len(root.children) == 1) return (text if tabwidth is None else text.expandtabs(tabwidth)).rstrip()
[docs]def tree_mermaid( tree: PyTree, depth: int | float = float("inf"), is_leaf: Callable[[Any], None] | None = None, tabwidth: int | None = 4, ) -> str: """Generate a mermaid diagram syntax for arbitrary pytrees. Args: tree: PyTree depth: depth of the tree to print. default is max depth is_leaf: function to determine if a node is a leaf. default is None tabwidth: tab width of the repr string. default is 4. Example: >>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> # as rendered by mermaid >>> print(tc.tree_mermaid(tree)) # doctest: +SKIP .. image:: ../_static/tree_mermaid.jpg :width: 300px :align: center Note: - Copy the output and paste it in the mermaid live editor to interact with the diagram. """ def step(node: Node, depth: int = 0) -> str: if len(node.children) == 0: (key, _), value = ppstr = f"{key}=" if key is not None else "" ppstr += tree_repr(value, depth=0) ppstr = "<b>" + ppstr + "</b>" return f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' (key, type), _ = ppstr = f"{key}:" if key is not None else "" ppstr += f"{type.__name__}" ppstr = "<b>" + ppstr + "</b>" if node.parent is None: text = f'\tid{id(node)}("{ppstr}")\n' else: text = f'\tid{id(node.parent)} --- id{id(node)}("{ppstr}")\n' for child in node.children.values(): text += step(child, depth=depth + 1) return text is_path_leaf = is_path_leaf_depth_factory(depth) root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) text = "flowchart LR\n" + step(root) return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip()
# dispatcher for dot nodestyles dot_dispatcher = ft.singledispatch(lambda _: dict(shape="box"))
[docs]def tree_graph( tree: PyTree, depth: int | float = float("inf"), is_leaf: Callable[[Any], None] | None = None, tabwidth: int | None = 4, ) -> str: """Generate a dot diagram syntax for arbitrary pytrees. Args: tree: pytree depth: depth of the tree to print. default is max depth is_leaf: function to determine if a node is a leaf. default is None tabwidth: tab width of the repr string. default is 4. Returns: str: dot diagram syntax Example: >>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> # as rendered by graphviz .. image:: ../_static/tree_graph.svg Example: >>> # define custom style for a node by dispatching on the value >>> # the defined function should return a dict of attributes >>> # that will be passed to graphviz. >>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> @tc.tree_graph.def_nodestyle(list) ... def _(_) -> dict[str, str]: ... return dict(shape="circle", style="filled", fillcolor="lightblue") .. image:: ../_static/tree_graph_stylized.svg """ def step(node: Node, depth: int = 0) -> str: (key, type), value = # dispatch node style style = ", ".join(f"{k}={v}" for k, v in dot_dispatcher(value).items()) if len(node.children) == 0: ppstr = f"{key}=" if key is not None else "" ppstr += tree_repr(value, depth=0) text = f'\t{id(node)} [label="{ppstr}", {style}];\n' text += f"\t{id(node.parent)} -> {id(node)};\n" return text ppstr = f"{key}:" if key is not None else "" ppstr += f"{type.__name__}" if node.parent is None: text = f'\t{id(node)} [label="{ppstr}", {style}];\n' else: text = f'\t{id(node)} [label="{ppstr}", {style}];\n' text += f"\t{id(node.parent)} -> {id(node)};\n" for child in node.children.values(): text += step(child, depth=depth + 1) return text is_path_leaf = is_path_leaf_depth_factory(depth) root = construct_tree(tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf) text = "digraph G {\n" + step(root) + "}" return (text.expandtabs(tabwidth) if tabwidth is not None else text).rstrip()
tree_graph.def_nodestyle = dot_dispatcher.register def format_width(string, width=60): """Strip newline/tab characters if less than max width.""" children_length = len(string) - string.count("\n") - string.count("\t") if children_length > width: return string return string.replace("\n", "").replace("\t", "") # table printing Row: TypeAlias = Sequence[str] # list of columns def _table(rows: list[Row]) -> str: """Generate a table from a list of rows.""" def line(text: Row, widths: list[int]) -> str: return "\n".join( "โ”‚" + "โ”‚".join(col.ljust(width) for col, width in zip(line_row, widths)) + "โ”‚" for line_row in zip_longest(*[t.split("\n") for t in text], fillvalue="") ) widths = [max(map(len, "\n".join(col).split("\n"))) for col in zip(*rows)] spaces: Row = ["โ”€" * width for width in widths] return ( ("โ”Œ" + "โ”ฌ".join(spaces) + "โ”") + "\n" + ("\nโ”œ" + "โ”ผ".join(spaces) + "โ”ค\n").join(line(row, widths) for row in rows) + "\n" + ("โ””" + "โ”ด".join(spaces) + "โ”˜") ) def size_pp(size: int, **spec: Unpack[PPSpec]): del spec order_alpha = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"] size_order = int(math.log(size, 1024)) if size else 0 text = f"{(size)/(1024**size_order):.2f}{order_alpha[size_order]}" return text
[docs]def tree_summary( tree: PyTree, *, depth: int | float = float("inf"), is_leaf: Callable[[Any], None] | None = None, ) -> str: """Print a summary of an arbitrary pytree. Args: tree: a registered pytree to summarize. depth: max depth to display the tree. defaults to maximum depth. is_leaf: function to determine if a node is a leaf. defaults to None Returns: String summary of the tree structure: - First column: path to the node. - Second column: type of the node. to control the displayed type use `tree_summary.def_type(type, func)` to define a custom type display function. - Third column: number of leaves in the node. for arrays the number of leaves is the number of elements in the array, otherwise its 1. to control the number of leaves of a node use `tree_summary.def_count(type,func)` - Fourth column: size of the node in bytes. if the node is array the size is the size of the array in bytes, otherwise its the size is not displayed. to control the size of a node use `tree_summary.def_size(type,func)` - Last row: type of parent, number of leaves of the parent Example: >>> import pytreeclass as tc >>> import jax.numpy as jnp >>> print(tc.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])])) โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚Name โ”‚Type โ”‚Countโ”‚Size โ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚[0] โ”‚int โ”‚1 โ”‚ โ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚[1][0] โ”‚int โ”‚1 โ”‚ โ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚[1][1][0]โ”‚int โ”‚1 โ”‚ โ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚[2] โ”‚i32[3]โ”‚3 โ”‚12.00Bโ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚ฮฃ โ”‚list โ”‚6 โ”‚12.00Bโ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜ Example: >>> # set python `int` to have 4 bytes using dispatching >>> import pytreeclass as tc >>> print(tc.tree_summary(1)) โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ” โ”‚Nameโ”‚Typeโ”‚Countโ”‚Sizeโ”‚ โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค โ”‚ฮฃ โ”‚int โ”‚1 โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜ >>> @tc.tree_summary.def_size(int) ... def _(node: int) -> int: ... return 4 >>> print(tc.tree_summary(1)) โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ” โ”‚Nameโ”‚Typeโ”‚Countโ”‚Size โ”‚ โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค โ”‚ฮฃ โ”‚int โ”‚1 โ”‚4.00Bโ”‚ โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜ Example: >>> # set custom type display for jaxprs >>> import jax >>> import pytreeclass as tc >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1)) >>> @tc.tree_summary.def_type(ClosedJaxprType) ... def _(expr: ClosedJaxprType) -> str: ... jaxpr = expr.jaxpr ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})" >>> def func(x, y): ... return x >>> jaxpr = jax.make_jaxpr(func)(1, 2) >>> print(tc.tree_summary(jaxpr)) โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ” โ”‚Nameโ”‚Type โ”‚Countโ”‚Sizeโ”‚ โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค โ”‚ฮฃ โ”‚Jaxpr([a, b], [a])โ”‚1 โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜ """ rows = [["Name", "Type", "Count", "Size"]] tcount = tsize = 0 traces_leaves = tree_leaves_with_typed_path( tree, is_leaf=is_leaf, is_path_leaf=is_path_leaf_depth_factory(depth), ) for trace, leaf in traces_leaves: tcount += (count := tree_count(leaf)) tsize += (size := tree_size(leaf)) if trace == ((), ()): # avoid printing the leaf trace (which is the root of the tree) # twice, once as a leaf and once as the root at the end continue paths, _ = trace pstr = treelib.keystr(paths) tstr = tree_summary.type_dispatcher(leaf) cstr = f"{count:,}" if count else "" sstr = size_pp(size) if size else "" rows += [[pstr, tstr, cstr, sstr]] pstr = "ฮฃ" tstr = tree_summary.type_dispatcher(tree) cstr = f"{tcount:,}" if tcount else "" sstr = size_pp(tsize) if tsize else "" rows += [[pstr, tstr, cstr, sstr]] return _table(rows)
tree_summary.count_dispatcher = ft.singledispatch(lambda x: 1) tree_summary.def_count = tree_summary.count_dispatcher.register tree_summary.size_dispatcher = ft.singledispatch(lambda x: 0) tree_summary.def_size = tree_summary.size_dispatcher.register tree_summary.type_dispatcher = ft.singledispatch(lambda x: type(x).__name__) tree_summary.def_type = tree_summary.type_dispatcher.register @tree_summary.def_type(arraylib.ndarray) def tree_summary_array(node: Any) -> str: """Return the type repr of the node.""" shape_dype = node.shape, node.dtype spec = dict(indent=0, kind="REPR", width=80, depth=float("inf")) return pp(ShapeDtypePP(*shape_dype), **spec) @tree_summary.def_count(arraylib.ndarray) def tree_summary_array_count(node: arraylib.ndarray) -> int: return node.size @tree_summary.def_size(arraylib.ndarray) def tree_summary_array_size(node: arraylib.ndarray) -> int: return node.nbytes def tree_size(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.size_dispatcher(node) leaves, _ = treelib.tree_flatten(tree) return ft.reduce(reduce_func, leaves, 0) def tree_count(tree: PyTree) -> int: def reduce_func(acc, node): return acc + tree_summary.count_dispatcher(node) leaves, _ = treelib.tree_flatten(tree) return ft.reduce(reduce_func, leaves, 0) if importlib.util.find_spec("jax"): # jax pretty printing extra handlers import jax # register jax types for pretty printing pp_dispatcher.register(jax.custom_jvp, func_pp) pp_dispatcher.register(jax.ShapeDtypeStruct, shape_dtype_pp) # register jax for tree_summary tree_summary.def_type(jax.ShapeDtypeStruct, tree_summary_array)