簡體   English   中英

在TensorFlow中將預測張量保存到圖像 - 圖形最終確定

[英]Saving predicted tensor to image in TensorFlow - Graph finalized

我能夠使用自己的數據在TensorFlow中訓練模型。 模型的輸入和輸出是圖像。 我現在嘗試獲取預測的輸出並將其保存到png圖像文件以查看最新情況。 不幸的是,我在運行我創建的以下函數時遇到錯誤,以測試預測。 我的目標是保存也是圖像的預測,這樣我就可以用普通的圖像查看器打開它。

代碼還有一些。 在我的主要我正在創建一個估計

def predict_element(my_model, features):
  eval_input_fn = tf.estimator.inputs.numpy_input_fn(
    x=features,
    num_epochs=1,
    shuffle=False)

  eval_results = my_model.predict(input_fn=eval_input_fn)

  predictions = eval_results.next() #this returns a dict with my tensors
  prediction_tensor = predictions["y"] #get the tensor from the dict

  image_tensor = tf.reshape(prediction_tensor, [IMG_WIDTH, -1]) #reshape to a matrix due my returned tensor is a 1D flat one
  decoded_image = tf.image.encode_png(image_tensor)
  write_image = tf.write_file("output/my_output_image.png", decoded_image)

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(write_image))

def get_input():
  filename_dataset = tf.data.Dataset.list_files("features/*.png")
  label_dataset = tf.data.Dataset.list_files("labels/*.png")

  # Make a Dataset of image tensors by reading and decoding the files.
  image_dataset = filename_dataset.map(lambda x: tf.cast(tf.image.decode_png(tf.read_file(x), channels=1),tf.float32))
  l_dataset = label_dataset.map(lambda x: tf.cast(tf.image.decode_png(tf.read_file(x),channels=1),tf.float32))

  image_reshape = image_dataset.map(lambda x: tf.reshape(x, [IM_WIDTH * IM_HEIGHT]))
  label_reshape = l_dataset.map(lambda x: tf.reshape(x, [IM_WIDTH * IM_HEIGHT]))

  iterator = image_reshape.make_one_shot_iterator()
  iterator2 = label_reshape.make_one_shot_iterator()

  next_img = iterator.get_next()
  next_lbl = iterator2.get_next()

  features = []
  labels = []
  # read all 10 images and labels and put it in the array
  # so we can pass it to the estimator
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(10):
      t1, t2 = sess.run([next_img, next_lbl])
      features.append(t1)
      labels.append(t2)

  return {"x": np.array(features)}, np.array(labels)

def main(unused_argv):
    features, labels = get_input() # creating the features dict {"x": }
    my_estimator = tf.estimator.Estimator(model_fn=my_cnn_model, model_dir="/tmp/my_model")

    predict_element(my_estimator, features)

錯誤是

圖表已完成,無法修改

使用一些簡單的print()語句,我可以看到檢索dict

eval_results = my_model.predict(input_fn = eval_input_fn)

很可能是最終確定圖表的人。 我絕對不知道該做什么或在哪里尋找解決方案。 我怎么能保存輸出?

我在我的model_fn中試過這個:

#the last layer of my network is dropout
predictions = {
   "y": dropout
    }

  if mode == tf.estimator.ModeKeys.PREDICT:
    reshape1 = tf.reshape(dropout, [-1,IM_WIDTH, IM_HEIGHT])
    sliced = tf.slice(reshape1, [0,0,0], [1, IM_WIDTH, IM_HEIGHT])
    encoded = tf.image.encode_png(tf.cast(sliced, dtype=tf.uint8))
    outputfile = tf.write_file(params["output_path"], encoded)
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

我的問題是我無法傳回“outputfile”節點,所以我可以使用它。

那么你的圖表已經完成,無法修改。 您可以將此tensorflow操作添加到模型中(在運行之前),或者只是編寫一些python代碼,單獨保存圖像(不使用tensorflow)。 也許我會找一些我的舊代碼作為例子。

您還可以創建第二個圖形,然后可以使用tensorflow而不更改現有的模型圖形。

您必須區分圖節點和評估對象。 tf.reshape不將數組作為輸入,而是圖形節點。 https://www.tensorflow.org/programmers_guide/graphs

對於每個有同樣問題的人來說,我的解決方案。 我不知道這是否是正確的方法,但它的工作原理。

在我的預測函數中,我創建了第二個圖形,用於重塑,切片,編碼和保存,如:

  pred_dict = eval_results.next() #generator the predict function returns
  preds = pred_dict["y"] #get the predictions from the dict

  #create the second graph
  g = tf.Graph()
  with g.as_default():
    inp = tf.Variable(preds)
    reshape1 = tf.reshape(printnode, [IM_WIDTH, IM_HEIGHT, -1])
    sliced = tf.slice(reshape1, [0,0,0], [ IM_WIDTH, IM_HEIGHT,1])
    reshaped = tf.reshape(sliced, [IM_HEIGHT, IM_WIDTH, 1])
    encoded = tf.image.encode_png(tf.image.convert_image_dtype(reshaped,tf.uint16))
    outputfile = tf.write_file("/tmp/pred_output/prediction_img.png", encoded)
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(outputfile)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM