1 Star 2 Fork 0

jacinth2006/机器学习常见算法及演示

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 2.62 KB
一键复制 编辑 原始数据 按行查看 历史
jacinth2006 提交于 2021-09-05 00:19 . 图注意力网络进行节点分类
#%%
import pickle
import scipy.sparse as sp
import numpy as np
import os
import scipy.sparse
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.sparse.linalg.eigen.arpack import eigsh
def parse_test_index(path):
index=[]
for line in open(path,"rb"):
index.append(int(line.strip()))
return index
def load_data(path,file_suf):
"""
cora数据集格式
ind.cora.x:训练实例的特征向量,是scipy.sparse.csr.csr_matrix类对象,shape:(140, 1433)
ind.cora.tx:测试实例的特征向量,shape:(1000, 1433)
ind.cora.y:训练实例的标签,独热编码,numpy.ndarray类的实例,是numpy.ndarray对象,shape:(140, 7)
ind.cora.ty:测试实例的标签,独热编码,numpy.ndarray类的实例,shape:(1000, 7)
ind.cora.allx:有标签的+无无标签训练实例的特征向量,是ind.dataset_str.x的超集,shape:(1708, 1433)
ind.cora.ally:对应于ind.dataset_str.allx的标签,独热编码,shape:(1708, 7)
ind.cora.graph => 图数据,collections.defaultdict类的实例,格式为 {index:[index_of_neighbor_nodes]}
allx——1708个点,tx-1000个点,总共2708个点。
ind.cora.test.index => 测试实例tx的id,1000行
Args:
path ([type]): [description]
file_suf ([type]): [description]
Returns:
[type]: [description]
"""
objects=[]
flag=["x","y","tx","ty","allx","ally","graph"]
for fl in flag:
file_name="ind.{}.{}".format(file_suf,fl)
file_path=os.path.join(path,file_name)
with open(file_path,"rb") as f:
objects.append(pickle.load(f,encoding='latin1'))
x,y,tx,ty,allx,ally,graph=tuple(objects)
file_name="ind.{}.test.index".format(file_suf)
file_path=os.path.join(path,file_name)
test_index=parse_test_index(file_path)
test_index_range=sorted(test_index)
#test index是无序的,对序号排序后,将feature和label按排序的序号重新调整顺序排列
feature=sp.vstack((allx,tx))
feature[test_index]=feature[test_index_range]
labels=np.vstack((ally,ty))
labels[test_index]=labels[test_index_range]
#构造邻接矩阵
g=nx.from_dict_of_lists(graph)
adj=nx.adjacency_matrix(g)
train_mask=np.zeros(labels.shape[0])
train_mask[range(len(y))]=1
val_mask=np.zeros(labels.shape[0])
val_mask[range(len(y),len(y)+500)]=1
test_mask=np.zeros(labels.shape[0])
test_mask[test_index_range]=1
return adj,feature,labels,train_mask,val_mask,test_mask
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/jacinth2006/ML.git
git@gitee.com:jacinth2006/ML.git
jacinth2006
ML
机器学习常见算法及演示
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385