简体   繁体   中英

Explaination for a code snippet related to classes in Python

There is a code snippet in the book "Hands on Machine Learning using Scikit-Learn, Keras and TensorFlow":

from sklearn.base import BaseEstimator, TransformerMixin
    
    # column index
    rooms_ix, bedrooms_ix, population_ix, households_ix = 3, 4, 5, 6
    
    class CombinedAttributesAdder(BaseEstimator, TransformerMixin):
        def __init__(self, add_bedrooms_per_room=True): # no *args or **kargs
            self.add_bedrooms_per_room = add_bedrooms_per_room
        def fit(self, X, y=None):
            return self  # nothing else to do
        def transform(self, X):
            rooms_per_household = X[:, rooms_ix] / X[:, households_ix]
            population_per_household = X[:, population_ix] / X[:, households_ix]
            if self.add_bedrooms_per_room:
                bedrooms_per_room = X[:, bedrooms_ix] / X[:, rooms_ix]
                return np.c_[X, rooms_per_household, population_per_household,
                             bedrooms_per_room]
            else:
                return np.c_[X, rooms_per_household, population_per_household]
    
    attr_adder = CombinedAttributesAdder(add_bedrooms_per_room=False)
    housing_extra_attribs = attr_adder.transform(housing.values)

The class is used to combine two attributes to create additional attributes. I did not understand the use of the second function (def fit)? Since it is not performing any actions, is it redundant?

In scikit-learn, an estimator is the class on which everything else is based. It is expected to implement the .fit method. Here's the reference in the docs .

In general the .fit method operates on the training data and stores some resulting state on the object .

In this particular case, the .transform method can operate on any data without any additional state learned from .fit so it doesn't need to actually do anything. You still need to define the .fit method though since scikit-learn expects all estimators to have one, even if it doesn't do anything.

Finally, this code is fragile because the column indices are defined outside of the object. I'd recommend passing them to __init__ so they are always tracked with the object. Then you can refer to them as self.room_ix and they are always kept with the object (eg, if it is saved and reloaded).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM