简体   繁体   中英

Teacher-Student System: Training Student with Top-k Hypotheses List

I want to configure a teacher-student system, where a teacher seq2seq model generates a top- k list of hypotheses, which are used to train a student seq2seq model.

My plan to implement this, is to batch the teacher hypotheses, meaning that the teacher outputs a tensor with batch axis length of k * B , where B is the input batch axis length. The output batch tensor, now contains k hypotheses for each sequence in the input batch tensor, sorted by position of the associated input sequence in the input batch.
This tensor is set as the student's training target. However, the student's batch tensor still has a batch axis length of B , so I utilize tf.repeat to repeat the sequences in the output tensor of the student's encoder k times, before feeding that tensor into the student's decoder.

For debugging purposes I made the simplification to repeat the single best hypothesis of the teacher, for now, before I'm going to implement the top- k list selection.

Here is a summary of my config file:

[...]

# Variables:

student_target = "teacher_hypotheses_stack"

[...]

# Custom repeat function:

def repeat(source, src_name="source", **kwargs):
    import tensorflow as tf

    input = source(0)
    input = tf.Print(input, [src_name, "in", input, tf.shape(input)])

    output = tf.repeat(input, repeats=3, axis=1)
    output = tf.Print(output, [src_name, "out", output, tf.shape(output)])

    return output

def repeat_t(source, **kwargs):
    return repeat(source, "teacher")


def repeat_s(source, **kwargs):
    return repeat(source, "student")


[...]

# Configuration of the teacher + repeating of its output

**teacher_network(), # The teacher_network is a encoder-decoder seq2seq model. The teacher performs search during training and is untrainable
"teacher_stack": {
    "class": "eval", "from": ["teacher_decision"], "eval": repeat_t,
    "trainable": False
    # "register_as_extern_data": "teacher_hypotheses_stack"
},
"teacher_stack_reinterpreter": { # This is an attempt to explicitly (re-)select the batch axis. It is probably unecessary...
    "class": "reinterpret_data",
    "set_axes": {"B": 1, "T": 0},
    "enforce_time_major": True,
    "from": ["teacher_stack"],
    "trainable": False,
    "register_as_extern_data": "teacher_hypotheses_stack"
}

[...]

# Repeating of the student's encoder ouput + configuration of its decoder

"student_encoder": {"class": "copy", "from": ["student_lstm6_fw", "student_lstm6_bw"]},  # dim: EncValueTotalDim
"student_encoder_repeater": {"class": "eval", "from": ["student_encoder"], "eval": repeat},
"student_encoder_stack": {  # This is an attempt to explicitly (re-)select the batch axis. It is probably unecessary...
    "class": "reinterpret_data",
    "set_axes": {"B": 1, "T": 0},
    "enforce_time_major": True,
    "from": ["student_encoder_repeater"]
},

"student_enc_ctx": {"class": "linear", "activation": None, "with_bias": True, "from": ["student_encoder_stack"], "n_out": EncKeyTotalDim},  # preprocessed_attended in Blocks
"student_inv_fertility": {"class": "linear", "activation": "sigmoid", "with_bias": False, "from": ["student_encoder_stack"], "n_out": AttNumHeads},
"student_enc_value": {"class": "split_dims", "axis": "F", "dims": (AttNumHeads, EncValuePerHeadDim), "from": ["student_encoder_stack"]},  # (B, enc-T, H, D'/H)

"model1_output": {"class": "rec", "from": [], 'cheating': config.bool("cheating", False), "unit": {
    'output': {'class': 'choice', 'target': student_target, 'beam_size': beam_size, 'cheating': config.bool("cheating", False), 'from': ["model1_output_prob"], "initial_output": 0},
    "end": {"class": "compare", "from": ["output"], "value": 0},
    'model1_target_embed': {'class': 'linear', 'activation': None, "with_bias": False, 'from': ['output'], "n_out": target_embed_size, "initial_output": 0},  # feedback_input
    "model1_weight_feedback": {"class": "linear", "activation": None, "with_bias": False, "from": ["prev:model1_accum_att_weights"], "n_out": EncKeyTotalDim, "dropout": 0.3},
    "model1_s_transformed": {"class": "linear", "activation": None, "with_bias": False, "from": ["model1_s"], "n_out": EncKeyTotalDim, "dropout": 0.3},
    "model1_energy_in": {"class": "combine", "kind": "add", "from": ["base:student_enc_ctx", "model1_weight_feedback", "model1_s_transformed"], "n_out": EncKeyTotalDim},
    "model1_energy_tanh": {"class": "activation", "activation": "tanh", "from": ["model1_energy_in"]},
    "model1_energy": {"class": "linear", "activation": None, "with_bias": False, "from": ["model1_energy_tanh"], "n_out": AttNumHeads},  # (B, enc-T, H)
    "model1_att_weights": {"class": "softmax_over_spatial", "from": ["model1_energy"]},  # (B, enc-T, H)
    "model1_accum_att_weights": {"class": "eval", "from": ["prev:model1_accum_att_weights", "model1_att_weights", "base:student_inv_fertility"],
                                 "eval": "source(0) + source(1) * source(2) * 0.5", "out_type": {"dim": AttNumHeads, "shape": (None, AttNumHeads)}},
    "model1_att0": {"class": "generic_attention", "weights": "model1_att_weights", "base": "base:student_enc_value"},  # (B, H, V)
    "model1_att": {"class": "merge_dims", "axes": "except_batch", "from": ["model1_att0"]},  # (B, H*V)
    "model1_s": {"class": "rnn_cell", "unit": "LSTMBlock", "from": ["prev:model1_target_embed", "prev:model1_att"], "n_out": 1000, "dropout": 0.3},  # transform
    "model1_readout_in": {"class": "linear", "from": ["model1_s", "prev:model1_target_embed", "model1_att"], "activation": None, "n_out": 1000, "dropout": 0.3},  # merge + post_merge bias
    "model1_readout": {"class": "reduce_out", "mode": "max", "num_pieces": 2, "from": ["model1_readout_in"]},
    "model1_output_prob": {
        "class": "softmax", "from": ["model1_readout"], "dropout": 0.3,
        "target": student_target,
        "loss": "ce", "loss_opts": {"label_smoothing": 0.1}
    }
}, "target": student_target},

[...]

Running this config will print the following error message to the console:

[...]

