简体   繁体   English

PyTorch 相当于 numpy reshape 函数

[英]PyTorch equivalent of numpy reshape function

Hi I have these to functions to flatten my complex type data to feed it to NN and reconstruct NN prediction to the original form.嗨,我有这些函数来展平我的复杂类型数据,以将其提供给 NN 并将 NN 预测重建为原始形式。

def flatten_input64(Input): #convert (:,4,4,2) complex matrix to (:,64) real vector
  Input1 = Input.reshape(-1, 32, order='F')
  Input_vector=np.zeros([19957,64],dtype = np.float64)
  Input_vector[:,0:32] = Input1.real
  Input_vector[:,32:64] = Input1.imag
  return Input_vector

def convert_output64(Output): #convert (:,64) real vector to (:,4,4,2) complex matrix  
  Output1 = Output[:,0:32] + 1j * Output[:,32:64]
  output_matrix = Output1.reshape(-1, 4 ,4 ,2 , order = 'F')
  return output_matrix

I am writing a customized loss that required all operation to be in torch and I should rewrite my conversion functions in PyTorch.我正在编写一个定制的损失,要求所有操作都在火炬中进行,我应该在 PyTorch 中重写我的转换函数。 The problem is that PyTorch doesn't have 'F' order reshape.问题是 PyTorch 没有“F”顺序重塑。 I tried to write my own version of F reorder but, it doesn't work.我尝试编写自己的 F reorder 版本,但是没有用。 Do you have any idea what is my mistake?你知道我的错误是什么吗?

def convert_output64_torch(input):
   # number_of_samples = defined
   for i in range(0, number_of_samples):
     Output1 = input[i,0:32] + 1j * input[i,32:64]
     Output2 = Output1.view(-1,4,4,2).permute(3,2,1,0)
   if i == 0:
     Output3 = Output2
   else:
     Output3 = torch.cat((Output3, Output2),0)
return Output3

Update: following @a_guest comment I tried to recreate my matrix with transpose and reshape and I got this code working same as F order reshape in numy:更新:在@a_guest 评论之后,我尝试使用转置和重塑重新创建我的矩阵,并且此代码的​​工作方式与 numy 中的 F order reshape 相同:

def convert_output64_torch(input):
   Output1 = input[:,0:32] + 1j * input[:,32:64]
   shape = (-1 , 4 , 4 , 2)
   Output3 = torch.transpose(torch.transpose(torch.reshape(torch.transpose(Output1,0,1),shape[::-1]),1,2),0,3)
return Output3

In both, Numpy and PyTorch, you can get the equivalent with the following operation: aTreshape(shape[::-1]).T (where a is either an array or a tensor):在 Numpy 和 PyTorch 中,您可以通过以下操作获得等效项: aTreshape(shape[::-1]).T (其中a是数组或张量):

>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])
>>> shape = (2, 8)
>>> a.reshape(shape, order='F')
array([[ 0,  8,  1,  9,  2, 10,  3, 11],
       [ 4, 12,  5, 13,  6, 14,  7, 15]])
>>> a.T.reshape(shape[::-1]).T
array([[ 0,  8,  1,  9,  2, 10,  3, 11],
       [ 4, 12,  5, 13,  6, 14,  7, 15]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM