简体   繁体   English

关于Tensorflow和PyTorch中的自定义操作

[英]About custom operations in Tensorflow and PyTorch

I have to implement an energy function, termed Rigidity Energy, as in Eq 7 of this paper here . 我要实现的能量函数,称为刚性能源,如本文的公式7 这里
The energy function takes as input two 3D object meshes, and returns the energy between them. 能量函数将两个3D对象网格作为输入,并返回它们之间的能量。 The first mesh is the source mesh, and the second mesh is the deformed version of the source mesh. 第一个网格是源网格,第二个网格是源网格的变形版本。 In rough psuedo-code, the computation would go like this: 在粗糙的伪代码中,计算将如下所示:

Iterate over all the vertices in the source mesh. 遍历源网格中的所有顶点。

  1. For every vertex, compute its covariance matrix with its neighboring vertices. 对于每个顶点,请计算其与相邻顶点的协方差矩阵。
  2. Perform SVD on the computed covariance matrix and find the rotation matrix of the vertex. 对计算出的协方差矩阵执行SVD并找到顶点的旋转矩阵。
  3. Use the computed rotation matrix, the point coordinates in the original mesh and the corresponding coordinates in the deformed mesh, to compute the energy deviation of the vertex. 使用计算出的旋转矩阵,原始网格中的点坐标和变形网格中的相应坐标来计算顶点的能量偏差。

Thus this energy function requires me to iterate over each point in the mesh, and the mesh could have more than 2k such points. 因此,这个能量函数需要我遍历网格中的每个点,并且网格中可能有超过2k个这样的点。 In Tensorflow, there are two ways to do this. 在Tensorflow中,有两种方法可以做到这一点。 I can have 2 tensors of shape (N,3), one representing the points of source and the other of the deformed mesh. 我可以有2个形状为(N,3)的张量,一个张量表示源点,另一个表示变形的网格。

  1. Do it purely using Tensorflow tensors. 纯粹使用Tensorflow张量执行此操作。 That is, iterate over elements of the above tensors using tf.gather and perform the computation on each point using only existing TF operations. 也就是说,使用tf.gather迭代上述张量的元素,并仅使用现有TF运算在每个点上执行计算。 This method, would be extremely slow. 这种方法将非常慢。 I've tried to define loss functions that iterate over 1000s of points before, and the graph construction itself takes too much time to be practical. 我尝试定义损失函数,该函数之前要迭代1000多个点,并且图形构造本身要花费太多时间才能实用。
  2. Add a new TF OP as explained in the TF documentation here . 添加一个新的TF OP为TF文档中说明这里 This involves writing the function in CPP (and Cuda, for GPU support), and registering the new OP with TF. 这涉及在CPP(和Cuda,用于GPU支持)中编写函数,以及向TF注册新OP。

The first method is easy to write, but impractically slow. 第一种方法很容易编写,但实际上却很慢。 The second method is a pain to write. 第二种方法很难写。

I've used TF for 3 years, and have never used PyTorch before, but at this point I'm considering switching to it, if it offers a better alternative for such cases. 我已经使用TF三年了,以前从未使用过PyTorch,但是目前,我正在考虑切换到它,如果它可以为此类情况提供更好的选择。

Does PyTorch have a way of implementing such loss functions both easily and performs as fast as it would on GPU. PyTorch是否有一种方法可以轻松实现此类损失函数, 并且可以像在GPU上一样快地执行。 ie, A pythonic way of writing my own loss functions that runs on GPU, without any C or Cuda code on my part? 即,一种编写我自己的在GPU上运行的损失函数的pythonic方式,而我却不需要任何C或Cuda代码?

As far as I understand, you are essentially asking if this operation can be vectorized. 据我了解,您实际上是在问是否可以对该操作进行向量化。 The answer is no, at least not fully, because svd implementation in PyTorch is not vectorized. 答案是否定的,至少是不完全的,因为PyTorch中的svd实现未向量化。

If you showed the tensorflow implementation, it would help in understanding your starting point. 如果您展示了tensorflow实现,它将有助于您了解起点。 I don't know what you mean by finding the rotation matrix of the vertex, but I would guess this can be vectorized. 我不知道找到顶点的旋转矩阵是什么意思,但是我想这可以向量化。 This would mean that svd is the only non-vectorized operation and you could perhaps get away with writing just a single custom OP, that is the vectorized svd - which is likely quite easy, because it would amount to calling some library routines in a loop in C++. 这意味着svd是唯一的非矢量化操作,您也许可以只编写一个自定义OP,即矢量化svd,这很容易,因为它很容易在循环中调用某些库例程在C ++中。

Two possible sources of problems I see are 我看到的两个可能的问题来源是

  1. if the neighborhoods of N(i) in equation 7 can be of significantly different sizes (which would mean that the covariance matrices are of different sizes and vectorization would require some dirty tricks) 如果方程7中N(i)的邻域可以具有明显不同的大小(这将意味着协方差矩阵具有不同的大小,并且向量化将需要一些肮脏的技巧)
  2. the general problem of dealing with meshes and neighborhoods could be difficult. 处理网格和邻域的一般问题可能很困难。 This is an innate property of irregular meshes, but PyTorch has support for sparse matrices and a dedicated package torch_geometry , which at least helps. 这是不规则网格的固有属性,但是PyTorch支持稀疏矩阵和专用包torch_geometry ,这至少有帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM