代码拉取完成,页面将自动刷新
from easydict import EasyDict as edict
import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict as edict
from utils import *
from data_utils import AmazonDataset
class KnowledgeEmbedding(nn.Module):
def __init__(self,dataset,args):
super(KnowledgeEmbedding,self).__init__()
self.embed_size=args.embed_size #100
self.num_neg_samples=args.num_neg_samples #5
self.device=args.device
self.l2_lambda=args.l2_lambda
# Initalize entity embedding
self.entities=edict(
product=edict(vocab_size=dataset.product.vocab_size),
word=edict(vocab_size=dataset.word.vocab_size),
brand=edict(vocab_size=dataset.brand.vocab_size),
category=edict(vocab_size=dataset.category.vocab_size),
user=edict(vocab_size=dataset.user.vocab_size),
)
for e in self.entities:
embed=self._entity_embedding(self.entities[e].vocab_size)
#setattr 设置对象的属性
setattr(self,e,embed)
# Initalize relationship embeddings and relation bias
self.relations = edict(
describe_as=edict(
et='word',
et_distrib=self._make_distrib(dataset.review.word_distrib)),
produced_by=edict(
et='brand',
et_distrib=self._make_distrib(dataset.produced_by.et_distrib)),
belong_to=edict(
et='category',
et_distrib=self._make_distrib(dataset.belong_to.et_distrib)),
comp=edict(
et='product',
et_distrib=self._make_distrib(dataset.comp.et_distrib)),
sub=edict(
et='product',
et_distrib=self._make_distrib(dataset.sub.et_distrib)),
purchase=edict(
et='user',
et_distrib=self._make_distrib(dataset.purchase.et_distrib)),
)
for r in self.relations:
embed=self._relation_embedding()
setattr(self,r,embed)
bias=self._relation_bias(len(self.relations[r].et_distrib))
setattr(self,r+'_bias',bias)
# 初始化实体的embedding entity embedding of size [vocab_size+1, embed_size] last dimension is always 0's.
def _entity_embedding(self,vocab_size):
embed=nn.Embedding(vocab_size+1,self.embed_size,padding_idx=-1,sparse=False)
initrange=0.5/self.embed_size
weight=torch.FloatTensor(vocab_size+1,self.embed_size).uniform_(-initrange,initrange)
embed.weight=nn.Parameter(weight)
return embed
#初始化关系的embedding,每种关系用一个[1, embed_size]
def _relation_embedding(self):
initrange=0.5/self.embed_size
weight=torch.FloatTensor(1,self.embed_size).uniform_(-initrange,initrange)
embed=nn.Parameter(weight)
return embed
def _relation_bias(self,vocab_size):
bias=nn.Embedding(vocab_size+1,1,padding_idx=-1,sparse=False)
bias.weight=nn.Parameter(torch.zeros(vocab_size+1,1))
return bias
def _make_distrib(self,distrib):
distrib=np.power(np.array(distrib,dtype=np.float),0.75)
distrib=distrib/distrib.sum()
distrib=torch.FloatTensor(distrib).to(self.device)
return distrib
def forward(self,batch_idxs):
loss=self.compute_loss(batch_idxs)
return loss
def compute_loss(self,batch_idxs):
# 0 1 2 3 4 5 6
# (p_id, w_id, b_id, c_id, rp_id, rp_id, user)
product_idxs=batch_idxs[:,0]
word_idxs=batch_idxs[:,1]
brand_idxs=batch_idxs[:,2]
category_idxs=batch_idxs[:,3]
product1_idxs=batch_idxs[:,4]
product2_idxs=batch_idxs[:,5]
uid=batch_idxs[:,6]
regularizations=[]
loss=0
# product + describe_as -> word
pw_loss, pw_embeds = self.neg_loss('product', 'describe_as', 'word', product_idxs, word_idxs)
regularizations.extend(pw_embeds)
loss += pw_loss
# product + produced_by -> brand
pb_loss, pb_embeds = self.neg_loss('product', 'produced_by', 'brand', product_idxs, brand_idxs)
if pb_loss is not None:
regularizations.extend(pb_embeds)
loss += pb_loss
# product + belongs_to -> category
pc_loss, pc_embeds = self.neg_loss('product', 'belong_to', 'category', product_idxs, category_idxs)
if pc_loss is not None:
regularizations.extend(pc_embeds)
loss += pc_loss
# product + comp -> product
pr1_loss, pr1_embeds = self.neg_loss('product', 'comp', 'product', product_idxs, product1_idxs)
if pr1_loss is not None:
regularizations.extend(pr1_embeds)
loss += pr1_loss
# product + sub -> product
pr2_loss, pr2_embeds = self.neg_loss('product', 'sub', 'product', product_idxs, product2_idxs)
if pr2_loss is not None:
regularizations.extend(pr2_embeds)
loss += pr2_loss
# product + purchase -> user
ru_loss, ru_embeds = self.neg_loss('product', 'purchase', 'user', product_idxs, uid)
if ru_loss is not None:
regularizations.extend(ru_embeds)
loss += ru_loss
# l2 regularization
if self.l2_lambda > 0:
l2_loss = 0.0
for term in regularizations:
l2_loss += torch.norm(term)
loss += self.l2_lambda * l2_loss
return loss
# 计算每一个三元组的loss值(e1,r,e2)
def neg_loss(self, entity_head, relation, entity_tail, entity_head_idxs, entity_tail_idxs):
# Entity tail indices can be -1. Remove these indices. Batch size may be changed!
mask = entity_tail_idxs >= 0
fixed_entity_head_idxs = entity_head_idxs[mask]
fixed_entity_tail_idxs = entity_tail_idxs[mask]
if fixed_entity_head_idxs.size(0) <= 0:
return None, []
entity_head_embedding = getattr(self, entity_head) # nn.Embedding
entity_tail_embedding = getattr(self, entity_tail) # nn.Embedding
relation_vec = getattr(self, relation) # [1, embed_size]
relation_bias_embedding = getattr(self, relation + '_bias') # nn.Embedding
entity_tail_distrib = self.relations[relation].et_distrib # [vocab_size]
return kg_neg_loss(entity_head_embedding, entity_tail_embedding,
fixed_entity_head_idxs, fixed_entity_tail_idxs,
relation_vec, relation_bias_embedding, self.num_neg_samples, entity_tail_distrib)
def kg_neg_loss(entity_head_embed, entity_tail_embed, entity_head_idxs, entity_tail_idxs,
relation_vec, relation_bias_embed, num_samples, distrib):
batch_size = entity_head_idxs.size(0)
entity_head_vec = entity_head_embed(entity_head_idxs) # [batch_size, embed_size]
example_vec = entity_head_vec + relation_vec # [batch_size, embed_size]
example_vec = example_vec.unsqueeze(2) # [batch_size, embed_size, 1]
#pos_loss
entity_tail_vec = entity_tail_embed(entity_tail_idxs) # [batch_size, embed_size]
pos_vec = entity_tail_vec.unsqueeze(1) # [batch_size, 1, embed_size]
relation_bias = relation_bias_embed(entity_tail_idxs).squeeze(1) # [batch_size]
pos_logits = torch.bmm(pos_vec, example_vec).squeeze() + relation_bias # [batch_size]
pos_loss = -pos_logits.sigmoid().log() # [batch_size]
#neg_loss
neg_sample_idx = torch.multinomial(distrib, num_samples, replacement=True).view(-1)
neg_vec = entity_tail_embed(neg_sample_idx) # [num_samples, embed_size]
neg_logits = torch.mm(example_vec.squeeze(2), neg_vec.transpose(1, 0).contiguous())
neg_logits += relation_bias.unsqueeze(1) # [batch_size, num_samples]
neg_loss = -neg_logits.neg().sigmoid().log().sum(1) # [batch_size]
loss = (pos_loss + neg_loss).mean()
return loss, [entity_head_vec, entity_tail_vec, neg_vec]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。