1 Star 0 Fork 246

韩付坤/OPTIMAL_KNN_MNIST_QUESTION_2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_example.py 2.70 KB
一键复制 编辑 原始数据 按行查看 历史
韩付坤 提交于 2024-09-29 14:26 . add pinecone_example.py.
from pinecone import Pinecone, ServerlessSpec
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
# 初始化 Pinecone 客户端
api_key = "bf3d68e8-9501-461a-90fa-529c0da18661"
pinecone = Pinecone(api_key=api_key)
# 索引名称
INDEX_NAME = "quickstart"
# 检查并删除索引(根据情况决定是否保留此部分)
existing_indexes = pinecone.list_indexes()
if any(index['name'] == INDEX_NAME for index in existing_indexes):
print(f"索引 '{INDEX_NAME}' 已存在,正在删除...")
pinecone.delete_index(INDEX_NAME)
print(f"索引 '{INDEX_NAME}' 已成功删除。")
else:
print(f"索引 '{INDEX_NAME}' 不存在,将创建新索引。")
# 创建新索引
print(f"正在创建新索引 '{INDEX_NAME}'...")
pinecone.create_index(
name=INDEX_NAME,
dimension=64,
metric="euclidean",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
print(f"索引 '{INDEX_NAME}' 创建成功。")
# 连接到索引
index = pinecone.Index(INDEX_NAME)
print(f"已成功连接到索引 '{INDEX_NAME}'。")
# 加载 MNIST 数据集
digits = load_digits(n_class=10)
X = digits.data
y = digits.target
# 转换数据为 Pinecone 可接受的格式
vectors = []
for i, (sample, label) in enumerate(zip(X, y)):
vector_id = str(i)
vector_values = sample.tolist()
metadata = {"label": int(label)}
vectors.append((vector_id, vector_values, metadata))
# 定义批处理大小并上传数据到 Pinecone 索引
BATCH_SIZE = 1000
for i in range(0, len(vectors), BATCH_SIZE):
batch = vectors[i:i + BATCH_SIZE]
index.upsert(batch)
# 创建查询图像
digit_3 = np.array(
[[0, 0, 255, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 255, 255, 255, 255, 0, 0]]
)
digit_3_flatten = (digit_3 / 255.0) * 16
query_data = digit_3_flatten.ravel().tolist()
# 执行查询
results = index.query(
vector=query_data,
top_k=11,
include_metadata=True
)
# 提取标签并确定最终预测
labels = [match['metadata']['label'] for match in results['matches']]
print(f"Labels: {labels}")
if labels:
final_prediction = Counter(labels).most_common(1)[0][0]
else:
final_prediction = None
# 显示图像和预测结果
plt.imshow(digit_3, cmap='gray')
if final_prediction is not None:
plt.title(f"Predicted digit: {final_prediction}", size=15)
else:
plt.title("No prediction available", size=15)
plt.axis('off')
plt.show()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/fukun-han/optimal_knn_mnist_question_2.git
git@gitee.com:fukun-han/optimal_knn_mnist_question_2.git
fukun-han
optimal_knn_mnist_question_2
OPTIMAL_KNN_MNIST_QUESTION_2
main

搜索帮助