๐จ Pretty printing API#
- pytreeclass.tree_diagram(tree, *, depth=inf, is_leaf=None, tabwidth=4)[source]#
Pretty print arbitrary pytrees tree with tree structure diagram.
- Parameters:
tree (Any) โ arbitrary pytree.
depth (int | float) โ depth of the tree to print. default is max depth.
is_leaf (Callable[[Any], None] | None) โ function to determine if a node is a leaf. default is None.
tabwidth (int) โ tab width of the repr string. default is 4.
Example
>>> import pytreeclass as tc >>> @tc.autoinit ... class A(tc.TreeClass): ... x: int = 10 ... y: int = (20,30) ... z: int = 40
>>> @tc.autoinit ... class B(tc.TreeClass): ... a: int = 10 ... b: tuple = (20,30, A())
>>> print(tc.tree_diagram(B(), depth=0)) B(...)
>>> print(tc.tree_diagram(B(), depth=1)) B โโโ .a=10 โโโ .b=(...)
>>> print(tc.tree_diagram(B(), depth=2)) B โโโ .a=10 โโโ .b:tuple โโโ [0]=20 โโโ [1]=30 โโโ [2]=A(...)
- pytreeclass.tree_graph(tree, depth=inf, is_leaf=None, tabwidth=4)[source]#
Generate a dot diagram syntax for arbitrary pytrees.
- Parameters:
tree (PyTree) โ pytree
depth (int | float) โ depth of the tree to print. default is max depth
is_leaf (Callable[[Any], None] | None) โ function to determine if a node is a leaf. default is None
tabwidth (int | None) โ tab width of the repr string. default is 4.
- Returns:
dot diagram syntax
- Return type:
str
Example
>>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> # as rendered by graphviz
Example
>>> # define custom style for a node by dispatching on the value >>> # the defined function should return a dict of attributes >>> # that will be passed to graphviz. >>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> @tc.tree_graph.def_nodestyle(list) ... def _(_) -> dict[str, str]: ... return dict(shape="circle", style="filled", fillcolor="lightblue")
- pytreeclass.tree_mermaid(tree, depth=inf, is_leaf=None, tabwidth=4)[source]#
Generate a mermaid diagram syntax for arbitrary pytrees.
- Parameters:
tree (PyTree) โ PyTree
depth (int | float) โ depth of the tree to print. default is max depth
is_leaf (Callable[[Any], None] | None) โ function to determine if a node is a leaf. default is None
tabwidth (int | None) โ tab width of the repr string. default is 4.
- Return type:
str
Example
>>> import pytreeclass as tc >>> tree = [1, 2, dict(a=3)] >>> # as rendered by mermaid >>> print(tc.tree_mermaid(tree))
Note
Copy the output and paste it in the mermaid live editor to interact with the diagram. https://mermaid.live
- pytreeclass.tree_repr(tree, *, width=80, tabwidth=2, depth=inf)[source]#
Prertty print arbitrary pytrees
__repr__
.- Parameters:
tree (PyTree) โ arbitrary pytree.
width (int) โ max width of the repr string.
tabwidth (int) โ tab width of the repr string.
depth (int | float) โ max depth of the repr string.
- Return type:
str
Example
>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])}
>>> print(tc.tree_repr(tree, depth=0)) {...}
>>> print(tc.tree_repr(tree, depth=1)) {a:1, b:[...], c:{...}, f:i32[2](ฮผ=6.50, ฯ=0.50, โ[6,7])}
>>> print(tc.tree_repr(tree, depth=2)) {a:1, b:[2, 3], c:{d:4, e:5}, f:i32[2](ฮผ=6.50, ฯ=0.50, โ[6,7])}
- pytreeclass.tree_str(tree, *, width=80, tabwidth=2, depth=inf)[source]#
Prertty print arbitrary pytrees
__str__
.- Parameters:
tree (PyTree) โ arbitrary pytree.
width (int) โ max width of the str string.
tabwidth (int) โ tab width of the repr string.
depth (int | float) โ max depth of the repr string.
- Return type:
str
Example
>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> tree = {'a' : 1, 'b' : [2, 3], 'c' : {'d' : 4, 'e' : 5} , 'f' : jnp.array([6, 7])}
>>> print(tc.tree_str(tree, depth=1)) {a:1, b:[...], c:{...}, f:[6 7]}
>>> print(tc.tree_str(tree, depth=2)) {a:1, b:[2, 3], c:{d:4, e:5}, f:[6 7]}
- pytreeclass.tree_summary(tree, *, depth=inf, is_leaf=None)[source]#
Print a summary of an arbitrary pytree.
- Parameters:
tree (PyTree) โ a registered pytree to summarize.
depth (int | float) โ max depth to display the tree. defaults to maximum depth.
is_leaf (Callable[[Any], None] | None) โ function to determine if a node is a leaf. defaults to None
- Returns:
First column: path to the node.
- Second column: type of the node. to control the displayed type use
tree_summary.def_type(type, func) to define a custom type display function.
- Third column: number of leaves in the node. for arrays the number of leaves
is the number of elements in the array, otherwise its 1. to control the number of leaves of a node use tree_summary.def_count(type,func)
- Fourth column: size of the node in bytes. if the node is array the size
is the size of the array in bytes, otherwise its the size is not displayed. to control the size of a node use tree_summary.def_size(type,func)
Last row: type of parent, number of leaves of the parent
- Return type:
String summary of the tree structure
Example
>>> import pytreeclass as tc >>> import jax.numpy as jnp >>> print(tc.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])])) โโโโโโโโโโโฌโโโโโโโฌโโโโโโฌโโโโโโโ โName โType โCountโSize โ โโโโโโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค โ[0] โint โ1 โ โ โโโโโโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค โ[1][0] โint โ1 โ โ โโโโโโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค โ[1][1][0]โint โ1 โ โ โโโโโโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค โ[2] โi32[3]โ3 โ12.00Bโ โโโโโโโโโโโผโโโโโโโผโโโโโโผโโโโโโโค โฮฃ โlist โ6 โ12.00Bโ โโโโโโโโโโโดโโโโโโโดโโโโโโดโโโโโโโ
Example
>>> # set python `int` to have 4 bytes using dispatching >>> import pytreeclass as tc >>> print(tc.tree_summary(1)) โโโโโโฌโโโโโฌโโโโโโฌโโโโโ โNameโTypeโCountโSizeโ โโโโโโผโโโโโผโโโโโโผโโโโโค โฮฃ โint โ1 โ โ โโโโโโดโโโโโดโโโโโโดโโโโโ >>> @tc.tree_summary.def_size(int) ... def _(node: int) -> int: ... return 4 >>> print(tc.tree_summary(1)) โโโโโโฌโโโโโฌโโโโโโฌโโโโโโ โNameโTypeโCountโSize โ โโโโโโผโโโโโผโโโโโโผโโโโโโค โฮฃ โint โ1 โ4.00Bโ โโโโโโดโโโโโดโโโโโโดโโโโโโ
Example
>>> # set custom type display for jaxprs >>> import jax >>> import pytreeclass as tc >>> ClosedJaxprType = type(jax.make_jaxpr(lambda x: x)(1)) >>> @tc.tree_summary.def_type(ClosedJaxprType) ... def _(expr: ClosedJaxprType) -> str: ... jaxpr = expr.jaxpr ... return f"Jaxpr({jaxpr.invars}, {jaxpr.outvars})" >>> def func(x, y): ... return x >>> jaxpr = jax.make_jaxpr(func)(1, 2) >>> print(tc.tree_summary(jaxpr)) โโโโโโฌโโโโโโโโโโโโโโโโโโโฌโโโโโโฌโโโโโ โNameโType โCountโSizeโ โโโโโโผโโโโโโโโโโโโโโโโโโโผโโโโโโผโโโโโค โฮฃ โJaxpr([a, b], [a])โ1 โ โ โโโโโโดโโโโโโโโโโโโโโโโโโโดโโโโโโดโโโโโ