diff --git a/2_knn/knn_classification.ipynb b/2_knn/knn_classification.ipynb index f646626c1e12d839457f69d857f8828c93904ee4..be225ab26bb2e74eb31880351f829354f5599c82 100644 --- a/2_knn/knn_classification.ipynb +++ b/2_knn/knn_classification.ipynb @@ -176,14 +176,26 @@ "\n", "\n", "# 绘制结果\n", - "plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n", - "plt.title(\"train data\")\n", - "plt.savefig(\"knn_train_data.pdf\")\n", + "for i in range(split_index):\n", + " if y_train[i] == 0:\n", + " plt.scatter(x_train[i,0],x_train[i,1],c = 0, marker='.')\n", + " else:\n", + " plt.scatter(x_train[i,0],x_train[i,1],c = 1, marker='^') \n", + "plt.rcParams['figure.figsize']=(12.0, 8.0)\n", + "mpl.rcParams['font.family'] = 'SimHei'\n", + "plt.title(\"训练数据\")\n", + "plt.savefig(\"fig-res-train.pdf\")\n", "plt.show()\n", - "plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n", - "plt.title(\"test data\")\n", - "plt.savefig(\"knn_test_data.pdf\")\n", - "plt.show()\n" + "\n", + "for i in range(data_size_all - split_index):\n", + " if y_test[i] == 0:\n", + " plt.scatter(x_test[i,0],x_test[i,1],c = 0, marker='.')\n", + " else:\n", + " plt.scatter(x_test[i,0],x_test[i,1],c = 1, marker='^')\n", + "plt.rcParams['figure.figsize']=(12.0, 8.0)\n", + "plt.title(\"测试数据\")\n", + "plt.savefig(\"fig-res-test.pdf\")\n", + "plt.show()" ] }, { @@ -475,7 +487,9 @@ "for i in range(nplot):\n", " img = X_digits[i].reshape(8, 8)\n", " axes[i].imshow(img)\n", - " axes[i].set_title(y_digits[i])\n" + " axes[i].set_title(y_digits[i])\n", + "fig.set_size_inches(16,9)\n", + "fig.savefig('fig-res-digits.pdf')" ] }, { @@ -574,7 +588,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.4" + "version": "3.8.5" } }, "nbformat": 4,