Create Adam optimizer.
Initialize optimizer (default) with slots ['m', 'v'].
These additional variable were created by the optimizer: [<tf.Variable 'optimize/beta1_power:0' shape=() dtype=float32_ref>, <tf.Variable 'optimize/beta2_power:0' shape=() dtype=float32_ref>].
[teacher][in][[6656 6657 6658...]...][17 23]
[teacher][out][[6656 6656 6656...]...][17 69]
TensorFlow exception: assertion failed: [x.shape[0] != y.shape[0]] [69 17] [23]
     [[node objective/loss/error/sparse_labels/check_dim_equal/assert_equal_1/Assert/Assert (defined at home/philipp/Documents/bachelor-thesis/returnn/returnn-venv/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]

[...]

Execute again to debug the op inputs...
FetchHelper(0): <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/Shape_1_1:0' shape=(1,) dtype=int32> = shape (1,), dtype int32, min/max 23/23, ([23])
FetchHelper(0): <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/assert_equal_1/Assert/Assert/data_0_1:0' shape=() dtype=string> = bytes(b'x.shape[0] != y.shape[0]')
FetchHelper(0): <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/Shape_2:0' shape=(2,) dtype=int32> = shape (2,), dtype int32, min/max 17/69, ([69 17])
FetchHelper(0): <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/assert_equal_1/All_1:0' shape=() dtype=bool> = bool_(False)
[teacher][in][[6656 6657 6658...]...][17 23]
[teacher][out][[6656 6656 6656...]...][17 69]
Op inputs:
  <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/assert_equal_1/All:0' shape=() dtype=bool>: bool_(False)
  <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/assert_equal_1/Assert/Assert/data_0:0' shape=() dtype=string>: bytes(b'x.shape[0] != y.shape[0]')
  <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/Shape:0' shape=(2,) dtype=int32>: shape (2,), dtype int32, min/max 17/69, ([69 17])
  <tf.Tensor 'objective/loss/error/sparse_labels/check_dim_equal/Shape_1:0' shape=(1,) dtype=int32>: shape (1,), dtype int32, min/max 23/23, ([23])
Step meta information:
{'seq_idx': [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22],
 'seq_tag': ['seq-0','seq-1','seq-2','seq-3','seq-4','seq-5','seq-6','seq-7','seq-8','seq-9','seq-10','seq-11','seq-12','seq-13','seq-14','seq-15','seq-16','seq-17','seq-18','seq-19','seq-20','seq-21','seq-22']}
Feed dict:
  <tf.Tensor 'extern_data/placeholders/data/data:0' shape=(?, ?, 80) dtype=float32>: shape (23, 42, 80), dtype float32, min/max -0.5/0.4, mean/stddev -0.050000004/0.28722814, Data(name='data', shape=(None, 80), batch_shape_meta=[B,T|'time:var:extern_data:data',F|80])
  <tf.Tensor 'extern_data/placeholders/data/data_dim0_size:0' shape=(?,) dtype=int32>: shape (23,), dtype int32, min/max 42/42, ([42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42 42])
  <tf.Tensor 'extern_data/placeholders/source_text/source_text:0' shape=(?, ?, 512) dtype=float32>: shape (23, 13, 512), dtype float32, min/max -0.5/0.4, mean/stddev -0.050011758/0.28722063, Data(name='source_text', shape=(None, 512), available_for_inference=False, batch_shape_meta=[B,T|'time:var:extern_data:source_text',F|512])
  <tf.Tensor 'extern_data/placeholders/source_text/source_text_dim0_size:0' shape=(?,) dtype=int32>: shape (23,), dtype int32, min/max 13/13, ([13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13 13])
  <tf.Tensor 'extern_data/placeholders/target_text/target_text:0' shape=(?, ?) dtype=int32>: shape (23, 17), dtype int32, min/max 6656/6694, Data(name='target_text', shape=(None,), dtype='int32', sparse=True, dim=35209, available_for_inference=False, batch_shape_meta=[B,T|'time:var:extern_data:target_text'])
  <tf.Tensor 'extern_data/placeholders/target_text/target_text_dim0_size:0' shape=(?,) dtype=int32>: shape (23,), dtype int32, min/max 17/17, ([17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17 17])
  <tf.Tensor 'globals/train_flag:0' shape=() dtype=bool>: bool(True)
EXCEPTION

[...]
File "home/philipp/Documents/bachelor-thesis/returnn/repository/TFUtil.py", line 4374, in sparse_labels_with_seq_lens
    x = check_dim_equal(x, 0, seq_lens, 0)
[...]

So, the network is build without errors, but on the first training step, it crashes due to an assertion error. To me it looks like RETURNN or TensorFlow validates the batch length against its original value somehow. But I don't know where and why, so I have no clue what to do about this.

What am I doing wrong? Is my idea even implementable with RETURNN this way?

EDIT (10th June 2020): For clarification: My ultimate goal is to let the teacher generate a top-k list of hypotheses for each input sequence, which are then used to train the student. So, for each input sequence of the student, there are k solutions/target sequences. To train the student, it must predict the probability of each hypothesis, and then the cross-entropy loss is calculated to determine the update gradients. But if there are k target sequences for each input sequence, the student must decode the encoder states k times, at each time targeting a different target sequence. This is why I want to repeat the encoder states k times, to make the student decoder's data parallel and then use the default cross-entropy loss implementation of RETURNN:

input-seq-1 --- teacher-hyp-1-1; 
input-seq-1 --- teacher-hyp-1-2; 
...; 
input-seq-1 --- teacher-hyp-1-k; 
input-seq-2 --- teacher-hyp-2-1; 
... 

Is there a more proper way to achieve my goal?

EDIT (12th June 2020 #1): Yes, I know that the DecisionLayer of the teacher already selects the best hypothesis and that this way, I'm only repeating that best hypothesis k times. I'm doing this as an intermediate step towards my ultimate goal. Later, I want to fetch the top-k list from the teacher's ChoiceLayer somehow, but I felt like this is a different construction site.
But Albert, you say RETURNN would extend the data on batch dimension automatically somehow? How can I imagine that?

EDIT (12th June 2020 #2): Okay, now I select the top-k (this time k=4) hypotheses list from the teacher's choice layer (or output layer) by:

"teacher_hypotheses": {
    "class": "copy", "from": ["extra.search:teacherMT_output"],
    "register_as_extern_data": "teacher_hypotheses_stack"
}

But using this Data as training target of the student leads to the error:

TensorFlow exception: assertion failed: [shape[0]:] [92] [!=] [dim:] [23]
     [[node studentMT_output/rec/subnet_base/check_seq_len_batch_size/check_input_dim/assert_equal_1/Assert/Assert (defined at home/philipp/Documents/bachelor-thesis/returnn/returnn-venv/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Which is, I assume, due to the issue that the target data of the student, the hypotheses list, has a batch axis length k=4 times longer than the one of the student's input data/encoder state data. Doesn't the student encoder state data need to be extended/repeated here, to match the target data?

EDIT (12th June 2020 #3) : I consider the initial issue as solved. The overall issue is continued here Teacher-Student System: Training Student With k Target Sequences for Each Input Sequence

It does not only validate the batch length. It will collapse the batch and time (it has used flatten_with_seq_len_mask , see code of Loss.init and that function) and then calculate the loss on that flattened tensor. So also the seq length need to match. This might be a problem but I'm not sure. As you have the same target also for the rec layer itself, it should have the same seq length in training.

You can debug this by carefully checking the output of debug_print_layer_output_template , ie check the Data (batch-shape-meta) output, if the axes are all correct as you expect them to be. ( debug_print_layer_output_template can and should always be enabled. It will not make it slower.) You can also temporarily enable debug_print_layer_output_shape , which will really print the shape of all tensors. That way you can verify how it looks like.

Your usage of ReinterpretDataLayer looks very wrong. You should never ever explicitly set the axes by integer (like "set_axes": {"B": 1, "T": 0} ). Why are you doing this at all? This could be the reason why it is messed up in the end.

Your repeat function is not very generic. You are using hard coded axes integers there as well. You should never do that. Instead, you would write sth like:

input_data = source(0, as_data=True)
input = input_data.placeholder
...
output = tf.repeat(input, repeats=3, axis=input_data.batch_dim_axis)

Did I understand this correct, that this is what you want to do? Repeat in the batch axis? In that case, you also need to adapt the seq length information of the output of that layer. You cannot simply use that function as-is in an EvalLayer . You would also need to define out_type to a function which correctly returns the correct Data template. Eg like this:

def repeat_out(out):
   out = out.copy()
   out.size_placeholder[0] = tf.repeat(out.size_placeholder[0], axis=0, repeats=3)
   return out

...
"student_encoder_repeater": {
    "class": "eval", "from": ["student_encoder"], "eval": repeat,
    "out_type": lambda sources, **kwargs: repeat_out(sources[0].output)
}

Now you have the additional problem that every time you call this repeat_out , you will get another seq length info. RETURNN will not be able to tell whether these seq lengths are all the same or different (at compile time). And that will cause errors or strange effects. To solve this, you should reuse the same seq length. Eg like this:

"teacher_stack_": {
    "class": "eval", "from": "teacher_decision", "eval": repeat
},
"teacher_stack": {
    "class": "reinterpret_data", "from": "teacher_stack_", "size_base": "student_encoder_repeater"
}

Btw, why do you want to do this repetition at all? What's the idea behind that? You repeat both the student and the teacher 3 times? So just increasing your learning rate by factor 3 would do the same?

Edit : It seems as if this is done to match the top-k list. In that case, this is all wrong, as RETURNN should already automatically do such repetition. You should not do this manually.

Edit : To understand how the repetition (and also beam search resolving in general) works, first thing is you should look at the log output (you must have debug_print_layer_output_template enabled, but you should have that anyway all the time). You will see the output of each layer, esp its Data output object. This is already useful to check if the shapes are all as you expect (check batch_shape_meta in the log). However, this is only the static shape at compile time, so batch-dim is just a marker there. You will also see the search beam information. This will keep track if the batch originates from some beam search (any ChoiceLayer basically), and has a beam, and the beam size. Now, in the code, check SearchChoices.translate_to_common_search_beam , and its usages. When you follow the code, you will see SelectSearchSourcesLayer , and effectively your case will end up with output.copy_extend_with_beam(search_choices.get_beam_info()) .

Edit : To repeat, this is done automatically. You do not need to call copy_extend_with_beam manually.

If you expect to get the top-k list from the teacher, you are also likely doing it wrong, as I see that you used "teacher_decision" as input. I guess this is coming from a DecisionLayer ? In that case, it already took only the first-best from the top-k beam.

Edit : Now I understand that you are ignoring this, and instead want to take only the first best, and then also repeat this. I would recommend to not do that, as you are making it unnecessary complicated, and you are kind of fighting RETURNN which knows what the batch-dim should be and will get confused. (You can make it work by what I wrote above, but really, this is just unnecessary complicated.)

Btw, there is no point in setting an EvalLayer to "trainable": False . That has no effect. The eval layer has no parameters anyway.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM