簡體   English   中英

Scikit-learn:理解 MeanShift fit_predict()

[英]Scikit-learn: Understanding MeanShift fit_predict()

我正在使用 scikit-learn 的 Mean Shift 算法來執行圖像分割。 我有以下代碼:

import cv2
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from itertools import cycle
from PIL import Image

image = Image.open('sample_images/fruit.png').convert('RGB')
image = np.array(image)

red = image[:,:,0]
green = image[:,:,1]
blue = image[:,:,2]

flat_red = red.flatten()
flat_green = green.flatten()
flat_blue = blue.flatten()

flattened = np.stack((flat_red, flat_green, flat_blue))

ms_clf = MeanShift(bin_seeding=True)
ms_labels = ms_clf.fit_predict(flattened)
plt.imshow(np.reshape(ms_labels, [1001, 994]))

我有一個尺寸為 3x994994 的扁平顏色矩陣,所以總共有 2984982 個樣本。

print(flattened.shape)
(3, 994994)

print(flattened)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]

這個扁平矩陣用作 MeanShift fit_predict() 函數的輸入。 當我嘗試打印 fit_predict() 返回的標簽數組時,我得到以下輸出:

print(ms_labels)
[0 1 2]

fit_predict() 函數不是為每個數據樣本返回一個標簽嗎? 為什么我只得到一個包含 3 個元素的數組? 任何見解表示贊賞。

我對MeanShift並不是很熟悉,但是fit_predict()的文檔說它以形狀X(n_samples,n_features)作為輸入,並返回形狀標簽(n_samples,)。 由於您正在輸入3x994994數組,其中n_samples = 3,n_features = 994994,因此,如您所見,這意味着標簽將是(3,)數組。 本質上是將每個圖像通道視為一個數據扁平化。 您是否要讓它標記每個功能?

fit_predict() 的文檔說它以 X 的形狀 (n_samples, n_features) 作為輸入,並返回形狀 (n_samples,) 的標簽。 由於您正在輸入一個 3x994994 數組,其中 n_samples=3 和 n_features=994994,這意味着標簽將是一個 (3,) 數組,如您所見。 它本質上是將“扁平化”中的每個圖像通道視為一個數據。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM