简体   繁体   中英

Using MXNet pre-trained image classification model in Python

I am trying to archieve everything described for R in Python 3. But so far, I am not getting any further.

The tutorial in R is described here: http://mxnet.readthedocs.org/en/latest/R-package/classifyRealImageWithPretrainedModel.html

How can I do the same in Python? Using the following model: https://github.com/dmlc/mxnet-model-gallery/blob/master/imagenet-1k-inception-bn.md

Kind regards, Kevin

At the moment, you can do much more things in mxnet using Python than using R. I am using Gluon API, which makes writing code even simpler, and it allows to load pretrained models.

The model that is used in the tutorial you refer to is an Inception model . The list of all available pretrained models can be found here .

The rest of actions in the tutorial is data normalization and augmentation. You can do the normalization of the new data similar to how they normalize it on the API page:

image = image/255
normalized = mx.image.color_normalize(image,
                                      mean=mx.nd.array([0.485, 0.456, 0.406]),
                                      std=mx.nd.array([0.229, 0.224, 0.225]))

The list of possible augmentation is available here .

Here is the runnable example for you. I did only one augmentation, and you can add more parameters to mx.image.CreateAugmenter if you want to do more of them:

%matplotlib inline
import mxnet as mx
from mxnet.gluon.model_zoo import vision
from matplotlib.pyplot import imshow

def plot_mx_array(array, clip=False):
    """
    Array expected to be 3 (channels) x heigh x width, and values are floats between 0 and 255.
    """
    assert array.shape[2] == 3, "RGB Channel should be last"
    if clip:
        array = array.clip(0,255)
    else:
        assert array.min().asscalar() >= 0, "Value in array is less than 0: found " + str(array.min().asscalar())
        assert array.max().asscalar() <= 255, "Value in array is greater than 255: found " + str(array.max().asscalar())
    array = array/255
    np_array = array.asnumpy()
    imshow(np_array)


inception_model = vision.inception_v3(pretrained=True)

with open("/Volumes/Unix/workspace/MxNet/2018-02-20T19-43-45/types_of_data_augmentation/output_4_0.png", 'rb') as open_file:
    encoded_image = open_file.read()
    example_image = mx.image.imdecode(encoded_image)
    example_image = example_image.astype("float32")
    plot_mx_array(example_image)


augmenters = mx.image.CreateAugmenter(data_shape=(1, 100, 100))

for augementer in augmenters:
    example_image = augementer(example_image)

plot_mx_array(example_image)

example_image = example_image / 255
normalized_image = mx.image.color_normalize(example_image,
                                      mean=mx.nd.array([0.485, 0.456, 0.406]),
                                      std=mx.nd.array([0.229, 0.224, 0.225]))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM