[英]How to plot pie chart of KMeans of image colors
我是计算机视觉的新手,我正在尝试做一个 plot 来总结图像上的颜色识别。 我使用 KMeans 来查找需要检测的像素值。 但我想通过调色板或饼图的可视化来改进代码。 我想我应该计算落在每个集群上的像素值。 但是,我只是卡住了,我不知道该怎么做。
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from skimage import io, color
def kmeans_segementation(image, n_clusters, random_state=0):
rows, cols, channels = image.shape #get the image shape
X = image.reshape(rows*cols, channels)
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
labels = kmeans.labels_.reshape(rows, cols)
labels = kmeans.cluster_centers_ # get group of colors
ordered_colors = [labels[key] for key,value in enumerate(labels)]
print(ordered_colors)
img2_rgb = io.imread('frame41.png')
n_clusters = 15
kmeans_segementation(img2_rgb, n_clusters, random_state=0)
您可以绘制饼图或直方图,如下所示:
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
# image_name = 'motorcycle_left.png'
# image_name = 'coffee.png'
image_name = 'astronaut.png'
image = plt.imread('https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/' + image_name)
n_clusters = 7
rows, cols, channels = image.shape # get the image shape
X = image.reshape(rows * cols, channels)
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
labels = kmeans.labels_
centers = kmeans.cluster_centers_ # get group of colors
unique_labels, counts = np.unique(labels, return_counts=True) # count the labels
label_order = np.argsort(counts)[::-1] # descending order
unique_labels = unique_labels[label_order]
counts = counts[label_order]
percentages = counts / counts.sum() * 100
fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(16, 5))
ax1.imshow(image)
ax1.axis('off')
color_labels = [f'{label}:\n{perc:.1f} %' for label, perc in zip(unique_labels, percentages)]
ax2.pie(counts, labels=color_labels, colors=centers[unique_labels])
bars = ax3.bar(unique_labels.astype(str), counts, color=centers[unique_labels], ec='black')
ax3.bar_label(bars, [f'{perc:.1f} %' for perc in percentages])
for spine in ['top', 'right']:
ax3.spines[spine].set_visible(False)
plt.tight_layout()
plt.show()
这是南非国旗的 7 colors 示例。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.