簡體   English   中英

什么是 Pytorch 相當於 Pandas groupby.apply(list)?

[英]What is Pytorch equivalent of Pandas groupby.apply(list)?

我有以下 pytorch 張量long_format

tensor([[ 1.,  1.],
        [ 1.,  2.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 1.,  8.],
        [ 0.,  9.],
        [ 0., 10.]])

我想對第一列進行分組並將第二列存儲為張量。 不保證每個分組的結果大小相同。 請參見下面的示例。

[tensor([ 1., 2., 3., 4., 8.]),
 tensor([ 5.,  6., 7., 9., 10.])]

有沒有什么好的方法可以使用純粹的 Pytorch 運算符來做到這一點? 我想避免將 for 循環用於可追溯性目的。

我試過使用 for 循環和空張量的空列表,但這導致跟蹤不正確(不同的輸入值給出相同的結果)

n_groups = 2
inverted = [torch.empty([0]) for _ in range(n_groups)]
for index, value in long_format:
   value = value.unsqueeze(dim=0)
   index = index.int()
   if type(inverted[index]) != torch.Tensor:
      inverted[index] = value
   else:
      inverted[index] = torch.cat((inverted[index], value))

您可以使用此代碼:

import torch
x = torch.tensor([[ 1.,  1.],
        [ 1.,  2.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 1.,  8.],
        [ 0.,  9.],
        [ 0., 10.]])

result =  [x[x[:,0]==i][:,1] for i in x[:,0].unique()]

output

[tensor([ 5.,  6.,  7.,  9., 10.]), tensor([1., 2., 3., 4., 8.])]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM