简体   繁体   English

使用 Numpy 权重矩阵初始化 TensorFlow CNN 模型

[英]Initialize TensorFlow CNN model with Numpy weight matrices

I am working on manually converting a pretrained matconvnet model to a tensorflow model.我正在手动将预训练的 matconvnet 模型转换为 tensorflow 模型。 I have pulled the weights/biases from the matconvnet model mat file using scipy.io and obtained numpy matrices for the weights and biases.我使用 scipy.io 从 matconvnet 模型 mat 文件中提取了权重/偏差,并获得了权重和偏差的 numpy 矩阵。

Code snippets where data is a dictionary returned from scipy.io: data是从 scipy.io 返回的字典的代码片段

for i in data['net2']['layers']:
    if i.type == 'conv':
        model.append({'weights': i.weights[0], 'bias': i.weights[1], 'stride': i.stride, 'padding': i.pad, 'momentum': i.momentum,'lr': i.learningRate,'weight_decay': i.weightDecay})

... ...

weights = {
    'wc1': tf.Variable(model[0]['weights']), 
    'wc2': tf.Variable(model[2]['weights']),
    'wc3': tf.Variable(model[4]['weights']),
    'wc4': tf.Variable(model[6]['weights'])
}

... ...

Where model[0]['weights'] are the 4x4x60 numpy matrices pulled from matconvnet model for for layer, for example.例如, model[0]['weights']是从 matconvnet 模型中提取的用于层的 4x4x60 numpy 矩阵。 And this is how I define the place holder for the 9x9 inputs.这就是我为 9x9 输入定义占位符的方式。

X = tf.placeholder(tf.float32, [None, 9, 9]) #also tried with [None, 81] with a tf.reshape, [None, 9, 9, 1]

Current Issue : I cannot get ranks to match up.当前问题:我无法获得匹配的排名。 I consistently getValueError:我始终getValueError:

ValueError: Shape must be rank 4 but is rank 3 for 'Conv2D' (op: 'Conv2D') with input shapes: [?,9,9], [4,4,60]  

Summary概括

  • Is it possible to explicitly define a tensorflow model's weights from numpy arrays?是否可以从 numpy 数组中明确定义 tensorflow 模型的权重?
  • Why is the rank for my weight matrices 4?为什么我的权重矩阵的等级是 4? Should my numpy array be something more like [?, 4, 4, 60], and can I make it that way?我的 numpy 数组应该更像 [?, 4, 4, 60],我可以这样做吗?

Unsuccessfully Attempted:未成功尝试:

  • Rotating numpy matrices: I know that matlab and python have different indexing, (0 based indexing vs 1 based, and row major vs column major).旋转 numpy 矩阵:我知道 matlab 和 python 有不同的索引,(基于 0 的索引与基于 1 的索引,以及行主要 vs 列主要)。 Even though I believe I have converted everything appropriately, I still have experimented using libraries like np.rot90() changing 4x4x60 array to 60x4x4.尽管我相信我已经适当地转换了所有内容,但我仍然尝试使用 np.rot90() 之类的库将 4x4x60 数组更改为 60x4x4。
  • Using tf.reshape: I have attempted to use tf.reshape on the weights after wrapping them with a tf.Variable wrapper, but I get Variable has no attribute 'reshape'使用 tf.reshape:在用 tf.Variable 包装器包装权重后,我尝试在权重上使用 tf.reshape,但我得到 Variable 没有属性“reshape”

NOTE : Please note, I am aware that there are a number of scripts to go from matconvnet to caffe, and caffe to tensorflow (as described here, for example, https://github.com/vlfeat/matconvnet/issues/1021 ).注意:请注意,我知道有许多脚本可以从 matconvnet 到 caffe,从 caffe 到 tensorflow(如此处所述,例如https://github.com/vlfeat/matconvnet/issues/1021 ) . My question is related to tensorflow weight initialization options:我的问题与 tensorflow 权重初始化选项有关:

I got over this hurdle with tf.reshape(...) (instead of calling weights['wc1'].reshape(...) ).我用tf.reshape(...) (而不是调用weights['wc1'].reshape(...)tf.reshape(...)了这个障碍。 I am still not certain about the performance yet, or if this is a horribly naive endeavor.我仍然不确定性能,或者这是否是一个非常天真的尝试。

UPDATE Further testing, this approach appears to be possible at least functionally (as in I have created a TensorFlow CNN model that will run and produce predictions that appear consistent with MatConvNet model. I make no claims on accuracies between the two).更新进一步测试,这种方法似乎至少在功能上是可行的(因为我已经创建了一个 TensorFlow CNN 模型,该模型将运行并产生与 MatConvNet 模型一致的预测。我对两者之间的准确性不做任何声明)。

I am sharing my code.我正在分享我的代码。 In my case, it was a very small network - and if you are attempting to use this code for your own matconvnet to tensorflow project, you will likely need much more modifications: https://github.com/melissadale/MatConv2TensorFlow就我而言,这是一个非常小的网络 - 如果您尝试将此代码用于您自己的 matconvnet 到 tensorflow 项目,您可能需要进行更多修改: https : //github.com/melissadale/MatConv2TensorFlow

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM