簡體   English   中英

如何使 `fit_generator` 與 `tf.keras.Model` 一起工作

[英]How to make `fit_generator` work with `tf.keras.Model`

我正在實現一個tf.keras.Model (不是Sequential模型!),應該使用fit_generator進行訓練。 但是, fit_generator會引發錯誤,可能是因為在編譯時輸入形狀不可用。

這是一個最小的例子:

import tensorflow as tf
import numpy as np


class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

    def call(self, inputs, training=None, mask=None):
        return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

    def __len__(self):
        # Number of batches per epoch
        return 1

    def __getitem__(self, _):
        # Generate one batch of data
        x = np.array([[1., 2., 3.]])
        y = np.array([[0., 1., 0.5]])

        return x, y


if __name__ == '__main__':
    m = MyModel()    
    g = MyGenerator()

    m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    m.fit_generator(g)

最后一行提出

AttributeError: 'MyModel' object has no attribute 'total_loss'

那么在自定義fit_generator模型中使用fit_generator的正確方法是什么?

在 Tensorflow 2.x 中,默認啟用 Eager Execution。 Model.fit_generator已棄用,將在未來版本中刪除。 所以你必須使用支持生成器的Model.fit

請參考TF 2.4兼容代碼如下所示

import tensorflow as tf
print(tf.__version__)
import numpy as np


class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

    def call(self, inputs, training=None, mask=None):
        return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

    def __len__(self):
        # Number of batches per epoch
        return 1

    def __getitem__(self, _):
        # Generate one batch of data
        x = np.array([[1., 2., 3.]])
        y = np.array([[0., 1., 0.5]])

        return x, y


if __name__ == '__main__':
    m = MyModel()    
    g = MyGenerator()

    m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    m.fit(g)

輸出:

2.4.0
1/1 [==============================] - 0s 224ms/step - loss: 0.4725

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM