Pytorch 1.11看上去很美,但是.....
Pytorch 1.11看上去很美,但是.....
作者:金雪锋
来源:https://zhuanlan.zhihu.com/p/488413600
Pytorch团队是一个非常善于学习的团队,从动态图起家,不停在吸收其他框架的优点,比如最新的1.11版本,提供对标JAX的function graph、非常类似TF/MindSpore的TorchData等。
https://github.com/pytorch/pytorch/releases/tag/v1.11.0
终于可用可组合函数转换库!PyTorch 1.11发布,弥补JAX短板,支持Python 3.10
我们先看看对标JAX的function graph,大家知道,JAX的设计思路是Numpy/Scipy这些标准的Python科学计算库上提供自动微分/分布式并行/jit等扩展,包括:
- 自动微分:vjp/jvp/grad/......
- 分布式:pmap/.....
- 向量化并行:vmap/....
- JIT加速
从特性上看,最关键是自动微分和分布式这两块
从架构上看,JAX采用的是函数式风格,允许上面的接口通过闭包等形式组合使用。
而这次Pytorch提供的function graph,缺失了最重要的pmap接口,因为Pytorch的静态图不支持分布式并行;
另外,个人认为最重要的问题是,function graph所提供的函数式编程风格其实和Pytorch nn的Tensor为中心的风格是不统一的,这一点在自动微分的时候,尤其明显,容易让开发者感到困惑,参考:
在看一下TorchData,前几天团队的小伙伴们还在说,Torch的Dataloader比较灵活,易用性好,结果没几天,就被打脸了,TorchData就出现了,居然和MindSpore的设计类似了,参考:
实际上,随着训练数据的越来越多,模型的越来越大,数据的处理已经成为大规模训练的瓶颈,我个人感觉TorchData还是改造的不够彻底,大量的处理还是依赖Python,后面还会是瓶颈。
回溯Pytorch的历史,其实从动态图取得优势后,Pytorch做了很多工作,比如torchscript、tracing、FX等等,但是,我们也看到这些工作基本上是作为动态图的补充,我想这也可能是function graph和torch data设计成这样的根因。
面向未来,我更期待是,随着NV H100以及更多的AI芯片发布,Pytorch如何解决既要维护当前开发者的开发习惯又要充分发挥越来越DSA的硬件性能的矛盾。