简体   繁体   中英

TF Keras how to get expected input shape when loading a model?

Is it possbible to get the expected input shape from a 'model.h5' file? I have to models for the same dataset but with different options and shapes. The first one expects a dim of (None, 64, 48, 1) and the seconds model need input shape (None, 128, 96, 3). (Note: The width or the height are not fixed and can change when I train again). The channels problem was easy to "fix" (or bypass rather) by just using try: and except because there are only two options (1 for grayscale image and 3 for rgb image):

        channels = self.df["channels"][0]
        file = ""
        try:
            images, src_images, data = self.get_images()
            images = self.preprocess_data(images, channels)
            predictions, file = self.load_model(images, file)
            self.predict_data(src_images, predictions, data)
        except:
            if channels == 1:
                print("Except channels =", channels)
                channels = 3
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)
            else:
                channels = 1
                print("Except channels =", channels)
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)

This workaround however cannot be used for the width and height of an image because there basically unlimited amount of options. Besides that it is rather slow because I read all the data twice and preprocess it twice for no reason.

Is there a way to load the model.h5 file and print the expected input shape in a form like this?:

[None, 128, 96, 3]

I finally found the answer myself.

config = model.get_config() # Returns pretty much every information about your model
print(config["layers"][0]["config"]["batch_input_shape"]) # returns a tuple of width, height and channels

This will output the following:

(None, 128, 96, 3)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM