簡體   English   中英

如何在Tensorflow中將預訓練網絡用作圖層?

[英]How do I use a pretrained network as a layer in Tensorflow?

我想使用特征提取器(例如ResNet101)並在之后添加使用特征提取器層輸出的圖層。 但是,我似乎無法弄清楚如何。 我只在網上找到了使用整個網絡的解決方案,而沒有添加額外的圖層。 我對Tensorflow缺乏經驗。

在下面的代碼中,您可以看到我嘗試過的內容。 我可以在沒有額外卷積層的情況下正確運行代碼,但我的目標是在ResNet之后添加更多層。 嘗試添加額外的轉換層時,會返回此類型錯誤:TypeError:預期的float32,得到OrderedDict([('resnet_v1_101 / conv1',...

一旦我添加了更多圖層,我想開始在一個非常小的測試集上進行訓練,看看我的模型是否能夠過度擬合。


import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import matplotlib.pyplot as plt

numclasses = 17

from google.colab import drive
drive.mount('/content/gdrive')

def decode_text(filename):
  img = tf.io.decode_jpeg(tf.io.read_file(filename))
  img = tf.image.resize_bilinear(tf.expand_dims(img, 0), [224, 224])
  img = tf.squeeze(img, 0)
  img.set_shape((None, None, 3))
  return img

dataset = tf.data.TextLineDataset(tf.cast('gdrive/My Drive/5LSM0collab/filenames.txt', tf.string))
dataset = dataset.map(decode_text)
dataset = dataset.batch(2, drop_remainder=True)

img_1 = dataset.make_one_shot_iterator().get_next()
net = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)


sess = tf.Session()

global_init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
sess.run(global_init)
sess.run(local_init)
img_out, conv_out = sess.run((img_1, net))

resnet_v1.resnet_v1_101不會返回net ,而是返回一個元組net, end_points 第二個元素是字典,這可能是您收到此特定錯誤消息的原因。

有關此功能文檔

返回:

net:size-4張量的大小[batch,height_out,width_out,channels_out]。 如果global_pool為False,則height_out和width_out與相應的height_in和width_in相比減少output_stride因子,否則height_out和width_out都等於1。 如果num_classes為0或None,則net是最后一個ResNet塊的輸出,可能在全局平均池之后。 如果num_classes是非零整數,則net包含pre-softmax激活。

end_points:從網絡組件到相應激活的字典。

所以你可以寫例如:

net, _ = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(net, numclasses, 1)

您還可以選擇中間層,例如:

_, end_points = resnet_v1.resnet_v1_101(img_1, 2048, is_training=False, global_pool=False, output_stride=8) 
net = slim.conv2d(end_points["main_Scope/resnet_v1_101/block3"], numclasses, 1)

(您可以查看end_points以查找端點的名稱。您的范圍名稱將與main_Scope不同。)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM