简体   繁体   English

Tensorflow:使用预训练的初始模型

[英]Tensorflow: use pretrained inception model

I would like to use one of these pretrained tensorflow models: https://github.com/tensorflow/models/tree/master/slim我想使用这些预训练张量流模型之一: https : //github.com/tensorflow/models/tree/master/slim

After loading the inceptionv4 model, I've had problems with some test predictions.加载 inceptionv4 模型后,我遇到了一些测试预测问题。 There is a similar question: Using pre-trained inception_resnet_v2 with Tensorflow有一个类似的问题: Using pre-trained inception_resnet_v2 with Tensorflow

In that question, the solution was to fix the image preprocessing.在那个问题中,解决方案是修复图像预处理。 I tried using ranges for the color channels from 0 to 1 and from -1 to 1.我尝试使用从 0 到 1 和从 -1 到 1 的颜色通道范围。

Here is my code (I've imported everything from the inceptionv4 source file):这是我的代码(我已经从 inceptionv4 源文件中导入了所有内容):

checkpoint_file = '..\checkpoints\inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['horse.jpg', 'hound.jpg']
sess = tf.Session()

im_size = 299
inception_v4.default_image_size = im_size

arg_scope = inception_utils.inception_arg_scope()
inputs = tf.placeholder(tf.float32, (None, im_size, im_size, 3))

with slim.arg_scope(arg_scope):
    net, logits, end_points = inception_v4(inputs)

saver = tf.train.Saver()

saver.restore(sess,'..\checkpoints\inception_v4.ckpt')

for image in sample_images:
    im = Image.open(image)
    im = im.resize((299, 299))
    im = np.array(im)
    im = im.reshape(-1, 299, 299, 3)
    im = 2. * (im / 255.) - 1.
    logit_values = sess.run(logits, feed_dict={inputs: im})
    print(np.max(logit_values))
    print(np.argmax(logit_values))

In the code, I'm testing the network with a horse.在代码中,我正在用马测试网络。 This is the picture.这是图片。 在此处输入图片说明

With the current preprocessing, color channels from -1 to 1, the network thinks that this horse is a bathing cap.以目前的预处理,颜色通道从-1到1,网络认为这匹马是浴帽。 For scaling from 0 to 1, it becomes a bittern, apparently a small bird.对于从 0 到 1 的缩放,它变成了一种卤水,显然是一只小鸟。 I've used this table to find out the predicted classes: https://gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57我用这个表来找出预测的类: https : //gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57

I've also checked more than one image.我还检查了不止一张图片。 The network is consistently off.网络一直处于关闭状态。

What is going wrong ?出了什么问题?

I think you used the wrong synsets for Imagenet.我认为您对 Imagenet 使用了错误的同义词集。 To be specific, the one you used is 2012 version.具体来说,您使用的是2012版。 You may try these two: imagenet_lsvrc_2015_synsets.txt and imagenet_metadata .你可以试试这两个: imagenet_lsvrc_2015_synsets.txtimagenet_metadata

For example, if your output is 340, then 340->n02389026-> sorrel例如,如果你的输出是 340,那么 340->n02389026-> sorrel

I would agree for the wrong synset, it can be automatically downloaded with the imagenet file, that way your are sure to have the correct one:我同意错误的同义词集,它可以与 imagenet 文件一起自动下载,这样你肯定有正确的:

from datasets import imagenet
names = imagenet.create_readable_names_for_imagenet_labels()

print(names[0])

when you import the inceptions,当您导入初始时,

net, logits, end_points = inception_v4(inputs)

should be应该

logits, end_points = inception_v4(inputs, is_training=False, dropout_keep_prob=1.0) 

for inference用于推理

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM