代码拉取完成,页面将自动刷新
同步操作将从 Bin/多模态医学诊断 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
""" Componets of the model
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
def xavier_init(m):
if type(m) == nn.Linear:
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.0)
class LinearLayer(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.clf = nn.Sequential(nn.Linear(in_dim, out_dim))
self.clf.apply(xavier_init)
def forward(self, x):
x = self.clf(x)
return x
class MMDynamic(nn.Module):
def __init__(self, in_dim, hidden_dim, num_class, dropout):
super().__init__()
self.views = len(in_dim)
self.classes = num_class
self.dropout = dropout
# 使用线性全连接网络在每一个模态中提取特征
self.FeatureEncoder = nn.ModuleList([LinearLayer(in_dim[view], hidden_dim[0]) for view in range(self.views)])
self.MMClasifier = []
for layer in range(1, len(hidden_dim)-1):
self.MMClasifier.append(LinearLayer(self.views*hidden_dim[0], hidden_dim[layer]))
self.MMClasifier.append(nn.ReLU())
self.MMClasifier.append(nn.Dropout(p=dropout))
if len(self.MMClasifier):
self.MMClasifier.append(LinearLayer(hidden_dim[-1], num_class))
else:
self.MMClasifier.append(LinearLayer(self.views*hidden_dim[-1], num_class))
self.MMClasifier = nn.Sequential(*self.MMClasifier)
def forward(self, data_list, label=None, infer=False):
criterion = torch.nn.CrossEntropyLoss(reduction='none')
feature = dict() # 提取到的特征
for view in range(self.views):
feature[view] = self.FeatureEncoder[view](data_list[view])
feature[view] = F.relu(feature[view])
feature[view] = F.dropout(feature[view], self.dropout, training=self.training)
MMfeature = torch.cat([i for i in feature.values()], dim=1) # 拼接特征
MMlogit = self.MMClasifier(MMfeature)
if infer:
return MMlogit
MMLoss = torch.mean(criterion(MMlogit, label))
return MMLoss, MMlogit
def infer(self, data_list):
MMlogit = self.forward(data_list, infer=True)
return MMlogit
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。