1 Star 0 Fork 246

姚杰/OPTIMAL_KNN_MNIST_QUESTION_2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
FaissKNeighbors.py .py 1.46 KB
一键复制 编辑 原始数据 按行查看 历史
姚杰 提交于 2024-10-08 09:42 . 0
import numpy as np
import faiss
class FaissKNeighbors:
"""
使用 FAISS 库实现的 K-近邻算法。
"""
def __init__(self, k: int = 1, res=None):
"""
初始化 FaissKNeighbors 类。
:param k: 最近邻个数。
:param res: FAISS GPU资源对象。
"""
self.index = None
self.y = None
self.k = k
self.res = res
def fit(self, X: np.ndarray, y: np.ndarray):
"""
训练模型。
:param X: 训练数据的特征。
:param y: 训练数据的标签。
"""
self.index = faiss.IndexFlatL2(self.res) if self.res is not None else faiss.IndexFlatL2()
self.y = y
self.index.add(X.astype(np.float32))
def predict(self, X: np.ndarray) -> np.ndarray:
"""
对新的数据集 X 进行分类预测。
:param X: 需要预测的数据。
:return: 预测的标签。
"""
distances, indices = self.index.search(X.astype(np.float32), self.k)
votes = self.y[indices]
predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
return predictions
def score(self, X: np.ndarray, y_true: np.ndarray) -> float:
"""
计算预测准确率。
:param X: 需要预测的数据。
:param y_true: 真实的标签。
:return: 准确率。
"""
predictions = self.predict(X)
return np.mean(predictions == y_true)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/asdaxas/optimal_knn_mnist_question_2.git
git@gitee.com:asdaxas/optimal_knn_mnist_question_2.git
asdaxas
optimal_knn_mnist_question_2
OPTIMAL_KNN_MNIST_QUESTION_2
main

搜索帮助