代码拉取完成,页面将自动刷新
同步操作将从 付昌陇/MMBSSL 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import random
from operator import itemgetter
import numpy as np
import torch
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.determinstic = True
def process_idx(loader):
new_loader = []
for batch in loader:
clique_batch = []
pos = 0
for t in range(len(batch)):
if t == 0:
pos = 0
else:
pos += batch[t - 1].x.size(0)
for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
clique_batch.append(g[:i].add_(pos))
batch.clique_batch = clique_batch
new_loader.append(batch)
return new_loader
def add_Inbatch(batch):
clique_batch = []
pos = 0
for t in range(len(batch)):
if t == 0:
pos = 0
else:
pos += batch[t - 1].x.size(0)
for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
clique_batch.append(g[:i].add_(pos))
batch.clique_batch = clique_batch
return batch
def get_dict(graphs, clique_dict, all_edges):
idxs = []
for i in graphs:
idxs.append(i.id)
get_items = itemgetter(*idxs)
cli_dict = get_items(clique_dict)
edge_dict = get_items(all_edges)
return cli_dict, edge_dict
def readout(clique_dict, all_edges, graphs):
index = 0
graphx = []
for g in graphs:
readout_result = []
data = g.x.data
x = clique_dict[index]
for j in x:
sum_result = sum([data[a] for a in j])
sum_result = sum_result.cpu().numpy()
readout_result.append(sum_result)
readout_result = np.array(readout_result)
readout_result = torch.Tensor(readout_result)
g.x = readout_result
g.edge_index = torch.Tensor(all_edges[index]).T.to(torch.long)
if len(g.edge_index) != 0:
g.edge_attr = torch.zeros(g.edge_index.size(1), 2)
g.edge_attr[:, 0] = 6
g.edge_attr[:, 1] = 3
g.edge_attr = g.edge_attr.to(torch.long)
num_atoms = g.x.shape[0]
g.num_nodes = num_atoms
graphx.append(g)
index += 1
return graphx
def read_book(cli_path, edge_path):
file = open(cli_path, 'r', encoding='utf-8')
data = file.readlines()
all_edge = []
edges = open(edge_path, 'r', encoding='utf-8')
for e in edges:
s = e.strip('\n')
lst = eval(s)
# 将元组转换为二维列表
lst_2d = [[x, y] for x, y in lst]
all_edge.append(lst_2d)
clique = []
cliques = []
mol_index = 0
for index, cli in enumerate(data):
cli = cli.strip('\n')
index_cli = (cli.split(' /'))
indexs = index_cli[0]
mol, cli_index = indexs.split(' ')
mol, cli_index = int(mol), int(cli_index)
if mol != mol_index:
cliques.append(clique)
clique = []
c = eval(index_cli[1])
clique.append(c)
mol_index = mol
else:
c = eval(index_cli[1])
clique.append(c)
cliques.append(clique)
return cliques, all_edge
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。