简体   繁体   中英

use part of Keras Sequential model to predict

In the following code I have defined a Sequential model, that contains two parts conv_encoder and conv_decoder . After training the model I want to use conv_encoder to predict.

How Can I access the trained conv_encoder? (See the last line of the code bellow)

And also I want to it from inside of a fucntion.

from tensorflow.keras.layers import Conv2D, Flatten ,BatchNormalization ,MaxPooling2D
from tensorflow.keras.layers import Reshape, Conv2DTranspose

# network parameters
input_shape = (200,300,3)

batch_size = 20




import tensorflow as tf

import keras
from keras.applications.vgg16 import VGG16
from skimage.feature import hog
import pandas as pd

# Defining a custom metric
def rounded_accuracy(y_true, y_pred):
    return keras.metrics.binary_accuracy(tf.round(y_true), tf.round(y_pred))


# Create the CAE model
def create_cae():

    # Define encoder
    conv_encoder = keras.models.Sequential([
        keras.layers.Conv2D(256, kernel_size=3, padding="SAME", activation="relu", input_shape=[200, 300, 3]),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(128, kernel_size=3, padding="SAME", activation="relu"),
        keras.layers.MaxPool2D(pool_size=2),
        keras.layers.Conv2D(64, kernel_size=3, padding="SAME", activation="relu"),
        keras.layers.BatchNormalization(),
        keras.layers.MaxPool2D(pool_size=2),
 
    ])
  
   # Define decoder
    conv_decoder = keras.models.Sequential([                                     
        keras.layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding="SAME", activation="relu",input_shape=[50, 75, 64]),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2DTranspose(256, kernel_size=3, strides=2, padding="SAME", activation="relu"),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2DTranspose(3,  kernel_size=3, strides=1, padding="SAME", activation="sigmoid"),
 ])
    
    # Define AE
    conv_ae = keras.models.Sequential([conv_encoder, conv_decoder])
    
    # Display the model's architecture
    conv_encoder.summary()
    conv_decoder.summary()
    
    # Compile the model
    conv_ae.compile(loss="mse", optimizer=keras.optimizers.Adam(),
                    metrics=[rounded_accuracy])
    
    return conv_ae

# Create CAE
conv_ae = create_cae()
print("New CAE model created") 

history=conv_ae.fit(gaussian_noise_imgs,sample_train_imgs,epochs=5)
sample_train = conv_ae.predict(gaussian_noise_imgs)
N_bin=50
F= encoder.predict( sample_train )

训练模型后,您可以从训练模型中检索编码器,如下所示:

encoder = keras.Model(inputs=conv_ae.layers[0].input, outputs=conv_ae.layers[0].layers[-1].output)

I spend a bit time on Gaussian model for comapres original image and I think it is one identity you are trying to find the scopes !

You may doing with ( 1 ), return queue !

model = models.Sequential()
for layer in mainQ_outputs: 
    model.add(layer)        
    model.add(tf.keras.layers.Flatten() )
    model.add(tf.keras.layers.Dense(6, activation=tf.nn.softmax))

You also doing with ( 2 ), callbacks !

class custom_callback(tf.keras.callbacks.Callback):
    log_write_dir = log_dir
    val_dir = os.path.join(log_dir, 'validation')
    print('val_dir: ' + val_dir) # F:\models\weights\characters\minds\validation
    tf.summary.create_file_writer(val_dir)
    
    def _val_writer(self):
        if 'val' not in self._writers:
            self._writers['val'] = tf.summary.create_file_writer(val_dir)
        return self._writers['val']
    
    def on_epoch_end(self, epoch, logs={}):
        print(self.model.inputs) # [<KerasTensor: shape=(None, 1, 32, 32, 3) dtype=float32 (created by layer 'input_1')>]
        feature_extractor = tf.keras.Model(inputs=self.model.inputs, outputs=[layer.output for layer in self.model.layers], )
        x = tf.ones((32, 32, 3))
        print(np.asarray(feature_extractor))
        # [<KerasTensor: shape=(None, 1, 32, 32, 3) dtype=float32 (created by layer 'input_1')>]
        # <keras.engine.functional.Functional object at 0x000001450EFC4670>
        input('Press AnyKey!')

custom_callback = custom_callback()

** You can do with ( 3 ), layer access or references !**

group_1_ShoryuKen_Left = tf.reshape(group_1_ShoryuKen_Left, [1, 1, 48])
predictions = model.predict(group_1_ShoryuKen_Left)

layer_1 = model.get_layer( name="LSTM_32" )
# layer_1 = model.get_layer( name="Dense_64" )  # , index=None
# <keras.layers.core.dense.Dense object at 0x0000023959FB24C0>
print(layer_1)  
print(layer_1.get_weights()[0].shape)
print(layer_1.get_weights()[1].shape)

Input:

image_original = plt.imread( 'C:\\Users\\Jirayu Kaewprateep\\Pictures\\Cats\\samples\\08_resizes.jpg' )
image_original = tf.keras.preprocessing.image.img_to_array(image_original)
image_original = tf.expand_dims(image_original, 0)

image = tfa.image.gaussian_filter2d( image_original, (3, 3), 1.0, 'REFLECT', 0, name='gaussian_filter2d' )
history=conv_ae.fit(image,image_original,epochs=5)
sample_train = conv_ae.predict(image)
N_bin=50

F= conv_encoder.predict( sample_train )
print(F)

Output:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 conv2d (Conv2D)             (None, 184, 136, 256)     7168

 batch_normalization (BatchN  (None, 184, 136, 256)    1024
 ormalization)

 conv2d_1 (Conv2D)           (None, 184, 136, 128)     295040

 max_pooling2d (MaxPooling2D  (None, 92, 68, 128)      0
 )

 conv2d_2 (Conv2D)           (None, 92, 68, 64)        73792

 batch_normalization_1 (Batc  (None, 92, 68, 64)       256
 hNormalization)

 max_pooling2d_1 (MaxPooling  (None, 46, 34, 64)       0
 2D)

=================================================================
Total params: 377,280
Trainable params: 376,640
Non-trainable params: 640
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 conv2d_transpose (Conv2DTra  (None, 92, 68, 128)      73856
 nspose)

 batch_normalization_2 (Batc  (None, 92, 68, 128)      512
 hNormalization)

 conv2d_transpose_1 (Conv2DT  (None, 184, 136, 256)    295168
 ranspose)

 batch_normalization_3 (Batc  (None, 184, 136, 256)    1024
 hNormalization)

 conv2d_transpose_2 (Conv2DT  (None, 184, 136, 3)      6915
 ranspose)

