[英]Keras Functional replace intermediate layers
I would like to replace BatchNorm layers with GroupNorm in built-in keras models, eg ResNet50.我想在内置 keras 模型(例如 ResNet50)中用 GroupNorm 替换 BatchNorm 层。 I'm trying to reset nodes' layers to my new layer, however nothing changes when I query a model.summary().我正在尝试将节点的层重置为我的新层,但是当我查询 model.summary() 时没有任何变化。
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
channels = 3
for i,layer in enumerate(model.layers[:]):
if 'bn' in layer.name:
inbound_nodes = layer.inbound_nodes
outbound_nodes = layer.outbound_nodes
new_name = layer.name.replace('bn','gn')
new_layer = tfa.layers.GroupNormalization(channels)
new_layer._name = new_name
for j in range(len(inbound_nodes)):
inbound_nodes[j].layer = new_layer #set end of node to this layer
for k in range(len(outbound_nodes)):
new_layer.outbound_nodes.append(outbound_nodes[k])
layer = new_layer
I've created the following code, doing some changes from this answer in order to make if work for your case:我创建了以下代码,对此答案进行了一些更改,以使您的情况适用:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, Model
model = tf.keras.applications.resnet.ResNet50(include_top=False, weights=None)
print(model.summary())
channels = 64
from keras.models import Model
def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
insert_layer_name=None, position='after'):
# Auxiliary dictionary to describe the network graph
network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}
# Set the input layers of each layer
for layer in model.layers:
for node in layer._outbound_nodes:
layer_name = node.outbound_layer.name
if layer_name not in network_dict['input_layers_of']:
network_dict['input_layers_of'].update(
{layer_name: [layer.name]})
else:
network_dict['input_layers_of'][layer_name].append(layer.name)
# Set the output tensor of the input layer
network_dict['new_output_tensor_of'].update(
{model.layers[0].name: model.input})
# Iterate over all layers after the input
model_outputs = []
for layer in model.layers[1:]:
# Determine input tensors
layer_input = [network_dict['new_output_tensor_of'][layer_aux]
for layer_aux in network_dict['input_layers_of'][layer.name]]
if len(layer_input) == 1:
layer_input = layer_input[0]
# Insert layer if name matches
if (layer.name).endswith(layer_regex):
if position == 'replace':
x = layer_input
else:
raise ValueError('position must be: replace')
new_layer = insert_layer_factory()
new_layer._name = '{}_{}'.format(layer.name, new_layer.name)
x = new_layer(x)
# print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name, layer.name, position))
else:
x = layer(layer_input)
# Set new output tensor (the original one, or the one of the inserted
# layer)
network_dict['new_output_tensor_of'].update({layer.name: x})
# Save tensor in output list if it is output in initial model
if layer_name in model.output_names:
model_outputs.append(x)
return Model(inputs=model.inputs, outputs=model_outputs)
def replace_layer():
return tfa.layers.GroupNormalization(channels)
model = insert_layer_nonseq(model, 'bn', replace_layer, position="replace")
Note : I've changed your channels
variable from 3 to 64 for the following reason.注意:出于以下原因,我已将您的channels
变量从 3 更改为 64。
From the documentation of the argument group
:从参数group
的文档中:
Integer, the number of groups for Group Normalization. Integer,Group Normalization 的组数。 Can be in the range [1, N] where N is the input dimension.可以在 [1, N] 范围内,其中 N 是输入维度。 The input dimension must be divisible by the number of groups.输入维度必须能被组数整除。 Defaults to 32.默认为 32。
You should choose the most appropriate one.你应该选择最合适的一个。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.