简体   繁体   中英

Python: Why would a **kwarg of a subclass method be passed as a parameter in __call__() signature of the base class? Results in TypeError

Sorry for the vague title, can't have a lot of characters.

Brief exposition :
I am implementing an Auto Encoder CNN architecture for an image analysis program that requires custom loss functions that don't exist in the keras back end or anywhere in tensorflow. That is no big deal; I like numerical computing and will gladly brush up my OOP skills. In order for my program to run smoothly, I need my loss functions to be callable objects (basically how keras implements their loss objects, so I mirrored my loss classes based on the Keras Loss source code). It has been great and I've learned tons so far. I have a class ReconstructionLoss(LossWrapper): that is a subclass of LossWrapper(Loss): which is a subclass of parent Loss(object): . Specifically, instances of ReconstructionLoss need a tensor, 'DecodeOut' kwarg, but there are other classes of different loss types don't need this kwarg tensor. They all need y_true and EncodeOut args.

Problem :
ReconstructionLoss instances should be able to take a **kwargs (always 'DecodeOut' = someTensor). BUT somehow the kwarg is getting passed to the __call__() signature in the parent class Loss() , resulting in a TypeError: __call__() got an unexpected keyword argument 'DecodeOut' . I think this happens when the LossWrapper call() method is utilized in the body of __call__(self, arg1, arg2): in the Loss class. It will just become more convoluted (haha) if I keep talking about it, so I will let you look at the code to see the rest of the details yourself. I'm sure I'm making a very subtle rookie mistake, as it has been a minute since I've had to do any OOP. Apologies for any weird indentations; it was a hassle copying and pasting it over in the code block on here. (I'm only showing relevant modules here as well):

from __future__ import absolute_import, division, print_function, unicode_literals

import sys,os 
from pathlib import Path
from matplotlib import pyplot, cm
import tensorflow as tf
import tensorflow_io as tfio
import keras
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as KB
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops.losses import util as tf_losses_util
from tensorflow.python.keras.utils import losses_utils

class Loss(object):

    def __init__(self): 
        print('Loss Object Instantiated')
    #HERE IS WHERE I GET THE ERROR once compilation gets here. 
    #kwarg 'DecodeOut' is being passed as a parameter
    # in __call__(self, y_true, EncodeOut, sample_weight = None).
    #It should only be called in the *return value* of the call() method
    # defined in LossWrapper class.
    def __call__(self, y_true, EncodeOut, sample_weight = None):#boom, error
    
        graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
            y_true,EncodeOut, sample_weight)
        with KB.name_scope(self.__class__.__name__), graph_ctx:

            losses = self.call(y_true,EncodeOut) 
            return losses_utils.compute_weighted_loss(losses, sample_weight)


    def call(self, y_true, EncodeOut):
    """Invokes the `Loss` instance.

    Args:
      y_true: Ground truth values, with the same shape as 'y_pred'.
      y_pred: The predicted values.
    """
    NotImplementedError('Must be implemented in subclasses.')
    
 class LossWrapper(Loss):
        def __init__(self, func, **kwargs):
            super(LossWrapper, self).__init__()
            self.func = func
            self.func_kwargs = kwargs

     #See below, def call(self, y_true, EncodeOut), RETURNS
     #reconstruction_loss(y_true,EncodeOut, **self.func_kwargs),
     #but the kwarg = 'DecodeOut' is getting passed in the __call()__ signature when I implement:
     # recon_loss_obj = ReconstructionLoss()
     # reconLoss = recon_loss_obj (trueLabels, someTensor, DecodeOut = someOtherTensor)

     def call(self, y_true, EncodeOut):
             return self.func(y_true,EncodeOut, **self.func_kwargs)
    

class ReconstructionLoss(LossWrapper):
     def __init__(self, DecodeOut = [0]):
    
        super(ReconstructionLoss, self).__init__(
            reconstruction_loss, #this is the value of 'func' in all instances, 
                                 #function defined at below
            DecodeOut = DecodeOut)
        self.DecodeOut = DecodeOut

 def reconstruction_loss( y_true, residual, DecodeOut = [0]):
      print(DecodeOut.shape)
      K  = tf.size(residual[0,:,:,:]).numpy()
      L1_norm_batches = tf.norm(residual-DecodeOut, ord = 1, axis = [-3,-2])
      reconstruction_loss = np.sum(L1_norm_batches.numpy())/K
      return reconstruction_loss 

Implementaion Example:

      loss_object = ReconstructionLoss()
      a=loss_object(y_true_train, conv11, DecodeOut = select) 

      Loss Object Instantiated
      Traceback (most recent call last):

         File "C:\Python File\Tamper.py", line 794, in <module>
         a=loss_object(y_true_train, conv11, DecodeOut = select)

         TypeError: __call__() got an unexpected keyword argument 'DecodeOut'

I am failing to see why 'DecodeOut' is being passed to the __call__() signature to begin with, when it should only be passed to self.func(y_true, EncodeOut, **kwarg) (in which case self.func = reconstruction_loss WHENEVER the ReconstructionLoss() object is instantiated.

I know it is a lot of info for what will probably be a very obvious mistake, but I'm trying to be as exhaustive as possible. If you are wondering why I am taking this approach, it is because I have several other loss objects (reconstructionloss vs. activation loss for instance) that I need to be able to call AND I'm doing this to learn. Furthermore, I am confused because this implementation is a directly analogous to the Keras Loss source code... they take the EXACT same approach as far as I am concerned. Please look at their code if you'd like, specifcally just Loss, LossFunctionWrapper and then the class for BinaryCrossEntropy and the function binary_cross_entropy:

    https://keras-gym.readthedocs.io/en/stable/_modules/tensorflow/python/keras/losses.html

I'm pretty sure I'm missing something subtle that those devs did that I don't know about OR I have sever misunderstanding of what is happening with the inheritance. I have already tried defining all self attributes before super() is called with no success... not even sure if that is appropriate here. For those wondering why this is designed this way, here is an example snippet of keras code with Loss, Wrapper, LossTypeClass, function of loss type with kwargs: The Keras code is a bit more crowded, as they have several more optional arguements like reduction, name, etc.

class Loss(object):


    def __init__(self, reduction=losses_utils.ReductionV2.AUTO,name=None):
         losses_utils.ReductionV2.validate(reduction)
         self.reduction = reduction
         self.name = name

    def __call__(self, y_true, y_pred, sample_weight=None):

        # If we are wrapping a lambda function strip '<>' from the name as it is not
        # accepted in scope name.
        scope_name = 'lambda' if self.name == '<lambda>' else self.name
        graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
        y_true, y_pred, sample_weight)
        with K.name_scope(scope_name or self.__class__.__name__), graph_ctx:
            losses = self.call(y_true, y_pred)
        return losses_utils.compute_weighted_loss(
        losses, sample_weight, reduction=self._get_reduction())

class LossFunctionWrapper(Loss):

   def __init__(self,
                fn,
                reduction=losses_utils.ReductionV2.AUTO,
                name=None,
                **kwargs):
     super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
     self.fn = fn
     self._fn_kwargs = kwargs

   def call(self, y_true, y_pred):

     if tensor_util.is_tensor(y_pred) and tensor_util.is_tensor(y_true):
       y_pred, y_true = tf_losses_util.squeeze_or_expand_dimensions(
           y_pred, y_true)
     return self.fn(y_true, y_pred, **self._fn_kwargs)

class BinaryCrossentropy(LossFunctionWrapper):

   def __init__(self,
           from_logits=False,# a kwarg not passed in __call__() of Loss(), passed when you call the instantiation of BinaryCrossentropy().
           label_smoothing=0,# same
           reduction=losses_utils.ReductionV2.AUTO,
           name='binary_crossentropy'):
     super(BinaryCrossentropy, self).__init__(
         binary_crossentropy,
         name=name,
         reduction=reduction,
         from_logits=from_logits,
         label_smoothing=label_smoothing)
     self.from_logits = from_logits


def binary_crossentropy(y_true, y_pred, from_logits=False,  label_smoothing=0):  # pylint: disable=missing-docstring
   y_pred = ops.convert_to_tensor(y_pred)
   y_true = math_ops.cast(y_true, y_pred.dtype)
   label_smoothing = ops.convert_to_tensor(label_smoothing,dtype=K.floatx())

   def _smooth_labels():
     return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

   y_true = smart_cond.smart_cond(label_smoothing,
                             _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)

~Thanks

There is nothing wrong with code, lazy understaing of OOP, implemented the code incorrectly. kwargs should be passed when you instantiate the object, not when you are trying to call the instance.

loss_obj = ReconstructionLoss(kwargs_here)
value = loss_obj(args)

instead of:

loss_obj = ReconstructionLoss()
value = loss_obj(args, kwargs_here)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM