简体   繁体   中英

Keras Model return predictions when evaluating

I have a dataset with multiple fields, but only two are relevant for my machine learning implementation. The rest shall not be considered for predictions, but might unveil interesting correlations.

Is there a way to return prediction results when calling model.evaluate ? For example:

[loss, accuracy, predicted_results] = model.evaluate(input, results)

AFAIK, we can't get prediction on x using model.evaluate , it simply returns the loss and acc , source . But for your need, you can write a custom class and define the necessary calls such as .evaluate and .predict . Let's define a simple model to demonstrate.

Train and Run

import tensorflow as tf
import numpy as np  

img = tf.random.normal([20, 32], 0, 1, tf.float32)
tar = np.random.randint(2, size=(20, 1))

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(10, input_dim = 32, 
                       kernel_initializer ='normal', activation= 'relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', 
              optimizer='adam', metrics=['accuracy'])
model.fit(img, tar, epochs=2, verbose=2)

Epoch 1/2
1/1 - 1s - loss: 0.7083 - accuracy: 0.5000
Epoch 2/2
1/1 - 0s - loss: 0.6983 - accuracy: 0.5000

Now, for your request, we can do something as follows:

class Custom_Evaluate:
    def __init__(self, model):
        self.model = model 
    def eval_predict(self, x, y):
        loss, acc = self.model.evaluate(x, y)
        pred = self.model.predict(x)
        return loss, acc, pred 

custom_evaluate = Custom_Evaluate(model)
loss, acc, pred = custom_evaluate.eval_predict(img, tar)
print(loss, acc)
print(pred)
0.6886215806007385 0.6499999761581421
[[0.5457604 ]
 [0.6126752 ]
 [0.53668976]
 [0.40323135]
 [0.37159938]
 [0.5520069 ]
 [0.4959099 ]
 [0.5363802 ]
 [0.5033434 ]
 [0.65680957]
 [0.6863682 ]
 [0.44409862]
 [0.4672098 ]
 [0.49656072]
 [0.620726  ]
 [0.47991502]
 [0.58834356]
 [0.5245693 ]
 [0.5359181 ]
 [0.4575624 ]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM