简体   繁体   English

我该如何使用tensor2tensor的stillation.py将知识从教师网络提取到学生网络?

[英]How do I use tensor2tensor's distillation.py to distill the knowledge from a teacher network to student network?

Top Level Problem 顶级问题

I want to use a teacher network and distill its performance/knowledge on a small subset of its power to another simpler model 我想使用教师网络并将其性能/知识仅以其功能的一小部分提炼为另一个更简单的模型

Attempted solution 尝试的解决方案

I am trying to get started with the T2T distillation code. 我正在尝试开始使用T2T蒸馏代码。 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/distillation.py https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/distillation.py

Problems with Attempted solution 尝试解决方案的问题

I am struggling to understand how to use it for my teacher and my student. 我正在努力了解如何为我的老师和学生使用它。 Are there any examples that show how it works I could get running first? 有没有显示它可以首先运行的示例? How do I get it work on an existing model in T2T? 如何在T2T中的现有模型上使用它? How do I get it to work on a model defined in Keras? 如何在Keras中定义的模型上使用它?

I understand how this block is what I'm supposed to use to register my model in T2T. 我了解此块是如何在T2T中注册模型的。 But how do I start the training? 但是,我如何开始培训呢? I searched github for distill_resnet_32_to_15_cifar20x5 , and the only hits were duplicate forks of the T2T repo with no examples of how to use this. 我在github上搜索distill_resnet_32_to_15_cifar20x5 ,唯一的发现是T2T回购的重复分叉,没有如何使用它的示例。

@registry.register_hparams
def distill_resnet_32_to_15_cifar20x5():
  """Set of hyperparameters."""
  hparams = distill_base()
  hparams.teacher_model = "resnet"
  hparams.teacher_hparams = "resnet_cifar_32"
  hparams.student_model = "resnet"
  hparams.student_hparams = "resnet_cifar_15"

  hparams.optimizer_momentum_nesterov = True
  # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
  hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256.
  hparams.student_learning_rate = 0.2 * 128. * 8. / 256.
  hparams.learning_rate_decay_scheme = "piecewise"
  hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000])
  hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001])

  hparams.task_balance = 0.28
  hparams.distill_temperature = 2.0

  hparams.num_classes = 20

  return hparams

Full Code: distillation.py 完整代码:distillation.py

# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Traditional Student-Teacher Distillation."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensor2tensor.layers import common_hparams
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


@registry.register_model
class Distillation(t2t_model.T2TModel):
  """Distillation from a teacher to student network.
  First, a teacher is trained on a task; Second, a student is trained to perform
  the task while matching the teacher's softened outputs. For more details, see
  the paper below.
  In the hparams passed to this model include the desired
  {teacher/student}_model and {teacher/student}_hparams to be used. Also,
  specify the distillation temperature and task-distillation balance.
  Distilling the Knowledge in a Neural Network
  Hinton, Vinyals and Dean
  https://arxiv.org/abs/1503.02531
  """

  def __init__(self,
               hparams,
               mode=tf.estimator.ModeKeys.TRAIN,
               problem_hparams=None,
               data_parallelism=None,
               decode_hparams=None):
    assert hparams.distill_phase in ["train", "distill"]

    if hparams.distill_phase == "train" and hparams.teacher_learning_rate:
      hparams.learning_rate = hparams.teacher_learning_rate
    elif hparams.distill_phase == "distill" and hparams.student_learning_rate:
      hparams.learning_rate = hparams.student_learning_rate

    self.teacher_hparams = registry.hparams(hparams.teacher_hparams)
    self.teacher_model = registry.model(
        hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams,
                               data_parallelism, decode_hparams)
    self.student_hparams = registry.hparams(hparams.student_hparams)
    self.student_model = registry.model(
        hparams.student_model)(self.student_hparams, mode, problem_hparams,
                               data_parallelism, decode_hparams)
    super(Distillation, self).__init__(hparams, mode, problem_hparams,
                                       data_parallelism, decode_hparams)

  def body(self, features):
    hp = self.hparams
    is_distill = hp.distill_phase == "distill"

    targets = features["targets_raw"]
    targets = tf.squeeze(targets, [1, 2, 3])
    one_hot_targets = tf.one_hot(targets, hp.num_classes, dtype=tf.float32)

    # Teacher Network
    with tf.variable_scope("teacher"):
      teacher_outputs = self.teacher_model.body(features)
      tf.logging.info("teacher output shape: %s" % teacher_outputs.get_shape())
      teacher_outputs = tf.reduce_mean(teacher_outputs, axis=[1, 2])
      teacher_logits = tf.layers.dense(teacher_outputs, hp.num_classes)

      teacher_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
          labels=one_hot_targets, logits=teacher_logits)
      outputs = teacher_logits

    if is_distill:
      # Load teacher weights
      tf.train.init_from_checkpoint(hp.teacher_dir, {"teacher/": "teacher/"})
      # Do not train the teacher
      trainable_vars = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
      del trainable_vars[:]

    # Student Network
    if is_distill:
      with tf.variable_scope("student"):
        student_outputs = self.student_model.body(features)
        tf.logging.info(
            "student output shape: %s" % student_outputs.get_shape())
        student_outputs = tf.reduce_mean(student_outputs, axis=[1, 2])
        student_logits = tf.layers.dense(student_outputs, hp.num_classes)

        student_task_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_targets, logits=student_logits)
        teacher_targets = tf.nn.softmax(teacher_logits / hp.distill_temperature)
        student_distill_xent = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=tf.stop_gradient(teacher_targets), logits=student_logits)

        outputs = student_logits

        # Summaries
        tf.summary.scalar("distill_xent", student_distill_xent)

    if not is_distill:
      phase_loss = teacher_task_xent
    else:
      phase_loss = hp.task_balance * student_task_xent
      phase_loss += (1 - hp.task_balance) * student_distill_xent

    losses = {"training": phase_loss}
    outputs = tf.reshape(outputs, [-1, 1, 1, 1, outputs.shape[1]])

    return outputs, losses

  def top(self, body_output, features):
    return body_output


def distill_base():
  """Set of hyperparameters."""
  # Base
  hparams = common_hparams.basic_params1()

  # teacher/student parameters
  hparams.add_hparam("teacher_model", "")
  hparams.add_hparam("teacher_hparams", "")
  hparams.add_hparam("student_model", "")
  hparams.add_hparam("student_hparams", "")

  # Distillation parameters
  # WARNING: distill_phase hparam will be overwritten in /bin/t2t_distill.py
  hparams.add_hparam("distill_phase", None)
  hparams.add_hparam("task_balance", 1.0)
  hparams.add_hparam("distill_temperature", 1.0)
  hparams.add_hparam("num_classes", 10)

  # Optional Phase-specific hyperparameters
  hparams.add_hparam("teacher_learning_rate", None)
  hparams.add_hparam("student_learning_rate", None)

  # Training parameters (stolen from ResNet)
  hparams.batch_size = 128
  hparams.optimizer = "Momentum"
  hparams.optimizer_momentum_momentum = 0.9
  hparams.optimizer_momentum_nesterov = True
  hparams.weight_decay = 1e-4
  hparams.clip_grad_norm = 0.0
  # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
  hparams.learning_rate = 0.4
  hparams.learning_rate_decay_scheme = "cosine"
  # For image_imagenet224, 120k training steps, which effectively makes this a
  # cosine decay (i.e. no cycles).
  hparams.learning_rate_cosine_cycle_steps = 120000
  hparams.initializer = "normal_unit_scaling"
  hparams.initializer_gain = 2.

  return hparams


@registry.register_hparams
def distill_resnet_32_to_15_cifar20x5():
  """Set of hyperparameters."""
  hparams = distill_base()
  hparams.teacher_model = "resnet"
  hparams.teacher_hparams = "resnet_cifar_32"
  hparams.student_model = "resnet"
  hparams.student_hparams = "resnet_cifar_15"

  hparams.optimizer_momentum_nesterov = True
  # (base_lr=0.1) * (batch_size=128*8 (on TPU, or 8 GPUs)=1024) / (256.)
  hparams.teacher_learning_rate = 0.25 * 128. * 8. / 256.
  hparams.student_learning_rate = 0.2 * 128. * 8. / 256.
  hparams.learning_rate_decay_scheme = "piecewise"
  hparams.add_hparam("learning_rate_boundaries", [40000, 60000, 80000])
  hparams.add_hparam("learning_rate_multiples", [0.1, 0.01, 0.001])

  hparams.task_balance = 0.28
  hparams.distill_temperature = 2.0

  hparams.num_classes = 20

  return hparams

Have a look at the training script here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_distill.py 在这里看看训练脚本: https : //github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_distill.py

To run the example you should just need to checkout the T2T repo and run: python bin/t2t_distill.py --model=distillation --hparams= distill_resnet_32_to_15_cifar20x5 --problem=image_cifar100 etc... 要运行该示例,您只需要签出T2T存储库并运行: python bin/t2t_distill.py --model=distillation --hparams= distill_resnet_32_to_15_cifar20x5 --problem=image_cifar100 etc...

My apologies for the lack of documentation, I would welcome your contribution on cleaning things up and making it easier for people to figure out how to use this. 对于缺少文档,我深表歉意,欢迎您为清理工作并让人们更轻松地知道如何使用它所做的贡献。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 在知识蒸馏中,如何并行运行学生和教师模型? - In knowledge distillation, how to run the student and the teacher models in parallel? 如何使用tensor2tensor对文本进行分类? - How to use tensor2tensor to classify text? 无法运行 Tensorflow 的官方 Tensor2Tensor colab notebook - Unable to run Tensorflow's official Tensor2Tensor colab notebook 如何在 python 中对 tensor2tensor 模型进行推理(没有解码二进制文件和 TensorFlow Serving) - How to do inference on a tensor2tensor model in python (without the decoding binary and TensorFlow Serving) 在Tensorflow的C ++ API中,如何使用Eigen Tensor设置Tensorflow Tensor? - In Tensorflow's C++ API, how do I use an Eigen Tensor to set my Tensorflow Tensor? 神经元网络中张量的关系 - Relation of a Tensor in Neuronal Network 尝试使用两个单独的网络时,如何更新张量(权重值)? - How can I update tensor (weight value) trying to use two separate network? 如何将 Mat 重塑为张量以在 C++ 中的深度神经网络中使用? - How can I reshape a Mat to a tensor to use in deep neural network in c++? Tensorflow:如何将Tensor馈送到经过训练的神经网络? - Tensorflow: How to feed a Tensor to a trained neural network? 我怎样才能使Tensor flow train.py使用所有可用的GPU? - How can i make Tensor flow train.py use all the available GPU's?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM