简体   繁体   English

通过mixin模式覆盖的方法无法正常工作

[英]Overwriting methods via mixin pattern does not work as intended

I am trying to introduce a mod/mixin for a problem . 我正在尝试介绍mod / mixin来解决问题 In particular I am focusing here on a SpeechRecognitionProblem . 我在这里特别关注的是SpeechRecognitionProblem I intend to modify this problem and therefore I seek to do the following: 我打算修改此问题,因此,我尝试执行以下操作:

class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):

    def hparams(self, defaults, model_hparams):
        SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
        vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
        p = defaults
        p.vocab_size['targets'] = vocab_size

    def feature_encoders(self, data_dir): 
        # ...

So this one does not do much. 因此,这并没有太大作用。 It calls the hparams() function from the base class and then changes some values. 它从基类中调用hparams()函数,然后更改一些值。

Now, there are already some ready-to-go problems eg Libri Speech: 现在,已经存在一些现成的问题,例如Libri Speech:

@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
    # ..

However, in order to apply my modifications I am doing this: 但是,为了应用我的修改,我正在这样做:

@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
    # ..

This should, if I am not mistaken, overwrite everything (with identical signatures) in Librispeech and instead call functions of SpeechRecognitionProblemMod . 这应该,如果我没有记错的话,在覆盖的一切(具有相同签名) Librispeech ,而是调用的函数SpeechRecognitionProblemMod

Since I was able to train a model with this code I am assuming that it's working as intended so far. 由于我能够使用此代码训练模型,因此我假设到目前为止,它已按预期工作。

Now here comes the my problem: 现在是我的问题了:

After training I want to serialize the model. 训练后,我想序列化模型。 This usually works. 这通常有效。 However, it does not with my mod and I actually know why: 但是,它与我的mod无关,我实际上知道原因:

At a certain point hparams() gets called. 在某个时刻,将调用hparams() Debugging to that point will show me the following: 调试到这一点将显示以下内容:

self                  # {LibrispeechMod}
self.hparams          # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>

self.hparams should be <bound method SpeechRecognitionProblemMod.hparams of ..> ! self.hparams应该是<bound method SpeechRecognitionProblemMod.hparams of ..> self.hparams It would seem that for some reason hparams() of SpeechRecognitionProblem gets called directly instead of SpeechRecognitionProblemMod . 这似乎是出于某种原因hparams()SpeechRecognitionProblem被调用,而不是直接SpeechRecognitionProblemMod But please note that it's the correct type for feature_encoders() ! 但是请注意 ,这是feature_encoders()的正确类型!

The thing is that I know this is working during training. 事实是,我知道这在培训期间是有效的。 I can see that the hyper-paramaters (hparams) are applied accordingly simply because the model's graph node names change through my modifications. 我看到相应地应用了超参数(hparams),仅仅是因为模型的图节点名称通过我的修改而改变了。

There is one specialty I need to point out. 我需要指出一个专业。 tensor2tensor allows to dynamically load a t2t_usr_dir , which are additional python modules which get loaded by import_usr_dir . tensor2tensor允许动态加载t2t_usr_dir ,这是由import_usr_dir加载的其他python模块。 I make use of that function in my serialization script as well: 我在序列化脚本中也使用了该函数:

if usr_dir:
    logging.info('Loading user dir %s' % usr_dir)
    import_usr_dir(usr_dir)

This could be the only culprit I can see at the moment although I would not be able to tell why this may cause the problem. 尽管我无法说出这可能导致问题的原因,但这可能是我目前唯一的罪魁祸首。

If anybody sees something I do not I'd be glad to get a hint what I'm doing wrong here. 如果有人看到我没有的东西,我很乐意在此提示我做错了什么。


So what is the error you're getting? 那么您遇到的错误是什么?

For the sake of completeness, this is the result of the wrong hparams() method being called: 为了完整起见,这是错误的hparams()方法被调用的结果:

NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint

symbol_modality_256_256 is wrong. symbol_modality_256_256错误。 It should be symbol_modality_<vocab-size>_256 where <vocab-size> is a vocabulary size which gets set in SpeechRecognitionProblemMod.hparams . 它应该是symbol_modality_<vocab-size>_256其中<vocab-size>是它获取在设定的词汇尺寸SpeechRecognitionProblemMod.hparams

So, this weird behavior came from the fact that I was remote debugging and that the source files of the usr_dir were not correctly synchronized. 因此,这种怪异的行为来自于我正在远程调试并且usr_dir的源文件未正确同步的usr_dir Everything works as intended but the source files where not matching. 一切正常,但源文件不匹配。

Case closed. 案件结案。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM