繁体   English   中英

以张量迭代数组Tensorflow

[英]Iterate over tensor as an array Tensorflow

我正在尝试将预测的图像保存在我用Tensorflow编写的CNN网络上。 在我的代码中y_pred_cls包含我的预测标签,并且y_pred_cls是尺寸为1 x批处理大小的张量。 现在,我要遍历y_pred_cls作为数组,并创建一个包含pred类,true类和一些索引号的文件名,然后找出与预测标签有关的图像,并使用imsave将其另存为图像。

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_writer.add_graph(sess.graph)



print("{} Start training...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print("{} Open Tensorboard at --logdir {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), tensorboard_dir))

for epoch in range(FLAGS.num_epochs):
    print("{} Epoch number: {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1))
    step = 1

    # Start training
    while step < train_batches_per_epoch:
        batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
        opt, train_acc = sess.run([optimizer, accuracy], feed_dict={x: batch_xs, y_true: batch_ys})

        # Logging
        if step % FLAGS.log_step == 0:
            s = sess.run(sum, feed_dict={x: batch_xs, y_true: batch_ys})
            train_writer.add_summary(s, epoch * train_batches_per_epoch + step)

        step += 1

    # Epoch completed, start validation
    print("{} Start validation".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
    val_acc = 0.
    val_count = 0
    cm_running_total = None

    for _ in range(val_batches_per_epoch):
        batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
        acc, loss , conf_m= sess.run([accuracy, cost, tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes)],
                                      feed_dict={x: batch_tx, y_true: batch_ty})



        if cm_running_total is None:
            cm_running_total = conf_m
        else:
            cm_running_total += conf_m


        val_acc += acc
        val_count += 1

    val_acc /= val_count

    s = tf.Summary(value=[
        tf.Summary.Value(tag="validation_accuracy", simple_value=val_acc),
        tf.Summary.Value(tag="validation_loss", simple_value=loss)
    ])

    val_writer.add_summary(s, epoch + 1)
    print("{} -- Training Accuracy = {:.4%} -- Validation Accuracy = {:.4%} -- Validation Loss = {:.4f}".format(
        datetime.now().strftime('%Y-%m-%d %H:%M:%S'), train_acc, val_acc, loss))

    # Reset the dataset pointers
    val_preprocessor.reset_pointer()
    train_preprocessor.reset_pointer()

    print("{} Saving checkpoint of model...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

    # save checkpoint of the model
    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch.ckpt' + str(epoch+1))
    save_path = saver.save(sess, checkpoint_path)
    print("{} Model checkpoint saved at {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), checkpoint_path))

batch_tx,batch_ty分别是我的RGB数据和标签。

提前致谢。

将数据从张量提取到python变量中

label = sess.run(y_pred_cls)

这将为您提供一个单矢量标签数组或一个用于标量标签的int变量。

要将阵列保存到图像,可以使用PIL库

from PIL import Image
img = Image.fromarray(data, 'RGB')
img.save('name.png')

其余的应该直接

  1. 从batch_tx,batch_ty和y_pred_cls张量中提取数据
  2. 遍历每个三元组
  3. 根据当前x创建RGB图像
  4. 创建一个形式为name = str(y)+'_'+str(y_hat)
  5. 保存图片

如果您在执行这些步骤时遇到困难,我可以为您提供进一步的帮助

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM