简体   繁体   中英

What is the difference between fit() and fit_predict() in SpectralClustering

I am trying to understand and use the spectral clustering from sklearn . Let us say we have X matrix input and we create a spectral clustering object as follows:

clustering = SpectralClustering(n_clusters=2,
         assign_labels="discretize",
         random_state=0)

Then, we call a fit_predict using the spectral cluster object.

clusters =  clustering.fit_predict(X)

What confuses me is that when does 'the affinity matrix for X using the selected affinity is created'? Because as per the documentation the fit_predict() method 'Performs clustering on X and returns cluster labels.' But it doesn't explicitly say that it also computes 'the affinity matrix for X using the selected affinity' before clustering.

I appreciate any help or tips.

查看fit_predict() 源代码 ,看来这只是一种便捷方法-实际上只是调用fit()并从对象返回标签。

As already implied in another answer, fit_predict is just a convenience method in order to return the cluster labels. According to the documentation , fit

Creates an affinity matrix for X using the selected affinity, then applies spectral clustering to this affinity matrix.

while fit_predict

Performs clustering on X and returns cluster labels.

Here, Performs clustering on X should be understood as what is described for fit , ie Creates an affinity matrix [...] .

It is not difficult to verify that calling fit_predict is equivalent to getting the labels_ attribute from the object after fit ; using some dummy data, we have

from sklearn.cluster import SpectralClustering
import numpy as np

X = np.array([[1, 2], [1, 4], [10, 0],
               [10, 2], [10, 4], [1, 0]])

# 1st way - use fit and get the labels_
clustering = SpectralClustering(n_clusters=2,
     assign_labels="discretize",
     random_state=0)

clustering.fit(X)
clustering.labels_
# array([1, 1, 0, 0, 0, 1])

# 2nd way - using fit_predict
clustering2 = SpectralClustering(n_clusters=2,
     assign_labels="discretize",
     random_state=0)

clustering2.fit_predict(X)
# array([1, 1, 0, 0, 0, 1])

np.array_equal(clustering.labels_, clustering2.fit_predict(X))
# True

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM