简体   繁体   中英

Scikit-learn: Understanding MeanShift fit_predict()

I am using the Mean Shift algorithm from scikit-learn to perform image segmentation. I have the following code:

import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image

image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)

red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]

flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()

flattened = np.stack((flat_red, flat_green, flat_blue))

ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))

I have a flattened colour matrix which has dimensions 3x994994 so in total there are 2984982 samples.

print(flattened.shape)
(3, 994994)

print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]

This flattened matrix is used as input for the MeanShift fit_predict() function. When I try to print the array of labels returned by fit_predict() I get the following output:

print(ms_labels)
[0 1 2]

Doesn't the fit_predict() function return a label for each data sample? Why am I only getting an array with 3 elements in it? Any insights are appreciated.

I'm not super familiar with MeanShift, but the documentation for fit_predict() says it takes X of shape (n_samples, n_features) as input, and returns labels of shape (n_samples,). Since you are inputting a 3x994994 array, where n_samples=3 and n_features=994994, this means the labels will be a (3,) array, as you have seen. It's essentially seeing each image channel in flattened as one piece of data. Are you trying to get it to label each feature?

The documentation for fit_predict() says it takes X of shape (n_samples, n_features) as input, and returns labels of shape (n_samples,). Since you are inputting a 3x994994 array, where n_samples=3 and n_features=994994, this means the labels will be a (3,) array, as you have seen. It's essentially seeing each image channel in 'flattened' as one piece of data.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM