简体   繁体   中英

NotImplementedError: shrinkage not supported

Hello I'm running this code and I got an error witht the fit function

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(shrinkage='auto')
lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
LinearDiscriminantAnalysis()

Here is the error :

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-34-ec552dd1faa1> in <module>
      1 lda = LinearDiscriminantAnalysis(shrinkage='auto')
----> 2 lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
      3 LinearDiscriminantAnalysis()

~/anaconda3/lib/python3.8/site-packages/sklearn/discriminant_analysis.py in fit(self, X, y)
    581         if self.solver == "svd":
    582             if self.shrinkage is not None:
--> 583                 raise NotImplementedError("shrinkage not supported")
    584             if self.covariance_estimator is not None:
    585                 raise ValueError(

NotImplementedError: shrinkage not supported

Does someone has an idea to fix it ? (got the same error upgrading scikit learn, and also on google collab)

Thanks !

shrinkage is not supported with svd solver. You can use this parameter with other solvers such as eigen or lsqr as follows:

LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto').fit(X_train, y_train)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM