代码拉取完成,页面将自动刷新
#%%
import pickle
import scipy.sparse as sp
import numpy as np
import os
import scipy.sparse
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.sparse.linalg.eigen.arpack import eigsh
def parse_test_index(path):
index=[]
for line in open(path,"rb"):
index.append(int(line.strip()))
return index
def load_data(path,file_suf):
"""
cora数据集格式
ind.cora.x:训练实例的特征向量,是scipy.sparse.csr.csr_matrix类对象,shape:(140, 1433)
ind.cora.tx:测试实例的特征向量,shape:(1000, 1433)
ind.cora.y:训练实例的标签,独热编码,numpy.ndarray类的实例,是numpy.ndarray对象,shape:(140, 7)
ind.cora.ty:测试实例的标签,独热编码,numpy.ndarray类的实例,shape:(1000, 7)
ind.cora.allx:有标签的+无无标签训练实例的特征向量,是ind.dataset_str.x的超集,shape:(1708, 1433)
ind.cora.ally:对应于ind.dataset_str.allx的标签,独热编码,shape:(1708, 7)
ind.cora.graph => 图数据,collections.defaultdict类的实例,格式为 {index:[index_of_neighbor_nodes]}
allx——1708个点,tx-1000个点,总共2708个点。
ind.cora.test.index => 测试实例tx的id,1000行
Args:
path ([type]): [description]
file_suf ([type]): [description]
Returns:
[type]: [description]
"""
objects=[]
flag=["x","y","tx","ty","allx","ally","graph"]
for fl in flag:
file_name="ind.{}.{}".format(file_suf,fl)
file_path=os.path.join(path,file_name)
with open(file_path,"rb") as f:
objects.append(pickle.load(f,encoding='latin1'))
x,y,tx,ty,allx,ally,graph=tuple(objects)
file_name="ind.{}.test.index".format(file_suf)
file_path=os.path.join(path,file_name)
test_index=parse_test_index(file_path)
test_index_range=sorted(test_index)
#test index是无序的,对序号排序后,将feature和label按排序的序号重新调整顺序排列
feature=sp.vstack((allx,tx))
feature[test_index]=feature[test_index_range]
labels=np.vstack((ally,ty))
labels[test_index]=labels[test_index_range]
#构造邻接矩阵
g=nx.from_dict_of_lists(graph)
adj=nx.adjacency_matrix(g)
train_mask=np.zeros(labels.shape[0])
train_mask[range(len(y))]=1
val_mask=np.zeros(labels.shape[0])
val_mask[range(len(y),len(y)+500)]=1
test_mask=np.zeros(labels.shape[0])
test_mask[test_index_range]=1
return adj,feature,labels,train_mask,val_mask,test_mask
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。