简体   繁体   English

Tensorflow2 关于形状不匹配的警告,仍在训练

[英]Tensorflow2 Warning on Shape Mismatch, Still Training

I am trying to use Keras/TF2.3.0 to do multilabel classification where I have 50 features and am classifying between five classes.我正在尝试使用 Keras/TF2.3.0 进行多标签分类,其中我有 50 个特征并在五个类别之间进行分类。 I am getting the following warning, although the model still trains, which confuses me.我收到以下警告,尽管模型仍在训练,这让我很困惑。


>>> model.fit(train_dataset, epochs=5, validation_data=val_dataset)

Epoch 1/5 WARNING:tensorflow:Model was constructed with shape (128, 1, 50) for input Tensor("input_1:0", shape=(128, 1, 50), dtype=float32), but it was called on an input with incompatible shape (None, 50). Epoch 1/5 WARNING:tensorflow:Model 被构造为形状 (128, 1, 50) for input Tensor("input_1:0", shape=(128, 1, 50), dtype=float32),但它被调用形状不兼容的输入(无,50)。

WARNING:tensorflow:Model was constructed with shape (128, 1, 50) for input Tensor("input_1:0", shape=(128, 1, 50), dtype=float32), but it was called on an input with incompatible shape (None, 50).警告:tensorflow:模型的输入 Tensor("input_1:0", shape=(128, 1, 50), dtype=float32) 的形状为 (128, 1, 50),但在不兼容的输入上调用它形状(无,50)。

1/5[..............................] - ETA: 0s - loss: 0.6996WARNING:tensorflow:Model was constructed with shape (128, 1, 50) for input Tensor("input_1:0", shape=(128, 1, 50), dtype=float32), but it was called on an input with incompatible shape (None, 50). 1/5[................................] - ETA: 0s - loss: 0.6996WARNING:tensorflow:Model was构建输入 Tensor("input_1:0", shape=(128, 1, 50), dtype=float32) 的形状为 (128, 1, 50),但它在形状不兼容的输入 (None, 50) 上被调用。 59/59 [==============================] - 0s 2ms/step - loss: 0.6941 - val_loss: 0.6935 59/59 [==============================] - 0s 2ms/步 - 损失:0.6941 - val_loss:0.6935

Epoch 2/5 59/59 [==============================]...纪元 2/5 59/59 [==============================]...

My full code, with random data to reproduce the error, is below.我的完整代码,带有随机数据来重现错误,如下所示。 What am I messing up with my NN architecture (or perhaps my dfs_to_tfds function?) to accept input records with num_vars features and output values distributed among num_classes classes in TF?我在处理我的 NN 架构(或者我的dfs_to_tfds函数?)来接受具有num_vars特征的输入记录和分布在 TF 中的num_classes类中的输出值的num_classes是什么?

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Input, Dense, Flatten, Conv1D, AveragePooling1D
from tensorflow.keras.models import Model
import tensorflow as tf

# setup example input data and labels
num_rows = 10_000
num_vars = 50
num_classes = 5
data = np.random.rand(num_rows, num_vars)
labels = np.random.rand(num_rows, num_classes)

# convert input data to TF.data datasets
bs=128
def dfs_to_tfds(features, targets, bs):
  return tf.data.Dataset.from_tensor_slices((features, targets)).batch(bs)

X_train, X_val, y_train, y_val = train_test_split(data, labels)

train_dataset = dfs_to_tfds(X_train, y_train, bs)
val_dataset = dfs_to_tfds(X_val, y_val, bs)

# setup model
inputs = Input(shape = (1, num_vars), batch_size=bs)
h = Dense(units=32, activation='relu')(inputs)
h = Dense(units=32, activation='relu')(h)
h = Dense(units=32, activation='relu')(h)
outputs = Dense(units=num_classes, activation='sigmoid')(h)

model = Model(inputs=inputs, outputs=outputs)

model.compile(optimizer='rmsprop', 
              loss=['binary_crossentropy'], #tf.keras.losses.MSLE
              metrics=None, 
              loss_weights=None, 
              run_eagerly=None)

# train model
model.fit(train_dataset, epochs=5, validation_data=val_dataset)

Use

inputs = Input(shape=num_vars)

and specify your batch size when fitting the model:并在拟合模型时指定批量大小:

model.fit(train_dataset, epochs=5, validation_data=val_dataset, batch_size=bs)

Your data is not preorganized in subbatches so you dont have to specify it along with the input shape but when fitting.您的数据不是预先组织在子批次中的,因此您不必将其与输入形状一起指定,而是在拟合时指定。 So model.fit automatically takes batches of batch_size out of your input data when fitting the model因此,model.fit 在拟合模型时会自动从输入数据中取出批次的batch_size

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM