简体   繁体   中英

Can pytorch optimize sequential operations (like a tensorflow graph or JAX's jit)?

Originally, tensorflow and pytorch had a fundamental difference:

  • tensorflow is based on a computional graph. Building this graph and evaluating it in a session are two separate steps. While it is being used, the graph doesn't change, which allows for optimizations.
  • torch eagerly evaluates operations on a tensor. This makes the API more convenient (no sessions) but also looses the potential to recognize and optimize operations that always occur in sequence.

Now this difference is becoming less clear. Tensorflow has answered to the popularity of torch with tf eager . There is also the JAX project, which builds on the same underlying framework as tensorflow ( XLA ). JAX has no concept of a session. But it allows you to compile multiple operations together by simply calling jit .

Since Tensorflow has moved to cover PyTorch functionality, is PyTorch also working on integrating Tensorflow advantages? Is there something like a session or jit functionality in PyTorch (or on its roadmap)?

The API docs have a jit section , but as far as I can see, that is more about exporting your models.

As you mentioned, there is a torch.jit and it's purpose is also to introduce optimization in the exported graph (eg kernel fusion, optimization of constants etc.). IIRC you can find some source code regarding those in their github repo here , though I'm not sure whether those are explicitly mentioned somewhere in the docs (or explicitly enough to be remembered).

Since 1.3 there is also quantization introduced (see here for some introduction). In tutorials section, namely here you can see explicit fusion of Conv2d , BatchNorm and ReLU in order to improve performance. Ofc there also exists specific stuff like using int instead of float for weights (quantization), mixed arithmetic (using half float precision whenever possible, see NVidia's Apex ) and others.

Last but not least, I don't think for a well written model using vectorized operations and exported with torchscript you are gonna see really substantial runtime differences because of some generic graph optimization. Still it differs whether you are going to use GPU, CPU, TPU, what are their versions, whether you are after inference only or training as well etc. It's pretty hard to pinpoint how fast tensorflow is in comparison to pytorch (besided some well-known issues in both frameworks). All in all it depends and measurements vary a lot AFAIK.

BTW. When it comes to advantages of each framework their core indeed starts to cover similar things (PyTorch got mobile support lately, see here ). Real difference is still different underlying approach and what each framework has to do to circumvent those limitations.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM