简体   繁体   English

如何让多类 Keras 模型输出字符串而不是 OHE?

[英]How to get multiclass Keras model to output string instead of OHE?

I understand the idea of passing labels in training data to a Keras model as a one-hot encoding, but I'm trying to create a model that will return a string during inference, and not a one-hot encoding that I'd have to decode myself.我理解将训练数据中的标签作为单热编码传递给 Keras 模型的想法,但我正在尝试创建一个模型,该模型将在推理过程中返回一个字符串,而不是我所拥有的单热编码解码我自己。

Specifically, I DON'T want to have to do the following:具体来说,我不想做以下事情:

encoder = LabelEncoder() 
encoder = encoder.fit(labels)
encoded_Y = encoder.transform(labels)
y_true = np_utils.to_categorical(encoded_Y). # Model accepts this during training

prediction = model.predict(query)
label_string = encoder.inverse_transform(prediction)

How can I create a model that will call .predict() and return something customized, such as the string of the highest prediction and its corresponding probability?如何创建一个将调用.predict()并返回自定义内容的模型,例如最高预测的字符串及其相应的概率?

You can take a complete model, and create a custom model on top of it that includes a post-processing function.您可以采用一个完整的模型,并在其上创建一个包含后处理功能的自定义模型。

Let's take a pre-trained model:让我们以一个预训练的模型为例:

import tensorflow as tf
from skimage.data import chelsea

model = tf.keras.applications.MobileNetV2()

That model comes with a function that takes the output probability matrix, and turns it into strings (you'll have to make your own)该模型带有一个函数,该函数采用输出概率矩阵,并将其转换为字符串(您必须自己制作)

decoder = tf.keras.applications.mobilenet_v2.decode_predictions

Then, we create a model that has two parts: 1) the model 2) the post-processing function然后,我们创建一个包含两部分的模型:1) 模型 2) 后处理功能

class NewModel(tf.keras.models.Model):
    def __init__(self, model, decoder):
        super(NewModel, self).__init__()
        self.model = model
        self.decoder = decoder

    def call(self, x):
        x = self.model(x).numpy()
        x = self.decoder(x)
        return x


m = NewModel(model, decoder)

Now, just call it with appropriate input:现在,只需使用适当的输入调用它:

cat_picture = tf.image.resize(chelsea(), (224, 224))[None, ...]/255

m(cat_picture)
[[('n02124075', 'Egyptian_cat', 0.7610773),
  ('n02123045', 'tabby', 0.12039327),
  ('n02123159', 'tiger_cat', 0.039847013),
  ('n02127052', 'lynx', 0.009957394),
  ('n04553703', 'washbasin', 0.0015057808)]]

It returns the output of the post-processing function.它返回后处理函数的输出。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM