簡體   English   中英

如何不破壞模型輸出的可微性?

[英]How to not break differentiability with a model's output?

我在 Pytorch 中有一個自回歸語言模型,它生成文本,它是句子的集合,給定一個輸入:

output_text = ["sentence_1. sentence_2. sentence_3. sentence_4."]

請注意,語言模型的輸出是 logits(詞匯表上的概率)的形式,可以轉換為 token IDS 或字符串。

其中一些句子需要進入另一個模型才能獲得只影響這些句子的損失:

loss1 = model2("sentence_2")
loss2 = model2("sentence_4")
loss_total = loss1+loss2

在不破壞可微性的情況下,從第一個模型中分解/拆分生成的文本的正確方法是什么? 也就是說,相應的文本(從上面)看起來像一個張量的 pytorch 張量(以便在下一個模型中使用其中的一些):

"[["sentence_1."]
["sentence_2."] 
["sentence_3."]
["sentence_4."]]

例如,Python 的split(".")方法很可能會破壞可微性,但允許我將每個單獨的句子插入到第二個模型中以獲得損失。

好的解決了。 發布答案以完成。

由於輸出是 logits 的形式,我可以使用argmax來獲取每個標記的索引。 這應該讓我知道每個period點在哪里(知道句子的結尾在哪里)。 然后我可以通過以下方式拆分句子以保持漸變:

sentences_list = []
r = torch.rand(50) #imagine that this is the output logits (though instead of a tensor of values it will be a tensor of tensors)
period_indices = [10,30,49]
sentences_list.append(r[0:10])
sentences_list.append(r[10:30])
sentences_list.append(r[30:])

現在sentences_list中的每個元素都是一個句子,我可以將其發送到另一個模型以獲取損失

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM