简体   繁体   中英

Implementation of multitask "nested" neural network

I am trying to implement a multitask neural network used by a paper but am quite unsure how I should code the multitask network because the authors did not provide code for that part.

The network architecture looks like ( paper ):

网络架构

To make it simpler, the network architecture could be generalized as (For demo I changed their more complicated operation for the pair of individual embeddings to concatenation): 更简单的版本

The authors are summing the loss from the individual tasks and the pairwise tasks, and using the total loss to optimize the parameters for the three networks (encoder, MLP-1, MLP-2) in each batch, but I am kind of at sea as to how different types of data are combined in a single batch to feed into two different networks that share an initial encoder. I tried to search for other networks with similar structure but did not find any sources. Would appreciate any thoughts!

This is actually a common pattern. It would be solved by code like the following.

class Network(nn.Module):
   def __init__(self, ...):
      self.encoder = DrugTargetInteractiongNetwork()
      self.mlp1 = ClassificationMLP()
      self.mlp2 = PairwiseMLP()

   def forward(self, data_a, data_b):
      a_encoded = self.encoder(data_a)
      b_encoded = self.encoder(data_b)

      a_classified = self.mlp1(a_encoded)
      b_classified = self.mlp1(b_encoded)

      # let me assume data_a and data_b are of shape
      # [batch_size, n_molecules, n_features].
      # and that those n_molecules are not necessarily
      # equal.
      # This can be generalized to more dimensions.
      a_broadcast, b_broadcast = torch.broadcast_tensors(
         a_encoded[:, None, :, :],
         b_encoded[:, :, None, :],
      )

      # this will work if your mlp2 accepts an arbitrary number of
      # learding dimensions and just broadcasts over them. That's true
      # for example if it uses just Linear and pointwise
      # operations, but may fail if it makes some specific assumptions
      # about the number of dimensions of the inputs
      pairwise_classified = self.mlp2(a_broadcast, b_broadcast)

      # if that is a problem, you have to reshape it such that it
      # works. Most torch models accept at least a leading batch dimension
      # for vectorization, so we can "fold" the pairwise dimension
      # into the batch dimension, presenting it as
      # [batch*n_mol_1*n_mol_2, n_features]
      # to mlp2 and then recover it back
      B, N1, N_feat = a_broadcast.shape
      _B, N2, _N_feat = b_broadcast.shape
      a_batched = a_broadcast.reshape(B*N1*N2, N_feat)
      b_batched = b_broadcast.reshape(B*N1*N2, N_feat)
      # above, -1 would suffice instead of B*N1*N2, just being explicit
      batch_output = self.mlp2(a_batched, b_batched)

      # this should be exactly the same as `pairwise_classified`
      alternative_classified = batch_output.reshape(B, N1, N2, -1)

      return a_classified, b_classified, pairwise_classified

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM