简体   繁体   中英

Short circuit computation in mixture of experts model using tensorflow keras functional api

I am trying to swap between multiple different "expert" layers based on the output of a "gating" layer (as a mixture of experts). I created a custom layer that takes in the outputs of the expert and gating layers, but this ends up throwing away some outputs rather than not computing them in the first place.

How can I make the model "short circuit" to only evaluate the gating layer and the selected expert layer(s) to save computation time?

I am using tensorflow 2.0 gpu and the keras functional api

Keras models can be implemented fully dynamically, to support the efficient routing that you mentioned. The following example shows one way in which this can be done. The example is written with the following premises:

  1. It assumes there are two experts ( LayerA and LayerB )
  2. It assumes that a mix-of-experts model ( MixOfExpertsModel ) switches dynamically between the two expert layer classes depending on the per-example output of a Keras Dense layer
  3. It satisfies the need to run training on the model in a batch fashion.

Pay attention to the comments in the code to see how the switching is done.

import numpy as np
import tensorflow as tf


# This is your Expert A class.
class LayerA(tf.keras.layers.Layer):

  def build(self, input_shape):
    self.weight = self.add_weight("weight_a", shape=input_shape[1:])

  @tf.function
  def call(self, x):
    return x + self.weight


# This is your Expert B class.
class LayerB(tf.keras.layers.Layer):

  def build(self, input_shape):
    self.weight = self.add_weight("weight_b", shape=input_shape[1:])

  @tf.function
  def call(self, x):
    return x * self.weight


class MixOfExpertsModel(tf.keras.models.Model):

  def __init__(self):
    super(MixOfExpertsModel, self).__init__()
    self._expert_a = LayerA()
    self._expert_b = LayerB()
    self._gating_layer = tf.keras.layers.Dense(1, activation="sigmoid")

  @tf.function
  def call(self, x):
    z = self._gating_layer(x)
    # The switching logic:
    #   - examples with gating output <= 0.5 are routed to expert A
    #   - examples with gating output > 0.5 are routed to expert B.
    mask_a = tf.squeeze(tf.less_equal(z, 0.5), axis=-1)
    mask_b = tf.squeeze(tf.greater(z, 0.5), axis=-1)
    # `input_a` is a subset of slices of the original input (`x`).
    # So is `input_b`. As such, no compute is wasted.
    input_a = tf.boolean_mask(x, mask_a, axis=0)
    input_b = tf.boolean_mask(x, mask_b, axis=0)
    if tf.size(input_a) > 0:
      output_a = self._expert_a(input_a)
    else:
      output_a = tf.zeros_like(input_a)
    if tf.size(input_b) > 0:
      output_b = self._expert_b(input_b)
    else:
      output_b = tf.zeros_like(input_b)
    # Return `mask_a`, and `mask_b`, so that the caller can know
    # which example is routed to which expert and whether its output
    # appears in `output_a` or `output_b`. # This is necessary
    # for writing a (custom) loss function for this class.
    return output_a, output_b, mask_a, mask_b


# Create an intance of the mix-of-experts model.
mix_of_experts_model = MixOfExpertsModel()

# Generate some dummy data.
num_examples = 32
xs = np.random.random([num_examples, 8]).astype(np.float32)

# Call the model.
print(mix_of_experts_model(xs))

I didn't write a custom loss function that would support the training of this class. But that's doable by using the return values of MixOfExpertsModel.call() , namely the outputs and masks.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM