Brief Introduction to PyTorch 1.11
Brief Introduction to PyTorch 1.11

PyTorch is an optimized tensor library for deep learning on GPUs and CPUs. It uses dynamic computational graphs at its core, and builds on the advantages of mainstream deep learning frameworks. The new PyTorch 1.11 presents beta versions of functorch and TorchData, premium functions that are heavily inspired by Google JAX and MindSpore, respectively. For details, see the PyTorch-released blog. In this blog, I'd like to express my thoughts on the new features of PyTorch 1.11.
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. It provides vjp, jvp, and grad for automatic differential; pmap for single-program multiple-data (SPMD) parallel programming of multiple accelerators; vamp for automatic vectorization; and jit for compilation speedup. Of these, the automatic differential and SPMD programming are the most important and frequently used. In addition, JAX is used for transforming numerical functions, allowing the aforementioned interfaces to be used in combination through closures. However, the PyTorch functorch does not provide the pmap interface, because the static graphs of PyTorch do not support distributed parallelism. In addition, functorch provides varying functional programming that is different from the tensor-oriented PyTorch nn Modules, which may confuse developers during automatic differentiation.
The PyTorch Dataloader is flexible and easy to use, but a problem is that it bundles too many features together and is difficult to extend. That's why the PyTorch team designed TorchData, a new library for common modular data loading primitives to help easily construct flexible and effective data pipelines. In practice, the increasing amounts of training data and large-sized models require convenient data preprocessing. TorchData's great dependency on Python will create a bottleneck that impacts further improvement.
Though TorchScript, tracing, and FX methods can be used for model and code compilation on PyTorch, they are basically supplements to the dynamic computational graphs. In the future, with the release of NVIDIA H100 GPU and other AI processors, it remains an open challenge for PyTorch to better utilize the performance of DSA hardware while retaining current development logic.