I am using the Mean Shift algorithm from scikit-learn to perform image segmentation. I have the following code:
import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image
image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)
red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]
flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()
flattened = np.stack((flat_red, flat_green, flat_blue))
ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))
I have a flattened colour matrix which has dimensions 3x994994 so in total there are 2984982 samples.
print(flattened.shape)
(3, 994994)
print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
This flattened matrix is used as input for the MeanShift fit_predict() function. When I try to print the array of labels returned by fit_predict() I get the following output:
print(ms_labels)
[0 1 2]
Doesn't the fit_predict() function return a label for each data sample? Why am I only getting an array with 3 elements in it? Any insights are appreciated.
I'm not super familiar with MeanShift, but the documentation for fit_predict() says it takes X of shape (n_samples, n_features) as input, and returns labels of shape (n_samples,). Since you are inputting a 3x994994 array, where n_samples=3 and n_features=994994, this means the labels will be a (3,) array, as you have seen. It's essentially seeing each image channel in flattened as one piece of data. Are you trying to get it to label each feature?
The documentation for fit_predict() says it takes X of shape (n_samples, n_features) as input, and returns labels of shape (n_samples,). Since you are inputting a 3x994994 array, where n_samples=3 and n_features=994994, this means the labels will be a (3,) array, as you have seen. It's essentially seeing each image channel in 'flattened' as one piece of data.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.