[英]how to load pretrained LSTM models weights in Tensorflow
I want to implement a LSTM model with pretrained weights in Tensorflow. 我想在Tensorflow中实现具有预训练权重的LSTM模型。 These weights may come from Caffee or Torch.
这些重量可能来自Caffee或Torch。
I found there are LSTM cells in file rnn_cell.py
, such as rnn_cell.BasicLSTMCell
and rnn_cell.MultiRNNCell
. 我发现文件
rnn_cell.py
有LSTM单元,例如rnn_cell.BasicLSTMCell
和rnn_cell.MultiRNNCell
。 But how can I load the pretrained weights for these LSTM cells. 但是如何为这些LSTM单元加载预训练的权重。
Here's a solution for loading a pre-trained Caffe model. 这是一个加载预先训练的Caffe模型的解决方案。 See the full code here , which is referenced in the discussion in this thread .
请参阅此处的完整代码 ,该代码在此主题的讨论中引用。
net_caffe = caffe.Net(prototxt, caffemodel, caffe.TEST)
caffe_layers = {}
for i, layer in enumerate(net_caffe.layers):
layer_name = net_caffe._layer_names[i]
caffe_layers[layer_name] = layer
def caffe_weights(layer_name):
layer = caffe_layers[layer_name]
return layer.blobs[0].data
def caffe_bias(layer_name):
layer = caffe_layers[layer_name]
return layer.blobs[1].data
#tensorflow uses [filter_height, filter_width, in_channels, out_channels] 2-3-1-0
#caffe uses [out_channels, in_channels, filter_height, filter_width] 0-1-2-3
def caffe2tf_filter(name):
f = caffe_weights(name)
return f.transpose((2, 3, 1, 0))
class ModelFromCaffe():
def get_conv_filter(self, name):
w = caffe2tf_filter(name)
return tf.constant(w, dtype=tf.float32, name="filter")
def get_bias(self, name):
b = caffe_bias(name)
return tf.constant(b, dtype=tf.float32, name="bias")
def get_fc_weight(self, name):
cw = caffe_weights(name)
if name == "fc6":
assert cw.shape == (4096, 25088)
cw = cw.reshape((4096, 512, 7, 7))
cw = cw.transpose((2, 3, 1, 0))
cw = cw.reshape(25088, 4096)
else:
cw = cw.transpose((1, 0))
return tf.constant(cw, dtype=tf.float32, name="weight")
images = tf.placeholder("float", [None, 224, 224, 3], name="images")
m = ModelFromCaffe()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
batch = cat.reshape((1, 224, 224, 3))
out = sess.run([m.prob, m.relu1_1, m.pool5, m.fc6], feed_dict={ images: batch })
...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.