=================================================================
Total params: 377,475
Trainable params: 376,707
Non-trainable params: 768
_________________________________________________________________
rounded_accuracy:
<function rounded_accuracy at 0x0000027BDFF65160>
New CAE model created
Epoch 1/5
2022-03-09 14:03:44.440068: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
1/1 [==============================] - 3s 3s/step - loss: 0.1221 - accuracy: 0.3148
Epoch 2/5
1/1 [==============================] - 0s 31ms/step - loss: 0.0645 - accuracy: 0.3573
Epoch 3/5
1/1 [==============================] - 0s 31ms/step - loss: 0.0512 - accuracy: 0.4318
Epoch 4/5
1/1 [==============================] - 0s 31ms/step - loss: 0.0439 - accuracy: 0.4564
Epoch 5/5
1/1 [==============================] - 0s 31ms/step - loss: 0.0385 - accuracy: 0.4731
[[[[-0.06346278  0.04637233 -0.02261382 ...  0.12098984  0.12070955
    -0.0229062 ]
   [-0.06346278  0.02278995 -0.02261382 ...  0.09531078  0.14106335
    -0.0229062 ]
   [-0.06346278  0.02279395 -0.02261382 ...  0.09496372  0.14070679
    -0.0229062 ]
   ...
   [-0.06346278  0.02212192 -0.02261382 ...  0.09469274  0.13895339
    -0.0229062 ]
   [-0.06346278  0.02219885 -0.02261382 ...  0.09484547  0.13892566
    -0.0229062 ]
   [-0.06346278  0.02447954 -0.02261382 ...  0.09182414  0.13579121
    -0.0229062 ]]

  [[-0.06346278  0.05012855 -0.02261382 ...  0.11684484  0.13889702
    -0.0229062 ]
   [-0.06346278  0.00757721 -0.02261382 ...  0.09060126  0.15116212
    -0.0229062 ]
   [-0.06346278  0.0072485  -0.02261382 ...  0.09044676  0.15090142
    -0.0229062 ]
   ...
   [-0.06346278  0.00701569 -0.02261382 ...  0.08867537  0.14858139
    -0.0229062 ]
   [-0.06346278  0.00706555 -0.02261382 ...  0.08872169  0.14858797
    -0.0229062 ]
   [-0.06346278  0.00225088 -0.02261382 ...  0.08776829  0.14434732
    -0.0229062 ]]

  [[-0.06346278  0.05006471 -0.02261382 ...  0.11678799  0.13892303
    -0.0229062 ]
   [-0.06346278  0.00763592 -0.02261382 ...  0.09031729  0.15117022
    -0.0229062 ]
   [-0.06346278  0.00749015 -0.02261382 ...  0.08962523  0.1507562
    -0.0229062 ]
   ...
   [-0.06346278  0.00702977 -0.02261382 ...  0.08858776  0.14859283
    -0.0229062 ]
   [-0.06346278  0.00706691 -0.02261382 ...  0.0886265   0.14861022
    -0.0229062 ]
   [-0.06346278  0.00223099 -0.02261382 ...  0.08776259  0.14446442
    -0.0229062 ]]

  ...

  [[-0.06346278  0.04950596 -0.02261382 ...  0.11570842  0.13801728
    -0.0229062 ]
   [-0.06346278  0.00701035 -0.02261382 ...  0.09086572  0.15059042
    -0.0229062 ]
   [-0.06346278  0.00680097 -0.02261382 ...  0.09112303  0.15117684
    -0.0229062 ]
   ...
   [-0.06346278  0.00777188 -0.02261382 ...  0.09041971  0.15148866
    -0.0229062 ]
   [-0.06346278  0.00767549 -0.02261382 ...  0.08995029  0.15114011
    -0.0229062 ]
   [-0.06346278  0.00224462 -0.02261382 ...  0.08866782  0.14654928
    -0.0229062 ]]

  [[-0.06346278  0.0488665  -0.02261382 ...  0.11492367  0.13678226
    -0.0229062 ]
   [-0.06346278  0.00693866 -0.02261382 ...  0.08889158  0.14880367
    -0.0229062 ]
   [-0.06346278  0.00687002 -0.02261382 ...  0.08928787  0.14929299
    -0.0229062 ]
   ...
   [-0.06346278  0.00782898 -0.02261382 ...  0.0905409   0.15163733
    -0.0229062 ]
   [-0.06346278  0.00775734 -0.02261382 ...  0.09018187  0.15157235
    -0.0229062 ]
   [-0.06346278  0.00236158 -0.02261382 ...  0.08874947  0.14714159
    -0.0229062 ]]

  [[-0.06346278  0.04358848 -0.02261382 ...  0.1652938   0.13935538
    -0.0229062 ]
   [-0.06346278 -0.00824893 -0.02261382 ...  0.1715379   0.1541569
    -0.0229062 ]
   [-0.06346278 -0.00822318 -0.02261382 ...  0.17140017  0.15413304
    -0.0229062 ]
   ...
   [-0.06346278 -0.00843395 -0.02261382 ...  0.1756423   0.1577251
    -0.0229062 ]
   [-0.06346278 -0.00859166 -0.02261382 ...  0.17557691  0.15755029
    -0.0229062 ]
   [-0.06346278 -0.01292683 -0.02261382 ...  0.18822822  0.1533695
    -0.0229062 ]]]]
...

... 截屏

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM