繁体   English   中英

将 Keras 生成器转换为 Tensorflow 数据集以训练 Resnet50

[英]Convert Keras generator to Tensorflow Dataset to train Resnet50

我正在将 python 代码从 keras 命名空间转换为 tf.keras。 它训练 Resnet50。 新的 Model.fit() 方法找不到适合我的简单生成器的适配器,validation_data 甚至不再支持生成器。 所以我正在尝试使用 tensorflow.data.Dataset.from_generator 方法将其转换为数据集。

图像是灰度的并以原始字节存储——一个字节对应一个像素。 生成器有这样的行

        def __next__( self ):
            return self.next()

        def __call__( self ):
            return self.next()

        def next( self ):
            #reading files
            ...

            resultLabels = numpy.zeros( ( count, len( classes ) ), "float32" )
            resultImages = numpy.zeros( ( count, patchSize, patchSize, 3 ), "float32" )

            #filling result with images and labels
                ...
                fileBytes = numpy.reshape( numpy.fromfile( self.ImageLabelsAndPaths[i][1], "uint8" ), (patchSize, patchSize), "F" ).astype( "float32" )

                imageWithChannels = numpy.zeros( ( patchSize, patchSize, 3 ), "float32" )
                # Because Resnet50 requires RGB images and we have grayscale
                imageWithChannels[:,:,0] = fileBytes
                imageWithChannels[:,:,1] = fileBytes
                imageWithChannels[:,:,2] = fileBytes

                resultImages[i - cursor] = imageWithChannels

            return ( resultImages, resultLabels )

所以 resultImages 是一个长度为 batch_size=16 的数组,其中包含图像像素的 arrays。 Numpy.shape 是 (16, 256, 256, 3) 并且 resultLabels 形状是 (16, 3) - 现在有 3 个类。

接下来我将其转换为数据集

            trainGenerator = FileIterator( "train" )
            trainDataset = tf.data.Dataset.from_generator( trainGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )
            validationGenerator = FileIterator( "validate" )
            validationDataset = tf.data.Dataset.from_generator( validationGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )

但我收到错误

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [[[[185. 185. 185.]
   [158. 158. 158.]
   [145. 145. 145.]
   ...

Dataset.from_generator 的代码示例有一个数组作为元组中的第二项和类似的 output_types=(tf.int64, tf.int64)。 我猜它在那里工作。

尝试将 arrays 添加到 type 导致另一个错误

TypeError: unhashable type: 'list'

我应该改变什么才能让它工作?

好的,在又花了两天时间,试图修复一些真正误导性的错误,并让 python.exe 在它最终工作时崩溃,我能够将我的生成器转换为 tensorflow 数据集。

我无法使其与批处理一起使用,并且 numpy.array 不被 Dataset 接受,因为它在 Dataset 的世界中不是顺序的,并且返回一个元组很重要,不知道如何使用“yield”和“返回数据,标签”的作品。

发电机

        def __iter__(self):
            return self

        def __call__( self ):
            return self

        def __len__(self):
            return self.TotalCount

        def __next__( self ):
            ...
            resultLabel = numpy.zeros( len( classes ), "float32" )
            resultImage = numpy.zeros( ( patchSize, patchSize, 3 ), "float32" )
            # fill those two
            ...

            return (resultImage.tolist(), resultLabel.tolist())

和数据集 + model.fit

            trainGenerator = FileIterator( "train" )
            validationGenerator = FileIterator( "validate" )

            trainDataset = tf.data.Dataset.from_generator( trainGenerator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([patchSize, patchSize, 3]), tf.TensorShape([len(classes)]) ) )
            trainDataset = trainDataset.batch( batchSize )
            validationDataset = tf.data.Dataset.from_generator( validationGenerator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([patchSize, patchSize, 3]), tf.TensorShape([len(classes)]) ) )
            validationDataset = validationDataset.batch( batchSize )


            trainResult = model.fit( x = trainDataset,
                                     epochs = epochsForDenseLayer,
                                     steps_per_epoch = trainGenerator.StepsPerEpoch,
                                     verbose = 2,
                                     validation_data = validationDataset,
                                     validation_steps = validationGenerator.StepsPerEpoch,
                                     validation_freq = 1,
                                     shuffle = False, # already shuffled by generator
                                     workers = cpuCoresCount,
                                     use_multiprocessing = False
                                    )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM