简体   繁体   中英

fit() vs fit_predict() metthods in sklearn KMeans

There are two methods when we make a model on sklearn.cluster.KMeans . First is fit() and other is fit_predict() . My understanding is that when we use fit() method on KMeans model, it gives an attribute labels_ which basically holds the info on which observation belong to which cluster. fit_predict() also have labels_ attribute.

So my question are,

  1. If fit() fulfills the need then why their is fit_predict() ?
  2. Are fit() and fit_predict() interchangeable while writing code?

KMeans is just one of the many models that sklearn has, and many share the same API. The basic functions ae fit , which teaches the model using examples, and predict , which uses the knowledge obtained by fit to answer questions on potentially new values.

KMeans will automatically predict the cluster of all the input data during the training, because doing so is integral to the algorithm. It keeps them around for efficiency, because predicting the labels for the original dataset is very common. Thus, fit_predict adds very little: it calls fit , then returns .labels_ . fit_predict is just a convenience method that calls fit , then returns the labels of the training dataset. ( fit_predict doesn't have a labels_ attribute, it just gives you the labels.)

However, if you want to train your model on one set of data and then use this to quickly (and without changing the established cluster boundaries) get an answer for a data point that was not in the original data, you would need to use predict , not fit_predict .

In other models (for example sklearn.neural_network.MLPClassifier ), training can be a very expensive operation so you may not want to re-train a model every time you want to predict something; also, it may not be a given that the prediction result is generated as a part of the prediction. Or, as discussed above, you just don't want to change the model in response to new data. In those cases, you cannot get predictions from the result of fit : you need to call predict with the data you want to get a prediction on.

Also note that labels_ is marked with an underscore, a Python convention for "don't touch this, it's private" (in absence of actual access control). Whenever possible, you should use the established API instead.

In scikit-learn, there are similar things such as fit and fit_transform .
Fit and predict or labels_ are essential for clustering.
Thus fit_predict is just efficient code, and its result is the same as the result from fit and predict (or labels).

In addition, the fitted clustering model is used only once when determining cluster labels of samples.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM