10 Working with pytrees
This chapter covers
- Representing complex data structures as pytrees
- Using functions for working with pytrees
- Creating custom pytree nodes
In previous chapters, we mostly used tensors to represent data and model parameters. That’s enough for simple cases, but it is not very convenient when your models and datasets become more complex.
Working with machine learning tasks frequently requires working with objects represented as lists of dicts, lists or arrays, dicts of arrays, and so on. For example, dataset elements can be represented this way, and neural network weights are typically organized in some hierarchy with weights and biases stored for each layer. If you continue to work with this complexity using low-level tools like tensors and basic tensor operations, your code quickly becomes larger and less clear. Finally, you have to invent your own higher-level abstractions and more complex data structures. Other domains like astrophysics, bioinformatics, weather modeling, and so on may have their own convenient data structures.
Fortunately, JAX has some batteries included; some of these tools are provided in core JAX. JAX refers to such tree-like data structures built out of container-like Python objects as pytrees and supports them in many library functions.