簡體   English   中英

制作您自己的MNIST數據集(與MNIST格式相同)

[英]Making your own set of MNIST data (identical to MNIST format)

我正在嘗試創建自己的MNIST數據版本。 我已將我的培訓和測試數據轉換為以下文件;

test-images-idx3-ubyte.gz
test-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz

(對於任何有興趣的人,我使用JPG-PNG-to-MNIST-NN-Format做了這個,這似乎讓我接近我的目標。)

但是,這與MNIST數據的文件類型和格式(mnist.pkl.gz)不完全相同。 我知道pkl意味着數據被腌制,但我並不真正了解酸洗數據的過程 - 酸洗是否有特定的順序? 有人可以提供我應該用來挑選我的數據的代碼嗎?

import gzip
import os

import numpy as np
import six
from six.moves.urllib import request

parent = 'http://yann.lecun.com/exdb/mnist'
train_images = 'train-images-idx3-ubyte.gz'
train_labels = 'train-labels-idx1-ubyte.gz'
test_images = 't10k-images-idx3-ubyte.gz'
test_labels = 't10k-labels-idx1-ubyte.gz'
num_train = 17010
num_test = 3010
dim = 32*32


def load_mnist(images, labels, num):
    data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim))
    target = np.zeros(num, dtype=np.uint8).reshape((num, ))

    with gzip.open(images, 'rb') as f_images,\
            gzip.open(labels, 'rb') as f_labels:
        f_images.read(16)
        f_labels.read(8)
        for i in six.moves.range(num):
            target[i] = ord(f_labels.read(1))
            for j in six.moves.range(dim):
                data[i, j] = ord(f_images.read(1))

    return data, target


def download_mnist_data():

    print('Converting training data...')
    data_train, target_train = load_mnist(train_images, train_labels,
                                          num_train)
    print('Done')
    print('Converting test data...')
    data_test, target_test = load_mnist(test_images, test_labels, num_test)
    mnist = {}
    mnist['data'] = np.append(data_train, data_test, axis=0)
    mnist['target'] = np.append(target_train, target_test, axis=0)

    print('Done')
    print('Save output...')
    with open('mnist.pkl', 'wb') as output:
        six.moves.cPickle.dump(mnist, output, -1)
    print('Done')
    print('Convert completed')


def load_mnist_data():
    if not os.path.exists('mnist.pkl'):
        download_mnist_data()
    with open('mnist.pkl', 'rb') as mnist_pickle:
        mnist = six.moves.cPickle.load(mnist_pickle)
    return mnist
download_mnist_data()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM