简体   繁体   中英

keras model.predict_generator() not returning the correct number of instances

I have followed the following link to learn to use a generator for keras model to fit_generator on. https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly One problem I have encountered is that, when I called the model.predict_generator() on some test data generator, length of the returned value is not the same as I have sent in the generator. My test data is of length 229431, and I use a batch_size of 256, and when I define __len__ function in the generator class in the following way:

class DataGenerator(keras.utils.Sequence):
    """A simple generator"""

    def __init__(self, list_IDs, labels, dim, dim_label, batch_size=512, shuffle=True, is_training=True):
        """Initialization"""
        self.list_IDs = list_IDs
        self.labels = labels
        self.dim = dim
        self.dim_label = dim_label
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.is_training = is_training
        self.on_epoch_end()

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data"""
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        list_labels_temp = [self.labels[k] for k in indexes]

        # Generate data
        result = self.__data_generation(list_IDs_temp, list_labels_temp, self.is_training)
        if self.is_training:
            X, y = result
            return X, y
        else:
            # only return X when test
            X = result
            return X

    def on_epoch_end(self):
        """Updates indexes after each epoch"""
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp, list_labels_temp, is_training):
        """Generates data containing batch_size samples"""
        # Initialization
        # X is a list of np.array
        X = np.empty((self.batch_size, *self.dim))
        if is_training:
            # y could have multiple columns
            y = np.empty((self.batch_size, *self.dim_label), dtype=int)

        # Generate data
        for i, (ID, label) in enumerate(zip(list_IDs_temp, list_labels_temp)):
            # Store sample
            X[i,] = np.load(ID)
            if is_training:
                # Store class
                y[i,] = np.load(label)
        if is_training:
            return X, y
        else:
            return X

The returned length of my predicted value is 229632. Here is the code of predict :

test_generator = DataGenerator(partition, labels, is_training=False, **self.params)
        predict_raw = self.model.predict_generator(generator=test_generator, workers=12, verbose=2)

I figured that 229632 / 256 = 897 which is the length of my generator, when I modify the __len__ method of DataGenerator to return int(np.ceil(len(self.list_IDs) / self.batch_size)) , I get 229376 predicted values, 229376/256 = 896, which is the correct number of length. But what I have passed to the generator is 229431 samples.

And I think in __getitem__ method, when running on the last batch, it should only get the less than 256 samples to test automatically. But appearently it is not the case, so how can I make sure the model predict the right number of samples?

For the last batch, the indexes calculated in the method __getitem__ don't have the correct size. To predict the right number of samples, the indexes should be defined as follow (see post ):

def __getitem__(self, index):
    """Generate one batch of data"""
    idx_min = idx*self.batch_size
    idx_max = min(idx_min + self.batch_size, len(self.list_IDs))
    indexes = self.indexes[idx_min: idx_max]

    ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM