簡體   English   中英

在pytorch中連接兩個張量(有一個扭曲)

[英]concatenating two tensors in pytorch(with a twist)

我有一個大小為torch.Size([8, 768])存儲在變量embeddings中,看起來像這樣:-

 tensor([[-0.0687, -0.1327,  0.0112,  ...,  0.0715, -0.0297, -0.0477],
        [ 0.0115, -0.0029,  0.0323,  ...,  0.0277, -0.0297, -0.0599],
        [ 0.0760,  0.0788,  0.1640,  ...,  0.0574, -0.0805,  0.0066],
        ...,
        [-0.0110, -0.1773,  0.1143,  ...,  0.1397,  0.3021,  0.1670],
        [-0.1379, -0.0294, -0.0026,  ..., -0.0966, -0.0726,  0.1160],
        [ 0.0466, -0.0113,  0.0283,  ..., -0.0735,  0.0496,  0.0963]],
       grad_fn=<IndexBackward>)

現在,我希望取一些嵌入的平均值並將平均值放回張量中。 例如,(我將在列表而不是張量的幫助下進行解釋)

a = [1,2,3,4,5]
output = [1.5, 3, 4, 5]

因此,在這里我取了 1 和 2 的平均值,然后通過將元素移到列表中的左側,將其放入list output中。 我也想對張量做同樣的事情。

我將索引存儲在變量i中,我需要從中取平均值,並且j變量用於停止索引。 現在,讓我們看一下代碼:-

if i != len(embeddings):
  sum = 0
  count = 0
  #Calculating sum 
  for x in range(i-1, j):
    sum += text_index[x]
    count += 1

  avg = sum/count

  #Inserting the average in place of the other embeddings
  embeddings = embeddings[:i-1] + [avg] + embeddings[j:]
else :
  pass

現在,我在這一行遇到錯誤embeddings = embeddings[:i-1] + [avg] + embeddings[j:]錯誤是:-

TypeError: unsupported operand type(s) for +: 'Tensor' and 'list'

現在,我知道如果embeddings是一個列表但它是一個張量,上面的代碼會很好地工作。 我該怎么做?

筆記:

*1。 *embeddings.shape: torch.Size([8, 768])
2. avg 是浮點型**

要連接多個張量,您可以使用torch.cat ,其中張量列表在指定維度上連接。 這要求所有張量具有相同數量的維度,並且除了它們連接的維度之外的所有維度都需要具有相同的大小。

您的embeddings大小為[8, 768] ,因此左側和右側的大小分別為[num_left, 768][num_right, 768] 並且avg具有大小[768] (它是張量,而不是單個float ),因為您將多個嵌入平均為一個。 為了將它們與其他兩個部分連接起來,它需要具有大小[1, 768] ,以便可以在第一個維度上連接以創建大小為[num_left + 1 + num_right, 768]的張量。 可以使用torch.unsqueeze添加奇異的第一個維度。

embeddings = torch.cat([embeddings[:i-1], avg.unsqueeze(0), embeddings[j:]], dim=0)

for 循環也可以通過切片張量並使用torch.mean取平均值來替換。

# keepdim=True keeps the dimension that the average is taken on
# So the output has size [1, 768] instead of [768]
avg = torch.mean(embeddings[i-1:j], dim=0, keepdim=True)

embeddings = torch.cat([embeddings[:i-1], avg, embeddings[j:]], dim=0)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM