代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class NTXentLoss_poly_4D(torch.nn.Module):
def __init__(self, device, batch_size, temperature,patch, use_cosine_similarity):
super(NTXentLoss_poly_4D, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.patch = patch
self.softmax = torch.nn.Softmax(dim=-1)
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_correlated_mask(self):
diag = np.eye(2 * self.batch_size)
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask.to(self.device)
@staticmethod
def _dot_simililarity(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
# batch * patch * seq_length * feature -> batch * patch * seq_length
# print( "4D",zis.size(),zjs.size() )
# zis = torch.squeeze(zis)
# zjs = torch.squeeze(zjs)
# print('zis,zjs.shape',zis.size(),zjs.size()) # 原来:batch * seq_length | 后来:batch * patch * seq_length
representations = torch.cat([zjs, zis], dim=0) # 原来:2 * batch * feature | 后来:(2 * batch) * patch * seq_length
# print('representation: ',representations.size())
similarity_matrix = torch.Tensor(2*self.batch_size,self.patch,2*self.batch_size).to(self.device)
for i in range(0,self.patch):
similarity_matrix[:,i,:] = self.similarity_function(representations[:,i,:], representations[:,i,:]) # 这里应该是每个patch计算每个patch的,
# 所以维度应该是2*batch * patch * 2*patch
# print('similarity_matrix: ',similarity_matrix.size())
positives = torch.Tensor(2 * self.batch_size, self.patch,1)
# filter out the scores from the positive samples
# l_pos = torch.diag(similarity_matrix, self.batch_size)
# r_pos = torch.diag(similarity_matrix, -self.batch_size)
# positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) # 提取的是当前正例的部分
for i in range(0,self.patch):
l_pos = torch.diag(similarity_matrix[:,i,:],self.batch_size)
r_pos = torch.diag(similarity_matrix[:,i,:],-self.batch_size)
# print('l_pos : ',l_pos.size(),' r_pos: ', r_pos.size())
positives[:,i,:] = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
# 到目前为止,计算得到正例的结果,positive的结果
# 这里是计算负例的相似度:
negatives = torch.Tensor(2* self.batch_size, self.patch,2*self.batch_size - 2)
# negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) # 提取的是当前负例的部分,刨除了我们的相同batch对应的正例
for i in range(0,self.patch):
t_neg = similarity_matrix[:,i,:]
neg = t_neg[self.mask_samples_from_same_repr].view(2 * self.batch_size,-1)
negatives[:,i,:] = neg
logits = torch.cat((positives, negatives), dim=2).to(self.device)# 得到[positive,negatives] 原来:2*batch*(2*batch -1)|后来:(2*batch) * patch * (2 * batch)
logits /= self.temperature
"""Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
# labels = torch.zeros(2 * self.batch_size).to(self.device).long() # 可以计算的原:2*batch,1|后来:2*batch * patch * 1
## pay attention to this!!!, 试一下是self.batch_size合适还是self.batch_size * patch合适
labels = torch.zeros(2 * self.batch_size,2*self.batch_size - 1).to(self.device).long()
# print("lables: ", labels.size(),'logits: ',logits.size())
CE = self.criterion(logits, labels) # 使用crossEntropy写所有的部分,
onehot_label = torch.cat((torch.ones(2 * self.batch_size, self.patch,1),torch.zeros(2 * self.batch_size, self.patch, negatives.shape[-1])),dim=-1).to(self.device).long()
# Add poly loss
pt = torch.mean(onehot_label* torch.nn.functional.softmax(logits,dim=-1)) # 保留相等的那个部分
epsilon = self.batch_size
# loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
loss = CE / (2 * self.batch_size) + epsilon * (1/self.batch_size - pt)
# loss = CE / (2 * self.batch_size)
return loss
class NTXentLoss_poly_2D(torch.nn.Module):
def __init__(self, device, batch_size, temperature,patch, use_cosine_similarity):
super(NTXentLoss_poly_2D, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.patch = patch
self.softmax = torch.nn.Softmax(dim=-1)
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_correlated_mask(self):
diag = np.eye(2 * self.batch_size)
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask.to(self.device)
@staticmethod
def _dot_simililarity(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
# print( "2d",zis.size(),zjs.size() )
# batch * patch * seq_length * feature -> batch * patch * seq_length
zis = torch.squeeze(zis)
zjs = torch.squeeze(zjs)
# print('zis,zjs.shape',zis.size(),zjs.size()) # 原来:batch * seq_length | 后来:batch * patch * seq_length
representations = torch.cat([zjs, zis], dim=0) # 原来:2 * batch * feature | 后来:(2 * batch) * patch * seq_length
# print('representation: ',representations.size())
similarity_matrix = self.similarity_function(representations,representations)
# print('similarity_matrix: ',similarity_matrix.size())
# filter out the scores from the positive samples
l_pos = torch.diag(similarity_matrix, self.batch_size)
r_pos = torch.diag(similarity_matrix, -self.batch_size)
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) # 提取的是当前正例的部分
# 到目前为止,计算得到正例的结果,positive的结果
# 这里是计算负例的相似度:
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) # 提取的是当前负例的部分,刨除了我们的相同batch对应的正例
logits = torch.cat((positives, negatives), dim=1)# 得到[positive,negatives] 原来:2*batch*(2*batch -1)|后来:(2*batch) * patch * (2 * batch)
logits /= self.temperature
"""Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
# labels = torch.zeros(2 * self.batch_size).to(self.device).long() # 可以计算的原:2*batch,1|后来:2*batch * patch * 1
## pay attention to this!!!, 试一下是self.batch_size合适还是self.batch_size * patch合适
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
# print("lables: ", labels.size(),'logits: ',logits.size())
CE = self.criterion(logits, labels) # 使用crossEntropy写所有的部分,
onehot_label = torch.cat((torch.ones(2 * self.batch_size,1),torch.zeros(2 * self.batch_size,negatives.shape[-1])),dim=-1).to(self.device).long()
# Add poly loss
# print("onehot_label: ",onehot_label.size())
pt = torch.mean(onehot_label* torch.nn.functional.softmax(logits,dim=-1)) # 保留相等的那个部分
epsilon = self.batch_size
# loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
loss = CE / (2 * self.batch_size) + epsilon * (1/self.batch_size - pt)
# loss = CE / (2 * self.batch_size)
return loss
class Triple_loss(torch.nn.Module):
def __init__(self,device,batch_size,temperature,patch,use_cosine_similarity):
super(Triple_loss,self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.patch = patch
self.use_cosine_similarity
def forward(self,zis,zjs):
B, T = zis.size(0),zis.size(1)
if T == 1:
return zis.new_tensor(0.)
z = torch.cat([zis,zjs],dim=1)
sim = torch.matmul(z,z.transpose(1,2))
logits = torch.tril(sim,diagonal = -1)[:,:,:-1]
logits += torch.triu(sim,diagonal = 1)[:,:,1:]
logits = -F.log_softmax(logits,dim=-1)
t = torch.arange(T,device=zis.device)
loss = (logits[:,t,T+t-1].mean()+logits[:,T+t,t].mean())/2 # 这里需要改
return loss
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。