56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
from k_means import KMeans
|
||
|
|
||
|
|
||
|
|
||
|
data = pd.read_csv('../data/iris.csv')
|
||
|
iris_types = ['SETOSA','VERSICOLOR','VIRGINICA']
|
||
|
|
||
|
x_axis = 'petal_length'
|
||
|
y_axis = 'petal_width'
|
||
|
|
||
|
plt.figure(figsize=(12,5))
|
||
|
plt.subplot(1,2,1)
|
||
|
for iris_type in iris_types:
|
||
|
plt.scatter(data[x_axis][data['class']==iris_type],data[y_axis][data['class']==iris_type],label = iris_type)
|
||
|
plt.title('label known')
|
||
|
plt.legend()
|
||
|
|
||
|
plt.subplot(1,2,2)
|
||
|
plt.scatter(data[x_axis][:],data[y_axis][:])
|
||
|
plt.title('label unknown')
|
||
|
plt.show()
|
||
|
|
||
|
num_examples = data.shape[0]
|
||
|
x_train = data[[x_axis,y_axis]].values.reshape(num_examples,2)
|
||
|
|
||
|
#指定好训练所需的参数
|
||
|
num_clusters = 3
|
||
|
max_iteritions = 50
|
||
|
|
||
|
k_means = KMeans(x_train,num_clusters)
|
||
|
centroids,closest_centroids_ids = k_means.train(max_iteritions)
|
||
|
# 对比结果
|
||
|
plt.figure(figsize=(12,5))
|
||
|
plt.subplot(1,2,1)
|
||
|
for iris_type in iris_types:
|
||
|
plt.scatter(data[x_axis][data['class']==iris_type],data[y_axis][data['class']==iris_type],label = iris_type)
|
||
|
plt.title('label known')
|
||
|
plt.legend()
|
||
|
|
||
|
plt.subplot(1,2,2)
|
||
|
for centroid_id, centroid in enumerate(centroids):
|
||
|
current_examples_index = (closest_centroids_ids == centroid_id).flatten()
|
||
|
plt.scatter(data[x_axis][current_examples_index],data[y_axis][current_examples_index],label = centroid_id)
|
||
|
|
||
|
for centroid_id, centroid in enumerate(centroids):
|
||
|
plt.scatter(centroid[0],centroid[1],c='black',marker = 'x')
|
||
|
plt.legend()
|
||
|
plt.title('label kmeans')
|
||
|
plt.show()
|
||
|
|
||
|
|