Hello I'm running this code and I got an error witht the fit function
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(shrinkage='auto')
lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
LinearDiscriminantAnalysis()
Here is the error :
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-34-ec552dd1faa1> in <module>
1 lda = LinearDiscriminantAnalysis(shrinkage='auto')
----> 2 lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
3 LinearDiscriminantAnalysis()
~/anaconda3/lib/python3.8/site-packages/sklearn/discriminant_analysis.py in fit(self, X, y)
581 if self.solver == "svd":
582 if self.shrinkage is not None:
--> 583 raise NotImplementedError("shrinkage not supported")
584 if self.covariance_estimator is not None:
585 raise ValueError(
NotImplementedError: shrinkage not supported
Does someone has an idea to fix it ? (got the same error upgrading scikit learn, and also on google collab)
Thanks !
shrinkage
is not supported with svd
solver. You can use this parameter with other solvers such as eigen
or lsqr
as follows:
LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto').fit(X_train, y_train)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.