1 Star 3 Fork 0

KunCheng-He/TIMIT-Conv-TasNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 8.71 KB
一键复制 编辑 原始数据 按行查看 历史
KunCheng-He 提交于 2022-11-03 15:39 . init
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
class cLN(nn.Module):
def __init__(self, dimension, eps = 1e-8, trainable=True):
super(cLN, self).__init__()
self.eps = eps
if trainable:
self.gain = nn.Parameter(torch.ones(1, dimension, 1))
self.bias = nn.Parameter(torch.zeros(1, dimension, 1))
else:
self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False)
self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False)
def forward(self, input):
# input size: (Batch, Freq, Time)
# cumulative mean for each time step
batch_size = input.size(0)
channel = input.size(1)
time_step = input.size(2)
step_sum = input.sum(1) # B, T
step_pow_sum = input.pow(2).sum(1) # B, T
cum_sum = torch.cumsum(step_sum, dim=1) # B, T
cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T
entry_cnt = np.arange(channel, channel*(time_step+1), channel)
entry_cnt = torch.from_numpy(entry_cnt).type(input.type())
entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum)
cum_mean = cum_sum / entry_cnt # B, T
cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) # B, T
cum_std = (cum_var + self.eps).sqrt() # B, T
cum_mean = cum_mean.unsqueeze(1)
cum_std = cum_std.unsqueeze(1)
x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input)
return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type())
class DepthConv1d(nn.Module):
def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False):
super(DepthConv1d, self).__init__()
self.causal = causal
self.skip = skip
self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1)
if self.causal:
self.padding = (kernel - 1) * dilation
else:
self.padding = padding
self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation,
groups=hidden_channel,
padding=self.padding)
self.res_out = nn.Conv1d(hidden_channel, input_channel, 1)
self.nonlinearity1 = nn.PReLU()
self.nonlinearity2 = nn.PReLU()
if self.causal:
self.reg1 = cLN(hidden_channel, eps=1e-08)
self.reg2 = cLN(hidden_channel, eps=1e-08)
else:
self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08)
if self.skip:
self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1)
def forward(self, input):
output = self.reg1(self.nonlinearity1(self.conv1d(input)))
if self.causal:
output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding]))
else:
output = self.reg2(self.nonlinearity2(self.dconv1d(output)))
residual = self.res_out(output)
if self.skip:
skip = self.skip_out(output)
return residual, skip
else:
return residual
class TCN(nn.Module):
def __init__(self, input_dim, output_dim, BN_dim, hidden_dim,
layer, stack, kernel=3, skip=True,
causal=False, dilated=True):
super(TCN, self).__init__()
# input is a sequence of features of shape (B, N, L)
# normalization
if not causal:
self.LN = nn.GroupNorm(1, input_dim, eps=1e-8)
else:
self.LN = cLN(input_dim, eps=1e-8)
self.BN = nn.Conv1d(input_dim, BN_dim, 1)
# TCN for feature extraction
self.receptive_field = 0
self.dilated = dilated
self.TCN = nn.ModuleList([])
for s in range(stack):
for i in range(layer):
if self.dilated:
self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal))
else:
self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal))
if i == 0 and s == 0:
self.receptive_field += kernel
else:
if self.dilated:
self.receptive_field += (kernel - 1) * 2**i
else:
self.receptive_field += (kernel - 1)
#print("Receptive field: {:3d} frames.".format(self.receptive_field))
# output layer
self.output = nn.Sequential(nn.PReLU(),
nn.Conv1d(BN_dim, output_dim, 1)
)
self.skip = skip
def forward(self, input):
# input shape: (B, N, L)
# normalization
output = self.BN(self.LN(input))
# pass to TCN
if self.skip:
skip_connection = 0.
for i in range(len(self.TCN)):
residual, skip = self.TCN[i](output)
output = output + residual
skip_connection = skip_connection + skip
else:
for i in range(len(self.TCN)):
residual = self.TCN[i](output)
output = output + residual
# output layer
if self.skip:
output = self.output(skip_connection)
else:
output = self.output(output)
return output
class TasNet(nn.Module):
""" 复现 Conv-TasNet
enc_dim: 编码器编码出的维度
feature_dim:
sr: 音频采样率
win:
layer:
stack:
kernel:
num_spk: 分离出的音频数量
causal:
"""
def __init__(self, enc_dim=512, feature_dim=128, sr=16000, win=2, layer=8,
stack=3, kernel=3, num_spk=2, causal=False):
super().__init__()
self.num_spk = num_spk # 需要分割出的数量
self.enc_dim = enc_dim # 编码的维度
self.feature_dim = feature_dim
self.win = int(sr*win / 1000) # 该采样率下,两秒(win=2)的数据,可以分割出多少个1000数据的片段
self.stride = self.win // 2
self.layer = layer
self.stack = stack
self.kernel = kernel
self.causal = causal
# input encoder
self.encoder = nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride)
# TCN separator
self.TCN = TCN(self.enc_dim, self.enc_dim*self.num_spk, self.feature_dim, self.feature_dim*4,
self.layer, self.stack, self.kernel, causal=self.causal)
self.receptive_field = self.TCN.receptive_field
# output decoder
self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride)
def pad_signal(self, input):
""" 对输入的信号进行填充
"""
# x [B, T] \ [B, 1, T] B-batch_size T-时序信号
if input.dim() not in [2, 3]:
raise RuntimeError("输入的数据必须为2维或3维的数据")
if input.dim() == 2:
input = input.unsqueeze(1) # 扩充维度为 [B, 1, T]
batch_size = input.shape[0]
nsample = input.shape[2]
rest = self.win - (self.stride + nsample % self.win) % self.win
if rest > 0:
pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def forward(self, input):
# padding
output, rest = self.pad_signal(input)
batch_size = output.size(0)
# waveform encoder
enc_output = self.encoder(output) # B, N, L
# generate masks
masks = torch.sigmoid(self.TCN(enc_output)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L
masked_output = enc_output.unsqueeze(1) * masks # B, C, N, L
# waveform decoder
output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L
output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L
output = output.view(batch_size, self.num_spk, -1) # B, C, T
return output
if __name__ == "__main__":
x = torch.rand(1, 32123)
nnet = TasNet()
x = nnet(x)
print(x.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/byack/timit-conv-tas-net.git
git@gitee.com:byack/timit-conv-tas-net.git
byack
timit-conv-tas-net
TIMIT-Conv-TasNet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385