简体   繁体   English

Keras train_on_batch()不训练模型vs fit()

[英]Keras train_on_batch() does not train the model vs fit()

I have a dataset that is too large to fit on RAM so I opted to use train_on_batch to train my model incrementally. 我的数据集太大而无法容纳在RAM上,因此我选择使用train_on_batch逐步训练我的模型。 To test if this approach works, I took a subset of my large data to run some preliminary testing. 为了测试这种方法是否有效,我使用了一部分大数据来进行一些初步测试。

However, I have been having some issues training the model, namely the accuracy of the model gets stuck at 10% when training with train_on_batch(). 但是,我在训练模型时遇到了一些问题,即在使用train_on_batch()进行训练时,模型的准确性停留在10%。 With fit(), I get an accuracy of 95% at 40 epochs. 使用fit()时,在40个历元时我的准确度为95%。 I have also tried fit_generator() and have encountered similar issues. 我也尝试过fit_generator()并遇到类似的问题。

using fit() 使用fit()

results = model.fit(x_train,y_train,batch_size=128,nb_epoch=40)

using train_on_batch() 使用train_on_batch()

#386 has been chosen so that each batch size is 128
splitSize = len(y_train) // 386

for j in range(20):
    print('epoch: '+str(j)+' ----------------------------')
    np.random.shuffle(x_train)
    np.random.shuffle(y_train)
    xb = np.array_split(x_train,386)
    yb = np.array_split(y_train,386)
    sumAcc = 0
    index = list(range(386))
    random.shuffle(index)
    for i in index:
        results = model.train_on_batch(xb[i],yb[i])
        sumAcc += results[1]
    print(sumAcc/(386))

The shuffle you are using is incorrect, because the y_train does not match x_train after the shuffle. 您使用的随机播放是不正确的,因为随机播放之后y_train与x_train不匹配。 When you shuffle like that, each array is shuffled in a different order. 当您像这样随机播放时,每个数组都以不同的顺序随机播放。 You can use: 您可以使用:

length = x_train.shape[0]
idxs = np.arange(0, length)
np.random.shuffle(idxs)

x_train = x_train[idxs]
y_train = y_train[idxs]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 train_on_batch() 在 keras 模型中做什么? - What does train_on_batch() do in keras model? Keras Sequential适合与几个train_on_batch调用相同吗? - Is Keras Sequential fit the same as several train_on_batch calls? TensorFlow Keras: tf.keras.ZA559B87068921EEC05086CE5485E978 函数比其他的train_batch 慢? - TensorFlow Keras: tf.keras.Model train_on_batch vs make_train_function - Why is one slower than the other? TensorFlow model fit 和 train_on_batch 之间的区别 - Difference between TensorFlow model fit and train_on_batch 在Keras中使用带有train_on_batch的Checkpoint保存 - Using Checkpoint saving with train_on_batch in Keras 为具有多个输出的模型尝试 train_on_batch 时 Keras 中的 sample_weight 问题 - Problem in sample_weight in Keras when trying train_on_batch for a model with multiple outputs train_on_batch nan更新后如何将keras模型还原为以前的时代权重 - How to revert keras model to previous epoch weights after train_on_batch nan update 为什么 Keras 的 train_on_batch 在第二个 epoch 产生零损失和准确率? - Why does Keras' train_on_batch produce zero loss and accuracy at the second epoch? Keras当train_on_batch时,梯度的操作为无 - Keras An operation has None for gradient when train_on_batch Keras:是否有具有历史+进度的 train_on_batch 示例代码? - Keras: is there sample code for train_on_batch which has history + progress?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM