代码拉取完成,页面将自动刷新
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
class align(nn.Module):
def __init__(self, c_in, c_out):
super(align, self).__init__()
self.c_in = c_in
self.c_out = c_out
if c_in > c_out:
self.conv1x1 = nn.Conv2d(c_in, c_out, 1)
def forward(self, x):
if self.c_in > self.c_out:
return self.conv1x1(x)
if self.c_in < self.c_out:
return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0])
return x
class temporal_conv_layer(nn.Module):
def __init__(self, kt, c_in, c_out, act="relu"):
super(temporal_conv_layer, self).__init__()
self.kt = kt
self.act = act
self.c_out = c_out
self.align = align(c_in, c_out)
if self.act == "GLU":
self.conv = nn.Conv2d(c_in, c_out * 2, (kt, 1), 1)
else:
self.conv = nn.Conv2d(c_in, c_out, (kt, 1), 1)
def forward(self, x):
x_in = self.align(x)[:, :, self.kt - 1:, :]
if self.act == "GLU":
x_conv = self.conv(x)
return (x_conv[:, :self.c_out, :, :] + x_in) * torch.sigmoid(x_conv[:, self.c_out:, :, :])
if self.act == "sigmoid":
return torch.sigmoid(self.conv(x) + x_in)
return torch.relu(self.conv(x) + x_in)
class spatio_conv_layer(nn.Module):
def __init__(self, ks, c, Lk):
super(spatio_conv_layer, self).__init__()
self.Lk = Lk
self.theta = nn.Parameter(torch.FloatTensor(c, c, ks))
self.b = nn.Parameter(torch.FloatTensor(1, c, 1, 1))
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.theta, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.theta)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.b, -bound, bound)
def forward(self, x):
x_c = torch.einsum("knm,bitm->bitkn", self.Lk, x)
x_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b
return torch.relu(x_gc + x)
class st_conv_block(nn.Module):
def __init__(self, ks, kt, n, c, p, Lk):
super(st_conv_block, self).__init__()
self.tconv1 = temporal_conv_layer(kt, c[0], c[1], "GLU")
self.sconv = spatio_conv_layer(ks, c[1], Lk)
self.sconv1 = spatio_conv_layer(ks, c[1], Lk)
# self.sconv2 = spatio_conv_layer(ks, c[1], Lk)
# self.sconv3 = spatio_conv_layer(ks, c[1], Lk)
# self.sconv4 = spatio_conv_layer(ks, c[1], Lk)
self.tconv2 = temporal_conv_layer(kt, c[1], c[2])
self.ln = nn.LayerNorm([n, c[2]])
self.dropout = nn.Dropout(p)
def forward(self, x):
x_t1 = self.tconv1(x)
x_s = self.sconv(x_t1)
x_s1 = self.sconv1(x_s)
# x_s2 = self.sconv1(x_s1)
# x_s3 = self.sconv1(x_s2)
# x_s4 = self.sconv1(x_s3)
x_t2 = self.tconv2(x_s1)
x_ln = self.ln(x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return self.dropout(x_ln)
class fully_conv_layer(nn.Module):
def __init__(self, c):
super(fully_conv_layer, self).__init__()
self.conv = nn.Conv2d(c, 1, 1)
def forward(self, x):
return self.conv(x)
class output_layer(nn.Module):
def __init__(self, c, T, n):
super(output_layer, self).__init__()
self.tconv1 = temporal_conv_layer(T, c, c, "GLU")
self.ln = nn.LayerNorm([n, c])
self.tconv2 = temporal_conv_layer(1, c, c, "sigmoid")
self.fc = fully_conv_layer(c)
def forward(self, x):
x_t1 = self.tconv1(x)
x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x_t2 = self.tconv2(x_ln)
return self.fc(x_t2)
class STGCN(nn.Module):
def __init__(self, ks, kt, bs, T, n, Lk, p):
super(STGCN, self).__init__()
self.st_conv1 = st_conv_block(ks, kt, n, bs[0], p, Lk)
self.st_conv2 = st_conv_block(ks, kt, n, bs[1], p, Lk)
self.output = output_layer(bs[1][2], T - 4 * (kt - 1), n)
def forward(self, x):
x_st1 = self.st_conv1(x)
x_st2 = self.st_conv2(x_st1)
return self.output(x_st2)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。