简体   繁体   中英

Python inheritance and keyword arguments

I'm writing a wrapper for pytorch transformers. To keep it simple I will include a minimal example. A class Parent , that will be an abstract class for classes My_BERT(Parent) and My_GPT2(Parent) . Because LM models for model_Bert and model_gpt2 are included in pytorch, they have many similar functions, thus I want to minimize code redundancy by coding otherwise identical functions in Partent .

My_bert and My_gpt2 differ basically with the model initialization, and one argument passed to model, but 99% of functions use both models in identical way.

The problem is with function "model" that accepts different arguments:

  • for model_Bert it is defined as model(input_ids, masekd_lm_labels)
  • for model_gpt2 it is defined as model(input_ids, labels)

minmal code example:

class Parent():
    """ My own class that is an abstract class for My_bert and My_gpt2 """
    def __init__(self):
        pass

    def fancy_arithmetic(self, text):
        print("do_fancy_stuff_that_works_identically_for_both_models(text=text)")

    def compute_model(self, text):
        return self.model(input_ids=text, masked_lm_labels=text) #this line works for My_Bert
        #return self.model(input_ids=text, labels=text) #I'd need this line for My_gpt2

class My_bert(Parent): 
    """ My own My_bert class that is initialized with BERT pytorch 
    model (here model_bert), and uses methods from Parent """
    def __init__(self):
        self.model = model_bert()

class My_gpt2(Parent):
    """ My own My_gpt2 class that is initialized with gpt2 pytorch model (here model_gpt2), and uses methods from Parent """
    def __init__(self):
        self.model = model_gpt2()

class model_gpt2:
    """ This class mocks pytorch transformers gpt2 model, thus I'm writing just bunch of code that allows you run this example"""
    def __init__(self):
        pass

    def __call__(self,*input, **kwargs):
        return self.model( *input, **kwargs)

    def model(self, input_ids, labels):
        print("gpt2")

class model_bert:
    """ This class mocks pytorch transformers bert model"""
    def __init__(self):
        pass

    def __call__(self, *input, **kwargs):
        self.model(*input, **kwargs)

    def model(self, input_ids, masked_lm_labels):
        print("bert")


foo = My_bert()
foo.compute_model("bar")  # this works
bar = My_gpt2()
#bar.compute_model("rawr") #this does not work.

I know I can override Parent::compute_model function inside My_bert and My_gpt2 classes.

BUT since both "model" methods are so similar, I wonder if there is a way to say: " I'll pass you three arguments, you can use those that you know"

def compute_model(self, text):
    return self.model(input_ids=text, masked_lm_labels=text, labels=text) # ignore the arguments you dont know

*args and **kwargs should take care of the issue you are running into.

In your code, you will modify compute_model to take the arbitrary arguments

def compute_model(self, *args, **kwargs):
    return self.model(*args, **kwargs)

Now the arguments will be defined by the model method on the different classes

With this change the following should work:

foo = My_bert()
foo.compute_model("bar", "baz")
bar = My_gpt2()
bar.compute_model("rawr", "baz")

If you are not familiar with args and kwargs, they allow you to pass arbitrary arguments to a function. args will take unnamed parameters and pass them in the order them are received to the function kwargs or keyword arguments takes named arguments and passed them to the correct parameter. So the following will also work:

foo = My_bert()
foo.compute_model(input_ids="bar", masked_lm_labels="baz") 
bar = My_gpt2()
bar.compute_model(input_ids="rawr", labels="baz") 

Just a note the names args and kwargs are meaningless, you can name them anything, but the typical convention is args and kwargs

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM