簡體   English   中英

TensorFlow 2.x:使用嵌入列時無法加載 h5 格式的訓練模型(ValueError: Shapes (101, 15) and (57218, 15) is incompatible)

[英]TensorFlow 2.x: Cannot load trained model in h5 format when using embedding columns (ValueError: Shapes (101, 15) and (57218, 15) are incompatible)

經過長時間的來回,我設法保存了我的模型(請參閱我的問題TensorFlow 2.x:無法以 h5 格式保存訓練模型(OSError:無法創建鏈接(名稱已存在)) )。 但是現在我在加載保存的模型時遇到了問題。 首先,我通過加載模型得到以下錯誤:

ValueError: You are trying to load a weight file containing 1 layers into a model with 0 layers.

將順序更改為功能 API 后,我收到以下錯誤:

ValueError: Cannot assign to variable dense_features/NAME1W1_embedding/embedding_weights:0 due to variable shape (101, 15) and value shape (57218, 15) are incompatible

我嘗試了不同版本的 TensorFlow。 我在 tf-nightly 版本中收到了描述的錯誤。 在 2.1 版中,我得到了一個非常相似的錯誤:

ValueError: Shapes (101, 15) and (57218, 15) are incompatible.

在 2.2 和 2.3 版本中,我什至無法保存我的模型(如我之前的問題所述)。

下面是函數式API的相關代碼:

def __loadModel(args):
    filepath = args.loadModel

    model = tf.keras.models.load_model(filepath)

    print("start preprocessing...")
    (_, _, test_ds) = preprocessing.getPreProcessedDatasets(args.data, args.batchSize)
    print("preprocessing completed")

    _, accuracy = model.evaluate(test_ds)
    print("Accuracy", accuracy)



def __trainModel(args):
    (train_ds, val_ds, test_ds) = preprocessing.getPreProcessedDatasets(args.data, args.batchSize)

    for bucketSizeGEO in args.bucketSizeGEO:
        print("start preprocessing...")
        feature_columns = preprocessing.getFutureColumns(args.data, args.zip, bucketSizeGEO, True)
        #Todo: compare trainable=False to trainable=True
        feature_layer = tf.keras.layers.DenseFeatures(feature_columns, trainable=False)
        print("preprocessing completed")


        feature_layer_inputs = preprocessing.getFeatureLayerInputs()
        feature_layer_outputs = feature_layer(feature_layer_inputs)
        output_layer = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(feature_layer_outputs)

        model = tf.keras.Model(inputs=[v for v in feature_layer_inputs.values()], outputs=output_layer)

        model.compile(optimizer='sgd',
            loss='binary_crossentropy',
            metrics=['accuracy'])

        paramString = "Arg-e{}-b{}-z{}".format(args.epoch, args.batchSize, bucketSizeGEO)


        log_dir = "logs\\logR\\" + paramString + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)


        model.fit(train_ds,
                validation_data=val_ds,
                epochs=args.epoch,
                callbacks=[tensorboard_callback])


        model.summary()

        loss, accuracy = model.evaluate(test_ds)
        print("Accuracy", accuracy)

        paramString = paramString + "-a{:.4f}".format(accuracy)

        outputName = "logReg" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + paramString

        

        if args.saveModel:
            for i, w in enumerate(model.weights): print(i, w.name)

            path = './saved_models/' + outputName + '.h5'
            model.save(path, save_format='h5')

對於相關的預處理部分,請參閱本問題開頭提到的問題。 for i, w in enumerate(model.weights): print(i, w.name)返回以下內容:

0 dense_features/NAME1W1_embedding/embedding_weights:0
1 dense_features/NAME1W2_embedding/embedding_weights:0
2 dense_features/STREETW_embedding/embedding_weights:0
3 dense_features/ZIP_embedding/embedding_weights:0
4 dense/kernel:0
5 dense/bias:0

這個問題是由於訓練和預測中嵌入矩陣的維度不一致造成的。

通常,在我們使用嵌入矩陣之前,我們會形成一個字典。 這里暫時把這個字典叫做word_index。 如果代碼作者不細心,會導致訓練和預測兩個word_index不同(因為訓練和預測使用的數據不同),embedding matrix的維數發生變化。

從你的bug中可以看到,訓練時得到len(word_index)+1是57218,預測時得到len(word_index)+1是101。

如果我們想正確運行代碼,當需要使用word_index的預測時,我們不能在預測的時候重新生成一個word_index。 所以解決這個問題最簡單的辦法就是保存你訓練時得到的word_index,在預測的時候調用,這樣我們就可以正確加載訓練時得到的權重。

我能夠解決我相當愚蠢的錯誤:

我正在使用 feature_column 庫來預處理我的數據。 不幸的是,我在函數 categorical_column_with_identity 的參數 num_buckets 中指定了詞匯表的固定大小而不是實際大小。 錯誤版本:

street_voc = tf.feature_column.categorical_column_with_identity(
        key='STREETW', num_buckets=100)

正確版本:

street_voc = tf.feature_column.categorical_column_with_identity(
        key='STREETW', num_buckets= __getNumberOfWords(data, 'STREETPRO') + 1)

函數__getNumberOfWords(data, 'STREETPRO')返回 pandas 數據幀的'STREETPRO'列中不同單詞的數量。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM