[英]Iterate over tensor as an array Tensorflow
我正在嘗試將預測的圖像保存在我用Tensorflow編寫的CNN網絡上。 在我的代碼中y_pred_cls
包含我的預測標簽,並且y_pred_cls
是尺寸為1 x批處理大小的張量。 現在,我要遍歷y_pred_cls作為數組,並創建一個包含pred類,true類和一些索引號的文件名,然后找出與預測標簽有關的圖像,並使用imsave
將其另存為圖像。
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_writer.add_graph(sess.graph)
print("{} Start training...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print("{} Open Tensorboard at --logdir {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), tensorboard_dir))
for epoch in range(FLAGS.num_epochs):
print("{} Epoch number: {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1))
step = 1
# Start training
while step < train_batches_per_epoch:
batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
opt, train_acc = sess.run([optimizer, accuracy], feed_dict={x: batch_xs, y_true: batch_ys})
# Logging
if step % FLAGS.log_step == 0:
s = sess.run(sum, feed_dict={x: batch_xs, y_true: batch_ys})
train_writer.add_summary(s, epoch * train_batches_per_epoch + step)
step += 1
# Epoch completed, start validation
print("{} Start validation".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
val_acc = 0.
val_count = 0
cm_running_total = None
for _ in range(val_batches_per_epoch):
batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
acc, loss , conf_m= sess.run([accuracy, cost, tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes)],
feed_dict={x: batch_tx, y_true: batch_ty})
if cm_running_total is None:
cm_running_total = conf_m
else:
cm_running_total += conf_m
val_acc += acc
val_count += 1
val_acc /= val_count
s = tf.Summary(value=[
tf.Summary.Value(tag="validation_accuracy", simple_value=val_acc),
tf.Summary.Value(tag="validation_loss", simple_value=loss)
])
val_writer.add_summary(s, epoch + 1)
print("{} -- Training Accuracy = {:.4%} -- Validation Accuracy = {:.4%} -- Validation Loss = {:.4f}".format(
datetime.now().strftime('%Y-%m-%d %H:%M:%S'), train_acc, val_acc, loss))
# Reset the dataset pointers
val_preprocessor.reset_pointer()
train_preprocessor.reset_pointer()
print("{} Saving checkpoint of model...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
# save checkpoint of the model
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch.ckpt' + str(epoch+1))
save_path = saver.save(sess, checkpoint_path)
print("{} Model checkpoint saved at {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), checkpoint_path))
batch_tx,batch_ty分別是我的RGB數據和標簽。
提前致謝。
將數據從張量提取到python變量中
label = sess.run(y_pred_cls)
這將為您提供一個單矢量標簽數組或一個用於標量標簽的int變量。
要將陣列保存到圖像,可以使用PIL庫
from PIL import Image
img = Image.fromarray(data, 'RGB')
img.save('name.png')
其余的應該直接
x
創建RGB圖像 name = str(y)+'_'+str(y_hat)
如果您在執行這些步驟時遇到困難,我可以為您提供進一步的幫助
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.