๐ŸŽจ Pretty printing API#

pytreeclass.tree_diagram(tree, *, depth=inf, is_leaf=None, tabwidth=4)[source]#

Pretty print arbitrary pytrees tree with tree structure diagram.

Parameters:
  • tree (Any) โ€“ arbitrary pytree.

  • depth (int | float) โ€“ depth of the tree to print. default is max depth.

  • is_leaf (Callable[[Any], None] | None) โ€“ function to determine if a node is a leaf. default is None.

  • tabwidth (int) โ€“ tab width of the repr string. default is 4.

Example

>>> import pytreeclass as tc
>>> @tc.autoinit
... class A(tc.TreeClass):
...     x: int = 10
...     y: int = (20,30)
...     z: int = 40
>>> @tc.autoinit
... class B(tc.TreeClass):
...     a: int = 10
...     b: tuple = (20,30, A())
>>> print(tc.tree_diagram(B(), depth=0))
B(...)
>>> print(tc.tree_diagram(B(), depth=1))
B
โ”œโ”€โ”€ .a=10
โ””โ”€โ”€ .b=(...)
>>> print(tc.tree_diagram(B(), depth=2))
B
โ”œโ”€โ”€ .a=10
โ””โ”€โ”€ .b:tuple
    โ”œโ”€โ”€ [0]=20
    โ”œโ”€โ”€ [1]=30
    โ””โ”€โ”€ [2]=A(...)
pytreeclass.tree_graph(tree, depth=inf, is_leaf=None, tabwidth=4)[source]#

Generate a dot diagram syntax for arbitrary pytrees.

Parameters:
  • tree (PyTree) โ€“ pytree

  • depth (int | float) โ€“ depth of the tree to print. default is max depth

  • is_leaf (Callable[[Any], None] | None) โ€“ function to determine if a node is a leaf. default is None

  • tabwidth (int | None) โ€“ tab width of the repr string. default is 4.

Returns:

dot diagram syntax

Return type:

str

Example

>>> import pytreeclass as tc
>>> tree = [1, 2, dict(a=3)]
>>> # as rendered by graphviz
../_images/tree_graph.svg

Example

>>> # define custom style for a node by dispatching on the value
>>> # the defined function should return a dict of attributes
>>> # that will be passed to graphviz.
>>> import pytreeclass as tc
>>> tree = [1, 2, dict(a=3)]
>>> @tc.tree_graph.def_nodestyle(list)
... def _(_) -> dict[str, str]:
...     return dict(shape="circle", style="filled", fillcolor="lightblue")
../_images/tree_graph_stylized.svg
pytreeclass.tree_mermaid(tree, depth=inf, is_leaf=None, tabwidth=4)[source]#

Generate a mermaid diagram syntax for arbitrary pytrees.

Parameters:
  • tree (PyTree) โ€“ PyTree

  • depth (int | float) โ€“ depth of the tree to print. default is max depth

  • is_leaf (Callable[[Any], None] | None) โ€“ function to determine if a node is a leaf. default is None

  • tabwidth (int | None) โ€“ tab width of the repr string. default is 4.

Return type:

str

Example

>>> import pytreeclass as tc
>>> tree = [1, 2, dict(a=3)]
>>> # as rendered by mermaid
>>> print(tc.tree_mermaid(tree))  
../_images/tree_mermaid.jpg

Note

  • Copy the output and paste it in the mermaid live editor to interact with the diagram. https://mermaid.live

pytreeclass.tree_repr(tree, *, width=80, tabwidth=2, depth=inf)[source]#

Prertty print arbitrary pytrees __repr__.

Parameters:
  • tree (PyTree) โ€“ arbitrary pytree.

  • width (int) โ€“ max width of the repr string.

  • tabwidth (int) โ€“ tab width of the repr string.

  • depth (int | float) โ€“ max depth of the repr string.

Return type:

str

Example

>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])}
>>> print(tc.tree_repr(tree, depth=0))
{...}
>>> print(tc.tree_repr(tree, depth=1))
{a:1, b:[...], c:{...}, f:i32[2](ฮผ=6.50, ฯƒ=0.50, โˆˆ[6,7])}
>>> print(tc.tree_repr(tree, depth=2))
{a:1, b:[2, 3], c:{d:4, e:5}, f:i32[2](ฮผ=6.50, ฯƒ=0.50, โˆˆ[6,7])}
pytreeclass.tree_str(tree, *, width=80, tabwidth=2, depth=inf)[source]#

Prertty print arbitrary pytrees __str__.

Parameters:
  • tree (PyTree) โ€“ arbitrary pytree.

  • width (int) โ€“ max width of the str string.

  • tabwidth (int) โ€“ tab width of the repr string.

  • depth (int | float) โ€“ max depth of the repr string.

Return type:

str

Example

>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])}
>>> print(tc.tree_str(tree, depth=1))
{a:1, b:[...], c:{...}, f:[6 7]}
>>> print(tc.tree_str(tree, depth=2))
{a:1, b:[2, 3], c:{d:4, e:5}, f:[6 7]}
pytreeclass.tree_summary(tree, *, depth=inf, is_leaf=None)[source]#

Print a summary of an arbitrary pytree.

Parameters:
  • tree (PyTree) โ€“ a registered pytree to summarize.

  • depth (int | float) โ€“ max depth to display the tree. defaults to maximum depth.

  • is_leaf (Callable[[Any], None] | None) โ€“ function to determine if a node is a leaf. defaults to None

Returns:

  • First column: path to the node.

  • Second column: type of the node. to control the displayed type use

    tree_summary.def_type(type, func) to define a custom type display function.

  • Third column: number of leaves in the node. for arrays the number of leaves

    is the number of elements in the array, otherwise its 1. to control the number of leaves of a node use tree_summary.def_count(type,func)

  • Fourth column: size of the node in bytes. if the node is array the size

    is the size of the array in bytes, otherwise its the size is not displayed. to control the size of a node use tree_summary.def_size(type,func)

  • Last row: type of parent, number of leaves of the parent

Return type:

String summary of the tree structure

Example

>>> import pytreeclass as tc
>>> import jax.numpy as jnp
>>> print(tc.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])]))
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚Name     โ”‚Type  โ”‚Countโ”‚Size  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚[0]      โ”‚int   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚[1][0]   โ”‚int   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚[1][1][0]โ”‚int   โ”‚1    โ”‚      โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚[2]      โ”‚i32[3]โ”‚3    โ”‚12.00Bโ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ        โ”‚list  โ”‚6    โ”‚12.00Bโ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Example

>>> # set python `int` to have 4 bytes using dispatching
>>> import pytreeclass as tc
>>> print(tc.tree_summary(1))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Typeโ”‚Countโ”‚Sizeโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚int โ”‚1    โ”‚    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜
>>> @tc.tree_summary.def_size(int)
... def _(node: int) -> int:
...     return 4
>>> print(tc.tree_summary(1))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Typeโ”‚Countโ”‚Size โ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚int โ”‚1    โ”‚4.00Bโ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”˜

Example

>>> # set custom type display for jaxprs
>>> import jax
>>> import pytreeclass as tc
>>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1))
>>> @tc.tree_summary.def_type(ClosedJaxprType)
... def _(expr: ClosedJaxprType) -> str:
...     jaxpr = expr.jaxpr
...     return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})"
>>> def func(x, y):
...     return x
>>> jaxpr = jax.make_jaxpr(func)(1, 2)
>>> print(tc.tree_summary(jaxpr))
โ”Œโ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”
โ”‚Nameโ”‚Type              โ”‚Countโ”‚Sizeโ”‚
โ”œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”ค
โ”‚ฮฃ   โ”‚Jaxpr([a, b], [a])โ”‚1    โ”‚    โ”‚
โ””โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”˜