[英]How to get multiclass Keras model to output string instead of OHE?
I understand the idea of passing labels in training data to a Keras model as a one-hot encoding, but I'm trying to create a model that will return a string during inference, and not a one-hot encoding that I'd have to decode myself.我理解将训练数据中的标签作为单热编码传递给 Keras 模型的想法,但我正在尝试创建一个模型,该模型将在推理过程中返回一个字符串,而不是我所拥有的单热编码解码我自己。
Specifically, I DON'T want to have to do the following:具体来说,我不想做以下事情:
encoder = LabelEncoder()
encoder = encoder.fit(labels)
encoded_Y = encoder.transform(labels)
y_true = np_utils.to_categorical(encoded_Y). # Model accepts this during training
prediction = model.predict(query)
label_string = encoder.inverse_transform(prediction)
How can I create a model that will call .predict()
and return something customized, such as the string of the highest prediction and its corresponding probability?如何创建一个将调用
.predict()
并返回自定义内容的模型,例如最高预测的字符串及其相应的概率?
You can take a complete model, and create a custom model on top of it that includes a post-processing function.您可以采用一个完整的模型,并在其上创建一个包含后处理功能的自定义模型。
Let's take a pre-trained model:让我们以一个预训练的模型为例:
import tensorflow as tf
from skimage.data import chelsea
model = tf.keras.applications.MobileNetV2()
That model comes with a function that takes the output probability matrix, and turns it into strings (you'll have to make your own)该模型带有一个函数,该函数采用输出概率矩阵,并将其转换为字符串(您必须自己制作)
decoder = tf.keras.applications.mobilenet_v2.decode_predictions
Then, we create a model that has two parts: 1) the model 2) the post-processing function然后,我们创建一个包含两部分的模型:1) 模型 2) 后处理功能
class NewModel(tf.keras.models.Model):
def __init__(self, model, decoder):
super(NewModel, self).__init__()
self.model = model
self.decoder = decoder
def call(self, x):
x = self.model(x).numpy()
x = self.decoder(x)
return x
m = NewModel(model, decoder)
Now, just call it with appropriate input:现在,只需使用适当的输入调用它:
cat_picture = tf.image.resize(chelsea(), (224, 224))[None, ...]/255
m(cat_picture)
[[('n02124075', 'Egyptian_cat', 0.7610773),
('n02123045', 'tabby', 0.12039327),
('n02123159', 'tiger_cat', 0.039847013),
('n02127052', 'lynx', 0.009957394),
('n04553703', 'washbasin', 0.0015057808)]]
It returns the output of the post-processing function.它返回后处理函数的输出。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.