简体   繁体   中英

Loading numpy weights in TensorFlow 2.0

I have a neural network architecture for MNIST dataset as follows-

def create_nn():
    """
    Function to create NN model for MNIST
    classification using 300 100 architecture
    """
    model = Sequential()
    model.add(l.InputLayer(input_shape = (784, )))
    model.add(Flatten())
    model.add(Dense(units = 300, activation='relu', kernel_initializer = tf.initializers.GlorotUniform()))
    # model.add(l.Dropout(0.2))
    model.add(Dense(units = 100, activation='relu', kernel_initializer = tf.initializers.GlorotUniform()))
    # model.add(l.Dropout(0.1))
    model.add(Dense(units = num_classes, activation='softmax'))

    # Compile designed NN-
    model.compile(
        loss = tf.keras.losses.categorical_crossentropy,
        # optimizer = 'adam',
        optimizer = tf.keras.optimizers.Adam(lr = 0.001),
        metrics = ['accuracy'])

    return model

# Insantiate a new NN model instance-
orig_model = create_nn()


# Load original weights from when designed model was initialized-
orig_model.load_weights("300_100_MNIST.h5")

type(orig_model.trainable_weights), len(orig_model.trainable_weights)
# (list, 6)

# Insantiate a new NN model instance-
pruned_model = create_nn()

# Load pruned weights AFTER pruning algorithm was applied to prune NN-
pruned_model.load_weights("300_100_Pruned_Model.h5")

Now, I create a list where I process the weights according to some criterion as follows-

# List to store extracted weights-
weight_extracted = []

for orig_wts, pruned_wts in zip(orig_model.trainable_weights, pruned_model.trainable_weights):
    c = np.where(pruned_wts == 0, pruned_wts, orig_wts)
    weight_extracted.append(c)
    del c


len(weight_extracted)
# 6

How can I use the weights/biases in list of numpy arrays 'weight_extracted' to load weights into a NN as defined above?

Thanks!

1-You can transform weight_extracted=[] list to array

2- Save this array as a h5 file with h5py module

3- Load your extracted weights again and train your network with your new weights!

These are my steps, if there is a misleading or misunderstanding please let me know.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM