1 Star 0 Fork 246

是大伦伦呀/OPTIMAL_KNN_MNIST_QUESTION_2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
pinecone_train.py 3.76 KB
一键复制 编辑 原始数据 按行查看 历史
是大伦伦呀 提交于 2024-09-18 08:18 . s
# 从 scikit-learn 库中导入 load_digits 函数
# 这个函数用于加载著名的手写数字数据集 MNIST
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from collections import Counter
from tqdm import tqdm
import time
import logging
from pinecone_example import load_pinecone,create_index
# 配置 logging,添加时间戳
logging.basicConfig(format='%(asctime)s - %(levelname)s-%(message)s',
level=logging.INFO)
# 加载数据集并划分训练集和测试集
def load_data():
# 加载数字数据集
digits = load_digits()
# 获取数据集中的特征数据
# X 是一个二维数组,每行代表一个样本,每个样本是一个 64 维的向量(8x8 像素展平)
X = digits.data
# 获取数据集中的标签
# y 是一个一维数组,包含每个样本对应的真实数字标签(0-9)
y = digits.target
# 划分数据集为训练集(80%)和测试集(20%)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2, random_state=42)
return X_train, X_test, y_train, y_test
# 上传数据到 Pinecone 索引
def upload_data(index):
X,y = load_data()[0],load_data()[2]
# 初始化一个空列表,用于存储转换后的向量数据
vectors = []
# 遍历所有样本,将数据转换为 Pinecone 可接受的格式
for i in range(len(X)):
# 使用样本的索引作为向量的唯一标识符
vector_id = str(i)
# 将 NumPy 数组转换为 Python 列表
# Pinecone 要求输入数据为 Python 列表格式
vector_values = X[i].tolist()
# 创建元数据字典,包含该样本的真实标签
# 将标签转换为整数类型,确保数据类型的一致性
metadata = {"label": int(y[i])}
# 将转换后的数据(ID、向量值、元数据)作为元组添加到 vectors 列表中
vectors.append((vector_id, vector_values, metadata))
# 定义批处理大小,每批最多包含 1000 个向量
# 这是为了避免一次性向 Pinecone 发送过多数据,可能导致请求超时或失败
batch_size = 1000
with tqdm(total=2, desc="上传数据到Pinecone") as pbar:
# 使用步长为 batch_size 的 range 函数,实现分批处理
for i in range(0, len(vectors), batch_size):
# 从 vectors 列表中切片获取一批数据
batch = vectors[i:i + batch_size]
# 使用 upsert 方法将这批数据上传到 Pinecone 索引中
# upsert 操作会插入新的向量或更新已存在的向量
index.upsert(batch)
time.sleep(0.1)
pbar.update(1)
logging.info(f"{len(vectors)}数据上传完成。")
# 查询并计算准确率
def evaluate_model(index):
X_test, y_test = load_data()[1], load_data()[3]
correct_predictions = 0
total_samples = len(X_test)
for i in tqdm(range(total_samples), desc="测试k=11时的准确率", unit="个样本"):
query_data = X_test[i].tolist()
results = index.query(vector=query_data, top_k=11, include_metadata=True)
# 从搜索结果中提取每个匹配项的标签
labels = [match['metadata']['label'] for match in results['matches']]
if labels:
final_prediction = Counter(labels).most_common(1)[0][0]
if final_prediction == y_test[i]:
correct_predictions += 1
accuracy = correct_predictions / total_samples * 100
logging.info(f"k=11时的准确率: {accuracy:.2f}%")
def main():
index = load_pinecone()
upload_data(index)
evaluate_model(index)
if __name__ == '__main__':
create_index()
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/its-da-lunlun/optimal_knn_mnist_question_2.git
git@gitee.com:its-da-lunlun/optimal_knn_mnist_question_2.git
its-da-lunlun
optimal_knn_mnist_question_2
OPTIMAL_KNN_MNIST_QUESTION_2
main

搜索帮助