简体   繁体   中英

How to remove multiple layers from pretrained ResNet50V2 Keras model

I'm trying to remove multiple layers form a pre-trained Keras model ( ResNet50V2 ), but no matter what I do it's not working. I've read countless other questions on stack overflow , github issues , and forum posts related to this topic in the past month, and I still can't make it work... So I'll ask directly. What might I be doing wrong?

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.models import ModelCatalog
tf = try_import_tf()

def resnet_core(x):
    x = tf.keras.applications.resnet_v2.preprocess_input(x)
    resnet = tf.keras.applications.ResNet50V2(
        include_top=False,
        weights="imagenet",
    )
    remove_n = 130
    for i in range(remove_n):
        resnet._layers.pop()
        print(len(resnet._layers))
    s = tf.keras.models.Model(resnet.input, resnet._layers[-1].output, name='resnet-core')
    for layer in s.layers:
        print('adding layer',layer.name)
    for layer in s.layers[:]:
        layer.trainable = False
    s.build(None)

    return s(x)

class ImpalaCNN(TFModelV2):

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        inputs = tf.keras.layers.Input(shape=obs_space.shape, name="observations")
        x = inputs
        x = resnet_core(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Dense(units=256, activation="relu", name="hidden")(x)

        logits = tf.keras.layers.Dense(units=num_outputs, name="pi")(x)
        value = tf.keras.layers.Dense(units=1, name="vf")(x)

        self.base_model = tf.keras.Model(inputs, [logits, value])
        self.register_variables(self.base_model.variables)

    def forward(self, input_dict, state, seq_lens):
        obs = tf.cast(input_dict["obs"], tf.float32)
        logits, self._value = self.base_model(obs)
        return logits, state

    def value_function(self):
        return tf.reshape(self._value, [-1])


# Register model in ModelCatalog
ModelCatalog.register_custom_model("impala_cnn_tf", ImpalaCNN)

The error I'm getting is:

  ...
  File "/Users/manu/anaconda3/envs/procgen/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 376, in __init__
    self._build_policy_map(policy_dict, policy_config)
  File "/Users/manu/anaconda3/envs/procgen/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 859, in _build_policy_map
    policy_map[name] = cls(obs_space, act_space, merged_conf)
  File "/Users/manu/anaconda3/envs/procgen/lib/python3.7/site-packages/ray/rllib/policy/tf_policy_template.py", line 143, in __init__
    obs_include_prev_action_reward=obs_include_prev_action_reward)
  File "/Users/manu/anaconda3/envs/procgen/lib/python3.7/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 163, in __init__
    framework="tf")
  File "/Users/manu/anaconda3/envs/procgen/lib/python3.7/site-packages/ray/rllib/models/catalog.py", line 317, in get_model_v2
    registered))
ValueError: It looks like variables {<tf.Variable 'default_policy/
conv4_block4_1_conv/kernel:0' ... } 
were created as part of <impala_cnn_tf.ImpalaCNN object at 
0x19a8ccc90> but does not appear in model.variables() 
({<tf.Variable 'default_policy/pi/
kernel:0' shape=(256, 15) dtype=float32> ...}). Did you forget to call
 model.register_variables() on the variables in question?

The error seems to indicate some variables from the layers I'm trying to skip were not registered, but that's because I don't want to use them? Any ideas?

More context in case it helps:

Thanks in advance!

Rather than popping off layers, you could try accessing the 130th layer from the last layer. Then, you can build a new model using the input of your original model and the output of this layer.

model = tf.keras.models.Model(resnet.input, resnet.layers[-130].output)

This will do essentially the same thing as what you tried but its much easier and safer since you aren't accessing any private properties of the model itself.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM