简体   繁体   中英

How to get class predictions from a trained Tensorflow classifier?

I have have trained binary classifier model. The model class contains self.cost , self.initial_state , self.final_state and self.logits params. It is saved simply with tf.train.Saver :

saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
saver.save(session, 'model.ckpt')

After the model was trained I load it as:

with tf.variable_scope("Model", reuse=False):
    model = MODEL(config, is_training=False)

with tf.Session() as session:
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(session, 'model.ckpt')

However, my model.run function returns cross-entropy loss which is the last op in the graph. I don't need loss, I need the model predictions for each batch element

logits = tf.sigmoid(tf.nn.xw_plus_b(last_layer, self.output_w, self.output_b))

where last_layer is a 800x1 matrix which I then later reshape into 32x25x1 (batch_size, sequence_length, 1) matrix. It is this matrix that contains the model prediction values in [0-1] range.

So, how can I use this model to make a prediction for single element matrix 1x1x1 ?

Add the OPs necessary to compute accuracy, something like what I have copied below (simply copied out of the closest model I had at hand).

  self.logits_flat = tf.argmax(logits, axis=1, output_type=tf.int32)
  labels_flat = tf.argmax(labels, axis=1, output_type=tf.int32)
  accuracy = tf.cast(tf.equal(self.logits_flat, labels_flat), tf.float32, name='accuracy')

Now when you run the model (either during test or training time) add accuracy to the sess.run call as:

sess.run([train_op, accuracy], feed_dict=...)

or

sess.run([accuracy, logits], feed_dict=...)

All you're doing when you call sess.run is to tell tensorflow to compute the value of whatever you ask for. You need to pass it in any data it needs to perform those computations. Tensorflow is lazy, it won't perform any computations that aren't explicitly necessary to produce the results you request. Eg if you run the second version of sess.run listed above the optimizer will not be run and hence your weights will not be updated.

Note that you can add the OPs after the network was trained because none of them actually add any variables so they won't affect the save/restore process any.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM