繁体   English   中英

Scikit-learn:理解 MeanShift fit_predict()

[英]Scikit-learn: Understanding MeanShift fit_predict()

我正在使用 scikit-learn 的 Mean Shift 算法来执行图像分割。 我有以下代码:

import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image

image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)

red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]

flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()

flattened = np.stack((flat_red, flat_green, flat_blue))

ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))

我有一个尺寸为 3x994994 的扁平颜色矩阵,所以总共有 2984982 个样本。

print(flattened.shape)
(3, 994994)

print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]

这个扁平矩阵用作 MeanShift fit_predict() 函数的输入。 当我尝试打印 fit_predict() 返回的标签数组时,我得到以下输出:

print(ms_labels)
[0 1 2]

fit_predict() 函数不是为每个数据样本返回一个标签吗? 为什么我只得到一个包含 3 个元素的数组? 任何见解表示赞赏。

我对MeanShift并不是很熟悉,但是fit_predict()的文档说它以形状X(n_samples,n_features)作为输入,并返回形状标签(n_samples,)。 由于您正在输入3x994994数组,其中n_samples = 3,n_features = 994994,因此,如您所见,这意味着标签将是(3,)数组。 本质上是将每个图像通道视为一个数据扁平化。 您是否要让它标记每个功能?

fit_predict() 的文档说它以 X 的形状 (n_samples, n_features) 作为输入,并返回形状 (n_samples,) 的标签。 由于您正在输入一个 3x994994 数组,其中 n_samples=3 和 n_features=994994,这意味着标签将是一个 (3,) 数组,如您所见。 它本质上是将“扁平化”中的每个图像通道视为一个数据。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM