I have the following class:
class MyKMeans:
def __init__(self, max_iter = 300):
self.max_iter = max_iter
# Directly access
self.centroids = None
self.clusters = None
def fit(self, X, k):
"""
"""
# each point is assigned to a cluster
clusters = np.zeros(X.shape[0])
# select k random centroids
random_idxs = np.random.choice(len(X), size=k, replace=False)
centroids = X[random_idxs, :]
# iterate until no change occurs in centroids
while True:
# for each point
for i, point in enumerate(X):
min_d = float('inf')
# find the closest centroid to the point
for idx, centroid in enumerate(centroids):
d = euclidean_dist(centroid, point)
if d < min_d:
min_d = d
clusters[i] = idx
# update the new centroids by averaging the points in each cluster
new_centroids = pd.DataFrame(X).groupby(by=clusters).mean().values
# if the centroids didn't change, then stop
if np.count_nonzero(centroids-new_centroids) == 0:
break
# otherwise, update the centroids
else:
centroids = new_centroids
self.centroids = centroids
self.clusters = clusters
and run it using
k = 4
kmeans = MyKMeans()
kmeans.fit(X, k)
centroids, clusters = kmeans.centroids, kmeans.clusters
However, this takes usually 5 seconds to complete running. On the other hand, if I move the method to a new function,
def fit(X, k):
"""
"""
# each point is assigned to a cluster
clusters = np.zeros(X.shape[0])
# select k random centroids
random_idxs = np.random.choice(len(X), size=k, replace=False)
centroids = X[random_idxs, :]
# iterate until no change occurs in centroids
while True:
# for each point
for i, point in enumerate(X):
min_d = float('inf')
# find the closest centroid to the point
for idx, centroid in enumerate(centroids):
d = euclidean_dist(centroid, point)
if d < min_d:
min_d = d
clusters[i] = idx
# update the new centroids by averaging the points in each cluster
new_centroids = pd.DataFrame(X).groupby(by=clusters).mean().values
# if the centroids didn't change, then stop
if np.count_nonzero(centroids-new_centroids) == 0:
break
# otherwise, update the centroids
else:
centroids = new_centroids
return centroids, clusters
and get the same variables by calling centroids, clusters = fit(X, k)
, the runtime is around 0.5-1 second which is a big difference.
Is there a reason why simply having a class method instead of a function causes such a big difference in runtime, and is there any way to improve the runtime while still being able to use the class?
非类版本中的 return 语句位于 while 循环内,因此它会提前退出循环。
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.