1 Star 0 Fork 246

是大伦伦呀/OPTIMAL_KNN_MNIST_QUESTION_2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_example.py 5.52 KB
一键复制 编辑 原始数据 按行查看 历史
是大伦伦呀 提交于 2024-09-18 08:18 . s
import logging
from pinecone import Pinecone,ServerlessSpec
# 从 scikit-learn 库中导入 load_digits 函数
# 这个函数用于加载著名的手写数字数据集 MNIST
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt # 用于绘图
import numpy as np # 用于数值计算
from collections import Counter # 用于计数
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
def create_index():
# 连接到 Pinecone 服务器
pinecone = Pinecone(api_key="d2e52f4b-93f3-4b8f-9369-1c00dbc2052d")
# 索引名称
index_name = "mnist-index"
# 获取现有索引列表
existing_indexes = pinecone.list_indexes()
# 检查索引是否存在,如果存在就删除
# 这个if是否需要,要看情况而定
# 比如有的时候,如果不想要重复删除在创建,这个if就可以不要
if any(index['name'] == index_name for index in existing_indexes):
logging.info(f"索引 '{index_name}' 已存在,正在删除...")
pinecone.delete_index(index_name)
logging.info(f"索引 '{index_name}' 已成功删除。")
else:
logging.info(f"索引 '{index_name}' 不存在,将创建新索引。")
# 创建新索引
logging.info(f"正在创建新索引 '{index_name}'...")
pinecone.create_index(
name=index_name,
dimension=64, # MNIST 每个图像展平后是一个 64 维向量
metric="euclidean", # 使用欧氏距离
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logging.info(f"索引 '{index_name}' 创建成功。")
def load_pinecone():
# 连接到 Pinecone 服务器
pinecone = Pinecone(api_key="d2e52f4b-93f3-4b8f-9369-1c00dbc2052d")
# 索引名称
index_name = "mnist-index"
# 连接到索引
index = pinecone.Index(index_name)
logging.info(f"已成功连接到索引 '{index_name}'。")
return index
create_index()
index = load_pinecone()
# 使用 load_digits 函数加载 MNIST 数据集
# n_class=10 表示加载全部 10 个数字类别(0-9)
digits = load_digits(n_class=10)
# 获取数据集中的特征数据
# X 是一个二维数组,每行代表一个样本,每个样本是一个 64 维的向量(8x8 像素展平)
X = digits.data
# 获取数据集中的标签
# y 是一个一维数组,包含每个样本对应的真实数字标签(0-9)
y = digits.target
# 初始化一个空列表,用于存储转换后的向量数据
vectors = []
# 遍历所有样本,将数据转换为 Pinecone 可接受的格式
for i in range(len(X)):
# 使用样本的索引作为向量的唯一标识符
vector_id = str(i)
# 将 NumPy 数组转换为 Python 列表
# Pinecone 要求输入数据为 Python 列表格式
vector_values = X[i].tolist()
# 创建元数据字典,包含该样本的真实标签
# 将标签转换为整数类型,确保数据类型的一致性
metadata = {"label": int(y[i])}
# 将转换后的数据(ID、向量值、元数据)作为元组添加到 vectors 列表中
vectors.append((vector_id, vector_values, metadata))
# 定义批处理大小,每批最多包含 1000 个向量
# 这是为了避免一次性向 Pinecone 发送过多数据,可能导致请求超时或失败
batch_size = 1000
# 使用步长为 batch_size 的 range 函数,实现分批处理
for i in range(0, len(vectors), batch_size):
# 从 vectors 列表中切片获取一批数据
batch = vectors[i:i + batch_size]
# 使用 upsert 方法将这批数据上传到 Pinecone 索引中
# upsert 操作会插入新的向量或更新已存在的向量
index.upsert(batch)
# 创建一个手写数字 3 的图像
# 使用 numpy 数组表示一个 8x8 的二维图像
# 255 表示白色像素,0 表示黑色像素
digit_3 = np.array(
[[0, 0, 255, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 255, 255, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 0, 0, 0, 255, 0, 0],
[0, 0, 255, 255, 255, 255, 0, 0]]
)
# 将图像像素值从 0-255 的范围缩放到 0-16 的范围
# 这是为了匹配 MNIST 数据集中使用的像素值范围
digit_3_flatten = (digit_3 / 255.0) * 16
# 将二维图像数组展平成一维列表
# 这是因为 Pinecone 要求输入向量是一维的
query_data = digit_3_flatten.ravel().tolist()
# 使用准备好的查询向量在 Pinecone 索引中执行搜索
results = index.query(
vector=query_data,
top_k=11, # 返回距离最近的 11 个结果
include_metadata=True # 同时返回每个向量的元数据(包括标签)
)
# 从搜索结果中提取每个匹配项的标签
labels = [match['metadata']['label'] for match in results['matches']]
# 打印每个匹配结果的详细信息
for match, label in zip(results['matches'], labels):
print(f"id: {match['id']}, distance: {match['score']}, label: {label}")
# 使用投票机制确定最终的分类结果
# Counter().most_common(1) 返回出现次数最多的元素
# [0][0] 获取该元素的值(即预测的数字)
final_prediction = Counter(labels).most_common(1)[0][0]
# 使用 matplotlib 显示查询图像和预测结果
plt.imshow(digit_3, cmap='gray') # 显示灰度图像
plt.title(f"Predicted digit: {final_prediction}", size=15) # 设置标题,显示预测结果
plt.axis('off') # 关闭坐标轴
plt.show() # 展示图像
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/its-da-lunlun/optimal_knn_mnist_question_2.git
git@gitee.com:its-da-lunlun/optimal_knn_mnist_question_2.git
its-da-lunlun
optimal_knn_mnist_question_2
OPTIMAL_KNN_MNIST_QUESTION_2
main

搜索帮助