简体   繁体   中英

Can't instantiate a Keras model when batch_normalization is used

I am not sure what I am doing wrong but I am following the code from a book to create a GAN model, and during instantiation the Python shell is just freezing. The code is actually a subset of some code from a book, but the book code also fails to create a model.

If I comment out the batch_norm however I can instantiate a model.

Here:

https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras/blob/master/chapter4-gan/dcgan-mnist-4.2.1.py

Docs: https://keras.io/layers/normalization/

from keras.layers import Activation, Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.optimizers import RMSprop
from keras.models import Model
from keras.datasets import mnist
from keras.models import load_model
import keras

import numpy as np
import math
import matplotlib.pyplot as plt
import os
import argparse




def generator_model(inputs, image_size, verbose = True):
    """Generator Model

    args
    =======
    inputs = input layer
    image_size = size of image dimension (299? 480? 28?etc)

    """

    #resized dependent on how many Conv2d Transpore

    print("build generator model")

    image_resize = image_size // 4 
    kernel_size = 5
    layer_filters = [128, 64] #first two convs
    final_layer_filters = [32, 1] # last two conbs

    x= inputs
    x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
    print(x)

    for filter_ in layer_filters:
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filter_,
                            kernel_size=kernel_size,
                            strides=2,
                            padding='same')(x)


    print("built first part")
    for filter_ in final_layer_filters:
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filter_,
                            kernel_size=kernel_size,
                            strides=1,
                            padding='same')(x)

    x = Activation('sigmoid')(x)
    print("finised building")
    generator = Model(inputs, x, name='generator')
    if verbose:
        print(generator.summary())
    return generator






print(keras.__version__) #2.24
z_size = 100
img_size = 28
gen_input =  Input(shape= (z_size,), name='gen_input')
generator = generator_model(gen_input, img_size)

Shell outputs the following and while still running, it doesn't finish running the script, it's just at a standstill:

2.2.4
build generator model
Tensor("reshape_1/Reshape:0", shape=(?, 7, 7, 128), dtype=float32)

I tried your code in google colab. The following is generated. I think it's not a problem of the code. You may check other problem, eg setting.

    Using TensorFlow backend.
    2.2.4
    build generator model
    WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
    Instructions for updating:
    Colocations handled automatically by placer.
    Tensor("reshape_1/Reshape:0", shape=(?, 7, 7, 128), dtype=float32)
    built first part
    finised building
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    gen_input (InputLayer)       (None, 100)               0
    _________________________________________________________________
    dense_1 (Dense)              (None, 6272)              633472
    _________________________________________________________________
    reshape_1 (Reshape)          (None, 7, 7, 128)         0
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 7, 7, 128)         512
    _________________________________________________________________
    activation_1 (Activation)    (None, 7, 7, 128)         0
    _________________________________________________________________
    conv2d_transpose_1 (Conv2DTr (None, 14, 14, 128)       409728
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 14, 14, 128)       512
    _________________________________________________________________
    activation_2 (Activation)    (None, 14, 14, 128)       0
    _________________________________________________________________
    conv2d_transpose_2 (Conv2DTr (None, 28, 28, 64)        204864
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 28, 28, 64)        256
    _________________________________________________________________
    activation_3 (Activation)    (None, 28, 28, 64)        0
    _________________________________________________________________
    conv2d_transpose_3 (Conv2DTr (None, 28, 28, 32)        51232
    _________________________________________________________________
    batch_normalization_4 (Batch (None, 28, 28, 32)        128
    _________________________________________________________________
    activation_4 (Activation)    (None, 28, 28, 32)        0
    _________________________________________________________________
    conv2d_transpose_4 (Conv2DTr (None, 28, 28, 1)         801
    _________________________________________________________________
    activation_5 (Activation)    (None, 28, 28, 1)         0
    =================================================================
    Total params: 1,301,505
    Trainable params: 1,300,801
    Non-trainable params: 704
    _________________________________________________________________
            None

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM