1 Star 0 Fork 0

Lu/CENET

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
cenet_model.py 19.92 KB
一键复制 编辑 原始数据 按行查看 历史
omegaxyz 提交于 2023-07-01 16:38 . fix case study bug
# Name: cenet_model
# Author: Reacubeth
# Time: 2021/6/25 17:28
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import *
import math
import copy
"""
class Oracle(nn.Module):
def __init__(self, input_dim, out_dim):
super(Oracle, self).__init__()
self.linear = nn.Sequential(nn.Linear(input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(input_dim, out_dim),
)
def forward(self, x):
return self.linear(x)
"""
class Oracle(nn.Module):
def __init__(self, input_dim, out_dim):
super(Oracle, self).__init__()
self.linear = nn.Sequential(nn.Linear(input_dim, input_dim),
nn.BatchNorm1d(input_dim),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(input_dim, out_dim),
)
def forward(self, x):
return self.linear(x)
class CENET(nn.Module):
def __init__(self, num_e, num_rel, num_t, args):
super(CENET, self).__init__()
# stats
self.num_e = num_e
self.num_t = num_t
self.num_rel = num_rel
self.args = args
# entity relation embedding
self.rel_embeds = nn.Parameter(torch.zeros(2 * num_rel, args.embedding_dim))
nn.init.xavier_uniform_(self.rel_embeds, gain=nn.init.calculate_gain('relu'))
self.entity_embeds = nn.Parameter(torch.zeros(self.num_e, args.embedding_dim))
nn.init.xavier_uniform_(self.entity_embeds, gain=nn.init.calculate_gain('relu'))
self.linear_frequency = nn.Linear(self.num_e, args.embedding_dim)
self.contrastive_hidden_layer = nn.Linear(3 * args.embedding_dim, args.embedding_dim)
self.contrastive_output_layer = nn.Linear(args.embedding_dim, args.embedding_dim)
self.oracle_layer = Oracle(3 * args.embedding_dim, 1)
self.oracle_layer.apply(self.weights_init)
self.linear_pred_layer_s1 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_o1 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_s2 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_o2 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.weights_init(self.linear_frequency)
self.weights_init(self.linear_pred_layer_s1)
self.weights_init(self.linear_pred_layer_o1)
self.weights_init(self.linear_pred_layer_s2)
self.weights_init(self.linear_pred_layer_o2)
"""
pe = torch.zeros(400, 3 * args.embedding_dim)
position = torch.arange(0, 400, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, 3 * args.embedding_dim, 2).float() * (-math.log(10000.0) / (3 * args.embedding_dim)))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
"""
self.dropout = nn.Dropout(args.dropout)
self.logSoftmax = nn.LogSoftmax()
self.softmax = nn.Softmax()
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.crossEntropy = nn.BCELoss()
self.oracle_mode = args.oracle_mode
print('CENET Initiated')
@staticmethod
def weights_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
def forward(self, batch_block, mode_lk, total_data=None):
quadruples, s_history_event_o, o_history_event_s, \
s_history_label_true, o_history_label_true, s_frequency, o_frequency = batch_block
if isListEmpty(s_history_event_o) or isListEmpty(o_history_event_s):
sub_rank, obj_rank, batch_loss = [None] * 3
if mode_lk == 'Training':
return batch_loss
elif mode_lk in ['Valid', 'Test']:
return sub_rank, batch_loss
else:
return None
s = quadruples[:, 0]
r = quadruples[:, 1]
o = quadruples[:, 2]
"""
t = (quadruples[:, 3] / 24.0).long()
time_embedding = self.pe[t]
"""
s_history_tag = copy.deepcopy(s_frequency)
o_history_tag = copy.deepcopy(o_frequency)
s_non_history_tag = copy.deepcopy(s_frequency)
o_non_history_tag = copy.deepcopy(o_frequency)
s_history_tag[s_history_tag != 0] = self.args.lambdax
o_history_tag[o_history_tag != 0] = self.args.lambdax
s_non_history_tag[s_history_tag == 1] = -self.args.lambdax
s_non_history_tag[s_history_tag == 0] = self.args.lambdax
o_non_history_tag[o_history_tag == 1] = -self.args.lambdax
o_non_history_tag[o_history_tag == 0] = self.args.lambdax
s_history_tag[s_history_tag == 0] = -self.args.lambdax
o_history_tag[o_history_tag == 0] = -self.args.lambdax
s_frequency = F.softmax(s_frequency, dim=1)
o_frequency = F.softmax(o_frequency, dim=1)
s_frequency_hidden = self.tanh(self.linear_frequency(s_frequency))
o_frequency_hidden = self.tanh(self.linear_frequency(o_frequency))
if mode_lk == 'Training':
s_nce_loss, _ = self.calculate_nce_loss(s, o, r, self.rel_embeds[:self.num_rel],
self.linear_pred_layer_s1, self.linear_pred_layer_s2,
s_history_tag, s_non_history_tag)
o_nce_loss, _ = self.calculate_nce_loss(o, s, r, self.rel_embeds[self.num_rel:],
self.linear_pred_layer_o1, self.linear_pred_layer_o2,
o_history_tag, o_non_history_tag)
# calculate_spc_loss(self, hidden_lk, actor1, r, rel_embeds, targets):
s_spc_loss = self.calculate_spc_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_spc_loss = self.calculate_spc_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
nce_loss = (s_nce_loss + o_nce_loss) / 2.0
spc_loss = (s_spc_loss + o_spc_loss) / 2.0
# print('nce loss', nce_loss.item(), ' spc loss', spc_loss.item())
return self.args.alpha * nce_loss + (1 - self.args.alpha) * spc_loss
elif mode_lk in ['Valid', 'Test']:
s_history_oid = []
o_history_sid = []
for i in range(quadruples.shape[0]):
s_history_oid.append([])
o_history_sid.append([])
for con_events in s_history_event_o[i]:
s_history_oid[-1] += con_events[:, 1].tolist()
for con_events in o_history_event_s[i]:
o_history_sid[-1] += con_events[:, 1].tolist()
s_nce_loss, s_preds = self.calculate_nce_loss(s, o, r, self.rel_embeds[:self.num_rel],
self.linear_pred_layer_s1, self.linear_pred_layer_s2,
s_history_tag, s_non_history_tag)
o_nce_loss, o_preds = self.calculate_nce_loss(o, s, r, self.rel_embeds[self.num_rel:],
self.linear_pred_layer_o1, self.linear_pred_layer_o2,
o_history_tag, o_non_history_tag)
s_ce_loss, s_pred_history_label, s_ce_all_acc = self.oracle_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_ce_loss, o_pred_history_label, o_ce_all_acc = self.oracle_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
s_mask = to_device(torch.zeros(quadruples.shape[0], self.num_e))
o_mask = to_device(torch.zeros(quadruples.shape[0], self.num_e))
for i in range(quadruples.shape[0]):
if s_pred_history_label[i].item() > 0.5:
s_mask[i, s_history_oid[i]] = 1
else:
s_mask[i, :] = 1
s_mask[i, s_history_oid[i]] = 0
if o_pred_history_label[i].item() > 0.5:
o_mask[i, o_history_sid[i]] = 1
else:
o_mask[i, :] = 1
o_mask[i, o_history_sid[i]] = 0
if self.oracle_mode == 'soft':
s_mask = F.softmax(s_mask, dim=1)
o_mask = F.softmax(o_mask, dim=1)
s_total_loss1, sub_rank1 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask, total_data, 's', True)
o_total_loss1, obj_rank1 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask, total_data, 'o', True)
batch_loss1 = (s_total_loss1 + o_total_loss1) / 2.0
s_total_loss2, sub_rank2 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask, total_data, 's', False)
o_total_loss2, obj_rank2 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask, total_data, 'o', False)
batch_loss2 = (s_total_loss2 + o_total_loss2) / 2.0
# Ground Truth
s_mask_gt = to_device(torch.zeros(quadruples.shape[0], self.num_e))
o_mask_gt = to_device(torch.zeros(quadruples.shape[0], self.num_e))
for i in range(quadruples.shape[0]):
if o[i] in s_history_oid[i]:
s_mask_gt[i, s_history_oid[i]] = 1
else:
s_mask_gt[i, :] = 1
s_mask_gt[i, s_history_oid[i]] = 0
if s[i] in o_history_sid[i]:
o_mask_gt[i, o_history_sid[i]] = 1
else:
o_mask_gt[i, :] = 1
o_mask_gt[i, o_history_sid[i]] = 0
s_total_loss3, sub_rank3 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask_gt, total_data, 's', True)
o_total_loss3, obj_rank3 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask_gt, total_data, 'o', True)
batch_loss3 = (s_total_loss3 + o_total_loss3) / 2.0
return sub_rank1, obj_rank1, batch_loss1, \
sub_rank2, obj_rank2, batch_loss2, \
sub_rank3, obj_rank3, batch_loss3, \
(s_ce_all_acc + o_ce_all_acc) / 2
elif mode_lk == 'Oracle':
print('Oracle Training')
s_ce_loss, _, _ = self.oracle_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_ce_loss, _, _ = self.oracle_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
return (s_ce_loss + o_ce_loss) / 2.0 + self.oracle_l1(0.01)
def oracle_loss(self, actor1, r, rel_embeds, history_label, frequency_hidden):
history_label_pred = F.sigmoid(
self.oracle_layer(torch.cat((self.entity_embeds[actor1], rel_embeds[r], frequency_hidden), dim=1)))
tmp_label = torch.squeeze(history_label_pred).clone().detach()
tmp_label[torch.where(tmp_label > 0.5)[0]] = 1
tmp_label[torch.where(tmp_label < 0.5)[0]] = 0
# print('# Bias Ratio', torch.sum(tmp_label).item() / tmp_label.shape[0])
ce_correct = torch.sum(torch.eq(tmp_label, torch.squeeze(history_label)))
ce_accuracy = 1. * ce_correct.item() / tmp_label.shape[0]
print('# CE Accuracy', ce_accuracy)
ce_loss = self.crossEntropy(torch.squeeze(history_label_pred), torch.squeeze(history_label))
return ce_loss, history_label_pred, ce_accuracy * tmp_label.shape[0]
def calculate_nce_loss(self, actor1, actor2, r, rel_embeds, linear1, linear2, history_tag, non_history_tag):
preds_raw1 = self.tanh(linear1(
self.dropout(torch.cat((self.entity_embeds[actor1], rel_embeds[r]), dim=1))))
preds1 = F.softmax(preds_raw1.mm(self.entity_embeds.transpose(0, 1)) + history_tag, dim=1)
preds_raw2 = self.tanh(linear2(
self.dropout(torch.cat((self.entity_embeds[actor1], rel_embeds[r]), dim=1))))
preds2 = F.softmax(preds_raw2.mm(self.entity_embeds.transpose(0, 1)) + non_history_tag, dim=1)
# cro_entr_loss = self.criterion_link(preds1 + preds2, actor2)
nce = torch.sum(torch.gather(torch.log(preds1 + preds2), 1, actor2.view(-1, 1)))
nce /= -1. * actor2.shape[0]
pred_actor2 = torch.argmax(preds1 + preds2, dim=1) # predicted result
correct = torch.sum(torch.eq(pred_actor2, actor2))
accuracy = 1. * correct.item() / actor2.shape[0]
print('# Batch accuracy', accuracy)
return nce, preds1 + preds2
def link_predict(self, nce_loss, preds, ce_loss, actor1, actor2, r, trust_musk, all_triples, pred_known, oracle,
history_tag=None, case_study=False):
if case_study:
# f = open("case_study.txt", "a+")
# entity2id, relation2id = get_entity_relation_set(self.args.dataset)
pass
if oracle:
preds = torch.mul(preds, trust_musk)
print('$Batch After Oracle accuracy:', end=' ')
else:
print('$Batch No Oracle accuracy:', end=' ')
# compute the correct triples
pred_actor2 = torch.argmax(preds, dim=1) # predicted result
correct = torch.sum(torch.eq(pred_actor2, actor2))
accuracy = 1. * correct.item() / actor2.shape[0]
print(accuracy)
# print('Batch Error', 1 - accuracy)
total_loss = nce_loss + ce_loss
ranks = []
for i in range(preds.shape[0]):
cur_s = actor1[i]
cur_r = r[i]
cur_o = actor2[i]
if case_study:
in_history = torch.where(history_tag[i] > 0)[0]
not_in_history = torch.where(history_tag[i] < 0)[0]
print('---------------------------', file=f)
for hh in range(in_history.shape[0]):
print('his:', entity2id[in_history[hh].item()], file=f)
print(pred_known,
'Truth:', entity2id[cur_s.item()], '--', relation2id[cur_r.item()], '--', entity2id[cur_o.item()],
'Prediction:', entity2id[pred_actor2[i].item()], file=f)
o_label = cur_o
ground = preds[i, cur_o].clone().item()
if self.args.filtering:
if pred_known == 's':
s_id = torch.nonzero(all_triples[:, 0] == cur_s).view(-1)
idx = torch.nonzero(all_triples[s_id, 1] == cur_r).view(-1)
idx = s_id[idx]
idx = all_triples[idx, 2]
else:
s_id = torch.nonzero(all_triples[:, 2] == cur_s).view(-1)
idx = torch.nonzero(all_triples[s_id, 1] == cur_r).view(-1)
idx = s_id[idx]
idx = all_triples[idx, 0]
preds[i, idx] = 0
preds[i, o_label] = ground
ob_pred_comp1 = (preds[i, :] > ground).data.cpu().numpy()
ob_pred_comp2 = (preds[i, :] == ground).data.cpu().numpy()
ranks.append(np.sum(ob_pred_comp1) + ((np.sum(ob_pred_comp2) - 1.0) / 2) + 1)
return total_loss, ranks
def regularization_loss(self, reg_param):
regularization_loss = torch.mean(self.rel_embeds.pow(2)) + torch.mean(self.entity_embeds.pow(2))
return regularization_loss * reg_param
def oracle_l1(self, reg_param):
reg = 0
for param in self.oracle_layer.parameters():
reg += torch.sum(torch.abs(param))
return reg * reg_param
# contrastive
def freeze_parameter(self):
self.rel_embeds.requires_grad_(False)
self.entity_embeds.requires_grad_(False)
self.linear_pred_layer_s1.requires_grad_(False)
self.linear_pred_layer_o1.requires_grad_(False)
self.linear_pred_layer_s2.requires_grad_(False)
self.linear_pred_layer_o2.requires_grad_(False)
self.linear_frequency.requires_grad_(False)
self.contrastive_hidden_layer.requires_grad_(False)
self.contrastive_output_layer.requires_grad_(False)
def contrastive_layer(self, x):
# Implement from the encoder E to the projection network P
# x = F.normalize(x, dim=1)
x = self.contrastive_hidden_layer(x)
# x = F.relu(x)
# x = self.contrastive_output_layer(x)
# Normalize to unit hypersphere
# x = F.normalize(x, dim=1)
return x
def calculate_spc_loss(self, actor1, r, rel_embeds, targets, frequency_hidden):
projections = self.contrastive_layer(
torch.cat((self.entity_embeds[actor1], rel_embeds[r], frequency_hidden), dim=1))
targets = torch.squeeze(targets)
"""if np.random.randint(0, 10) < 1 and torch.sum(targets) / targets.shape[0] < 0.65 and torch.sum(targets) / targets.shape[0] > 0.35:
np.savetxt("xx.tsv", projections.detach().cpu().numpy(), delimiter="\t")
np.savetxt("yy.tsv", targets.detach().cpu().numpy(), delimiter="\t")
"""
dot_product_tempered = torch.mm(projections, projections.T) / 1.0
# Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
exp_dot_tempered = (
torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
)
mask_similar_class = to_device(targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets)
mask_anchor_out = to_device(1 - torch.eye(exp_dot_tempered.shape[0]))
mask_combined = mask_similar_class * mask_anchor_out
cardinality_per_samples = torch.sum(mask_combined, dim=1)
log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)
if torch.any(torch.isnan(supervised_contrastive_loss)):
return 0
return supervised_contrastive_loss
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yiweilu/CENET.git
git@gitee.com:yiweilu/CENET.git
yiweilu
CENET
CENET
master

搜索帮助