From 92d5bdd7de3602ed314efec37a2b12343e2f7ed1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E8=A5=BF=E8=B4=9D1?=
 <12701564+xibei-1@user.noreply.gitee.com>
Date: Sun, 5 Nov 2023 10:19:06 +0000
Subject: [PATCH] update optimal_knn.py.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: 西贝1 <12701564+xibei-1@user.noreply.gitee.com>
---
 optimal_knn.py | 51 ++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 43 insertions(+), 8 deletions(-)

diff --git a/optimal_knn.py b/optimal_knn.py
index 53200c3..99550ed 100644
--- a/optimal_knn.py
+++ b/optimal_knn.py
@@ -1,15 +1,50 @@
-# TODO: 导入必要的库和模块
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.datasets import load_digits
+from sklearn.model_selection import train_test_split
+from sklearn.neighbors import KNeighborsClassifier
+import pickle
+from tqdm import tqdm # 添加进度条
 
-# TODO: 加载数字数据集
+# 加载数字数据集
+digits = load_digits()
 
-# TODO: 将数据集划分为训练集和测试集
+# 将数据集划分为训练集和测试集
+X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
 
-# TODO: 初始化变量以存储最佳准确率,相应的k值和最佳knn模型
+# 初始化变量以存储最佳准确率,相应的k值和最佳knn模型
+best_accuracy = 0
+best_k = 0
+best_knn = None
 
-# TODO: 初始化一个列表以存储每个k值的准确率
+# 初始化一个列表以存储每个k值的准确率
+accuracies = []
 
-# TODO: 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型
+# 尝试从1到40的k值,对于每个k值,训练knn模型,保存最佳准确率,k值和knn模型
+for k in tqdm(range(1, 41)): # 添加进度条
+    knn = KNeighborsClassifier(n_neighbors=k)
+    knn.fit(X_train, y_train)
+    accuracy = knn.score(X_test, y_test)
+    accuracies.append(accuracy)
+    if accuracy > best_accuracy:
+        best_accuracy = accuracy
+        best_k = k
+        best_knn = knn
 
-# TODO: 将最佳KNN模型保存到二进制文件
+# 将最佳KNN模型保存到二进制文件
+with open('best_knn_model.pkl', 'wb') as file: # 修改文件名后缀为 .pkl
+    pickle.dump(best_knn, file)
 
-# TODO: 打印最佳准确率和相应的k值
\ No newline at end of file
+# 打印最佳准确率和相应的k值
+print("Best Accuracy:", best_accuracy)
+print("Best K:", best_k)
+
+
+plt.plot(range(1, 41), accuracies)
+plt.xlabel('K')
+plt.ylabel('Accuracy')
+plt.title('Accuracy vs K')
+plt.axvline(x=best_k, color='red', linestyle='--')
+plt.text(best_k, best_accuracy, f'({best_k}, {best_accuracy:.2f})', verticalalignment='bottom', horizontalalignment='right')
+plt.savefig('accuracy_plot.pdf') 
+plt.show()
\ No newline at end of file
-- 
Gitee