简体   繁体   中英

Custom Single Connected "dense" Layer in Keras causing crash

I'm trying to implement the model described in this post keras model with one connection per input node . However, the model crashes due to RAM exhaustion every time I try to train.

The model builds without a problem, as shown below:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 Trend_Input (InputLayer)    [(None, 30490, 6)]        0         
                                                                 
 Flatten (Flatten)           (None, 182940)            0         
                                                                 
 Yhat_flat (SingleDense)     (None, 182940)            182940    
                                                                 
 Yhat (Reshape)              (None, 30490, 6)          0         
                                                                 
=================================================================

The model has "only" 182.940 parameter, and I've trained far more complex models than this. But every time I try to fit this model it crashes.

The code I'm using is as follows:

class SingleDense(keras.layers.Layer):
  def __init__(self, **kwargs):
    super(SingleDense, self).__init__(**kwargs)

  def build(self, input_shape):
    self.size = int(input_shape[-1])
    self.kernel = self.add_weight("kernel",
                                  shape=[self.size],
                                  initializer=tf.random_normal_initializer())

  def call(self, input):
    linear = tf.matmul(input, self.kernel * tf.eye(self.size))
    return linear

def create_model(lookback, n_features, checkpoint):

    inputs = Input(shape=(n_features, lookback), name = "Trend_Input")

    flatten = Flatten(name = "Flatten")(inputs)
    yhat_flat = SingleDense(name = "Yhat_flat")(flatten)
    yhat = Reshape((n_features, lookback), name = "Yhat")(yhat_flat)

    model = keras.Model(inputs = inputs, outputs = yhat, name = "Final_Model")
    model.compile(optimizer = "adam", loss = "mse")
    tf.keras.utils.plot_model(model, show_shapes = True)
    model.summary()
    return model

There are some "forbidden" calculations being done in the custom layer? Or something left out?

EDIT

After closer inspection, I verified that the self.kernel * tf.eye(self.size) part would produce a (182940, 182940) vector of approximately 1070.9gb... Therefore, exhausting the RAM.

Now the question is, there is a way of efficiently doing this?

I just removed the self.kernel * tf.eye(self.size) by self.kernel and now it is behaving as expected.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM