简体   繁体   English

如何使用tensorflow成对导入jpg和npy文件以进行深度学习?

[英]How to import jpg and npy files in pair for deep learning with tensorflow?

I have 1500 RGB files(.jpg) and 1500 feature map values(.npy). 我有1500个RGB文件(.jpg)和1500个特征图值(.npy)。 I want to use them as a dataset for my deep learning project. 我想将它们用作深度学习项目的数据集。 I am using tensorflow 1.12. 我正在使用tensorflow 1.12。

I wrote them into a .tfrecords file using the tf.Example. 我使用tf.Example将它们写入.tfrecords文件。 Here is the code I used to import this file with tf.data(Thanks to Uday's comment). 这是我用来通过tf.data导入该文件的代码(感谢Uday的评论)。

import tensorflow as tf
import numpy as np
import pdb

IMAGE_HEIGHT = 228
IMAGE_WIDTH = 304

def tfdata_generator(tfrname, is_training, batch_size):
    '''Construct a data generator using tf.Dataset'''
    ## You can write your own parse function
    def parse_function(example):

    features = tf.parse_single_example(example, features={

        'image_raw': tf.FixedLenFeature([], tf.string, default_value=""),
        'hint_raw': tf.FixedLenFeature([], tf.string, default_value="")
        })
    image = features['image_raw']
    hint = features['hint_raw']

    image = tf.decode_raw(image, tf.uint8)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])

    hint = tf.decode_raw(hint, tf.uint8)
    hint = tf.cast(hint, tf.float32)
    hint = tf.reshape(hint, [8, 10, 1024])

    return image, hint

dataset = tf.data.TFRecordDataset(tfrname)
#pdb.set_trace()
if is_training:
    dataset = dataset.shuffle(100)  # depends on sample size
#pdb.set_trace()
# Transform and batch data at the same time
dataset = dataset.apply(tf.data.experimental.map_and_batch(parse_function, 
        8, num_parallel_batches=4)) # cpu cores

dataset = dataset.repeat(-1)
dataset = dataset.prefetch(2)
return dataset

I set the batch_size to be 8. But when I did the debugging, the shape of the dataset is 我将batch_size设置为8。但是当我进行调试时,数据集的形状为

((?, 228, 304, 3), (?, 8, 10, 1024)), types: (tf.float32, tf.float32)

Is this correct? 这个对吗? Is this code wrong? 此代码是否错误? Or there are mistakes when I making the tfrecords?. 还是在制作tfrecords时出错?

you can use code like below, 您可以使用如下代码,

def tfdata_generator(images, labels, is_training, batch_size=32):
   '''Construct a data generator using tf.Dataset'''
   ## You can write your own parse function
   def parse_function(filename, label):
        image_string = tf.read_file(filename)
        image = tf.image.decode_jpeg(image_string)
        image = tf.image.convert_image_dtype(image, tf.float32)
        y = tf.one_hot(tf.cast(label, tf.uint8), 16)
    return image, y

    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    if is_training:
        dataset = dataset.shuffle(1000)  # depends on sample size

    # Transform and batch data at the same time
    dataset = dataset.apply(tf.data.experimental.map_and_batch( parse_function, 
            batch_size,num_parallel_batches=6,  # cpu cores
        drop_remainder=True if is_training else False))
    dataset = dataset.repeat()
    dataset = dataset.prefetch(no_of_prefetch_needed)
return dataset

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 深度学习中的文本生成 Tensorflow - Text Generation in Deep Learning Tensorflow 深度学习 Tensorflow 重塑不起作用? - Deep Learning Tensorflow reshape is not working? 如何在 Google Cloud 深度学习 VM 上安装 tensorflow-transform? - How to install tensorflow-transform on a Google Cloud Deep Learning VM? 如何在 Tensorflow 的深度学习期间测量每批次的训练时间? - How to measure training time per batches during Deep Learning in Tensorflow? 如何在 Matlab 中读取 .npy 文件 - How to read .npy files in Matlab Tensorflow:如何将“.npy”文件加载到网络 - Tensorflow: how to load a ".npy" file to a net 如何从包含文件名的列表中加载来自 tensorflow 数据管道中不同目录的.npy 文件? - How to load .npy files from different directories in tensorflow data pipeline from a list containing filenames? TENSORFLOW的深度学习:保存和加载模型的问题 - Deep learning with TENSORFLOW: Issues with saving and loading models 在TensorFlow中使用广泛和深入的学习网络进行预测 - Prediction using wide and deep learning network in TensorFlow Deep Q - 在 Python 中使用 Tensorflow 学习 Cartpole - Deep Q - Learning for Cartpole with Tensorflow in Python
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM