There is a code snippet in the book "Hands on Machine Learning using Scikit-Learn, Keras and TensorFlow":
from sklearn.base import BaseEstimator, TransformerMixin
# column index
rooms_ix, bedrooms_ix, population_ix, households_ix = 3, 4, 5, 6
class CombinedAttributesAdder(BaseEstimator, TransformerMixin):
def __init__(self, add_bedrooms_per_room=True): # no *args or **kargs
self.add_bedrooms_per_room = add_bedrooms_per_room
def fit(self, X, y=None):
return self # nothing else to do
def transform(self, X):
rooms_per_household = X[:, rooms_ix] / X[:, households_ix]
population_per_household = X[:, population_ix] / X[:, households_ix]
if self.add_bedrooms_per_room:
bedrooms_per_room = X[:, bedrooms_ix] / X[:, rooms_ix]
return np.c_[X, rooms_per_household, population_per_household,
bedrooms_per_room]
else:
return np.c_[X, rooms_per_household, population_per_household]
attr_adder = CombinedAttributesAdder(add_bedrooms_per_room=False)
housing_extra_attribs = attr_adder.transform(housing.values)
The class is used to combine two attributes to create additional attributes. I did not understand the use of the second function (def fit)? Since it is not performing any actions, is it redundant?
In scikit-learn, an estimator is the class on which everything else is based. It is expected to implement the .fit
method. Here's the reference in the docs .
In general the .fit
method operates on the training data and stores some resulting state on the object .
In this particular case, the .transform
method can operate on any data without any additional state learned from .fit
so it doesn't need to actually do anything. You still need to define the .fit
method though since scikit-learn
expects all estimators to have one, even if it doesn't do anything.
Finally, this code is fragile because the column indices are defined outside of the object. I'd recommend passing them to __init__
so they are always tracked with the object. Then you can refer to them as self.room_ix
and they are always kept with the object (eg, if it is saved and reloaded).
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.