I'm trying to implement the model described in this post keras model with one connection per input node . However, the model crashes due to RAM exhaustion every time I try to train.
The model builds without a problem, as shown below:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Trend_Input (InputLayer) [(None, 30490, 6)] 0
Flatten (Flatten) (None, 182940) 0
Yhat_flat (SingleDense) (None, 182940) 182940
Yhat (Reshape) (None, 30490, 6) 0
=================================================================
The model has "only" 182.940 parameter, and I've trained far more complex models than this. But every time I try to fit this model it crashes.
The code I'm using is as follows:
class SingleDense(keras.layers.Layer):
def __init__(self, **kwargs):
super(SingleDense, self).__init__(**kwargs)
def build(self, input_shape):
self.size = int(input_shape[-1])
self.kernel = self.add_weight("kernel",
shape=[self.size],
initializer=tf.random_normal_initializer())
def call(self, input):
linear = tf.matmul(input, self.kernel * tf.eye(self.size))
return linear
def create_model(lookback, n_features, checkpoint):
inputs = Input(shape=(n_features, lookback), name = "Trend_Input")
flatten = Flatten(name = "Flatten")(inputs)
yhat_flat = SingleDense(name = "Yhat_flat")(flatten)
yhat = Reshape((n_features, lookback), name = "Yhat")(yhat_flat)
model = keras.Model(inputs = inputs, outputs = yhat, name = "Final_Model")
model.compile(optimizer = "adam", loss = "mse")
tf.keras.utils.plot_model(model, show_shapes = True)
model.summary()
return model
There are some "forbidden" calculations being done in the custom layer? Or something left out?
EDIT
After closer inspection, I verified that the self.kernel * tf.eye(self.size)
part would produce a (182940, 182940) vector of approximately 1070.9gb... Therefore, exhausting the RAM.
Now the question is, there is a way of efficiently doing this?
I just removed the self.kernel * tf.eye(self.size) by self.kernel and now it is behaving as expected.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.