[英]Implementation of multitask "nested" neural network
我正在嘗試實現論文使用的多任務神經網絡,但我很不確定我應該如何編碼多任務網絡,因為作者沒有提供該部分的代碼。
網絡架構看起來像(紙):
為了更簡單,網絡架構可以概括為(對於演示,我將它們對單個嵌入對的更復雜的操作更改為串聯):
作者正在總結單個任務和成對任務的損失,並使用總損失來優化每個批次中三個網絡(編碼器、MLP-1、MLP-2)的參數,但我有點不知所措關於如何將不同類型的數據組合在一個批次中以饋入共享初始編碼器的兩個不同網絡。 我試圖搜索具有類似結構的其他網絡,但沒有找到任何來源。 將不勝感激任何想法!
這實際上是一種常見的模式。 它將通過如下代碼解決。
class Network(nn.Module):
def __init__(self, ...):
self.encoder = DrugTargetInteractiongNetwork()
self.mlp1 = ClassificationMLP()
self.mlp2 = PairwiseMLP()
def forward(self, data_a, data_b):
a_encoded = self.encoder(data_a)
b_encoded = self.encoder(data_b)
a_classified = self.mlp1(a_encoded)
b_classified = self.mlp1(b_encoded)
# let me assume data_a and data_b are of shape
# [batch_size, n_molecules, n_features].
# and that those n_molecules are not necessarily
# equal.
# This can be generalized to more dimensions.
a_broadcast, b_broadcast = torch.broadcast_tensors(
a_encoded[:, None, :, :],
b_encoded[:, :, None, :],
)
# this will work if your mlp2 accepts an arbitrary number of
# learding dimensions and just broadcasts over them. That's true
# for example if it uses just Linear and pointwise
# operations, but may fail if it makes some specific assumptions
# about the number of dimensions of the inputs
pairwise_classified = self.mlp2(a_broadcast, b_broadcast)
# if that is a problem, you have to reshape it such that it
# works. Most torch models accept at least a leading batch dimension
# for vectorization, so we can "fold" the pairwise dimension
# into the batch dimension, presenting it as
# [batch*n_mol_1*n_mol_2, n_features]
# to mlp2 and then recover it back
B, N1, N_feat = a_broadcast.shape
_B, N2, _N_feat = b_broadcast.shape
a_batched = a_broadcast.reshape(B*N1*N2, N_feat)
b_batched = b_broadcast.reshape(B*N1*N2, N_feat)
# above, -1 would suffice instead of B*N1*N2, just being explicit
batch_output = self.mlp2(a_batched, b_batched)
# this should be exactly the same as `pairwise_classified`
alternative_classified = batch_output.reshape(B, N1, N2, -1)
return a_classified, b_classified, pairwise_classified
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.