簡體   English   中英

打印或保存 tf.keras 模型的輸出,與輸入配對

[英]Print or save output of tf.keras model, paired with inputs

def locations_model(...):
    input_shape = image_shape + (3,)
    base_model = tf.keras.applications.MobileNetV2(...)                                                                base_model.trainable = False 
    inputs = tf.keras.Input(...)  
... ...     
    outputs = tfl.Dense(5, activation = "softmax")(x)
    model = tf.keras.Model(inputs, outputs)
 
    return model

上面的代碼只是顯示 tf.keras 模型中的輸入和輸出,該模型將輸入圖像分為 5 個類別。 如何保存每個輸入圖像的輸出類別(“y_pred”)?

簡單語句ypreds = model(inputs)ypreds = model.predict(inputs)生成一組 5 元素數組,它們相加為 1,即概率。
因此,問題是如何輸出預測的類別,在這種情況下是整數:0-4,而不是概率。 更新:這是 Apostolova 對 Lodzz 的“從 Keras 功能模型獲取類標簽”問題的回答,如 test_probas = model.predict(test_data) test_classes = probas.argmax(axis = -1)

感謝@EduardoriosChicago的確認。 為了社區的利益,我在這里提到您的答案。

代碼是

probas = model(x_in); x_classes = probas.argmax( axis = - 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM