代码拉取完成,页面将自动刷新
同步操作将从 mynameisi/OPTIMAL_KNN_MNIST_QUESTION 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
# 初始化 Pinecone 客户端
pinecone = Pinecone(api_key="fd2bfa26-9d55-4d2b-9605-59814f89c7d2")
index_name = "mnist-index"
# 获取现有索引列表
existing_indexes = pinecone.list_indexes()
# 检查索引是否存在并删除
if any(index['name'] == index_name for index in existing_indexes):
print(f"索引 '{index_name}' 已存在,正在删除...")
pinecone.delete_index(index_name)
print(f"索引 '{index_name}' 已成功删除。")
else:
print(f"索引 '{index_name}' 不存在,将创建新索引。")
# 创建新索引
print(f"正在创建新索引 '{index_name}'...")
pinecone.create_index(
name=index_name,
dimension=64, # MNIST 每个图像展平后是一个 64 维向量
metric="euclidean", # 使用欧氏距离
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
print(f"索引 '{index_name}' 创建成功。")
# 连接到索引
index = pinecone.Index(index_name)
print(f"已成功连接到索引 '{index_name}'。")
# 加载 MNIST 数据集
digits = load_digits(n_class=10)
X = digits.data # 特征数据
y = digits.target # 标签数据
# 转换数据为 Pinecone 可接受的格式
vectors = []
for i in range(len(X)):
vector_id = str(i) # 使用索引作为唯一标识符
vector_values = X[i].tolist() # 转换为列表
metadata = {"label": int(y[i])} # 元数据包含标签
vectors.append((vector_id, vector_values, metadata))
# 定义批处理大小
batch_size = 1000
for i in range(0, len(vectors), batch_size):
batch = vectors[i:i + batch_size] # 获取一批数据
index.upsert(batch) # 上传到 Pinecone
# 创建手写数字 3 的图像
digit_3 = np.array(
[[0, 0, 255, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 255, 255, 255, 255, 0, 0]]
)
# 将图像像素值缩放并展平
digit_3_flatten = (digit_3 / 255.0) * 16
query_data = digit_3_flatten.ravel().tolist() # 转换为一维列表
# 在 Pinecone 索引中执行搜索
results = index.query(
vector=query_data,
top_k=11, # 返回距离最近的 11 个结果
include_metadata=True # 包含元数据
)
# 提取匹配项的标签
labels = [match['metadata']['label'] for match in results['matches']]
for match, label in zip(results['matches'], labels):
print(f"id: {match['id']}, distance: {match['score']}, label: {label}")
# 使用投票机制确定最终分类结果
# 检查 labels 是否为空
if labels:
final_prediction = Counter(labels).most_common(1)[0][0]
else:
final_prediction = "没有找到匹配项"
# 显示查询图像和预测结果
plt.imshow(digit_3, cmap='gray')
plt.title(f"Predicted digit: {final_prediction}", size=15)
plt.axis('off')
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。