1 Star 1 Fork 2

syb/多模态医学诊断

forked from Bin/多模态医学诊断 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 2.28 KB
一键复制 编辑 原始数据 按行查看 历史
Bin 提交于 2022-09-27 16:34 . restart
""" Componets of the model
"""
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
def xavier_init(m):
if type(m) == nn.Linear:
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.0)
class LinearLayer(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.clf = nn.Sequential(nn.Linear(in_dim, out_dim))
self.clf.apply(xavier_init)
def forward(self, x):
x = self.clf(x)
return x
class MMDynamic(nn.Module):
def __init__(self, in_dim, hidden_dim, num_class, dropout):
super().__init__()
self.views = len(in_dim)
self.classes = num_class
self.dropout = dropout
# 使用线性全连接网络在每一个模态中提取特征
self.FeatureEncoder = nn.ModuleList([LinearLayer(in_dim[view], hidden_dim[0]) for view in range(self.views)])
self.MMClasifier = []
for layer in range(1, len(hidden_dim)-1):
self.MMClasifier.append(LinearLayer(self.views*hidden_dim[0], hidden_dim[layer]))
self.MMClasifier.append(nn.ReLU())
self.MMClasifier.append(nn.Dropout(p=dropout))
if len(self.MMClasifier):
self.MMClasifier.append(LinearLayer(hidden_dim[-1], num_class))
else:
self.MMClasifier.append(LinearLayer(self.views*hidden_dim[-1], num_class))
self.MMClasifier = nn.Sequential(*self.MMClasifier)
def forward(self, data_list, label=None, infer=False):
criterion = torch.nn.CrossEntropyLoss(reduction='none')
feature = dict() # 提取到的特征
for view in range(self.views):
feature[view] = self.FeatureEncoder[view](data_list[view])
feature[view] = F.relu(feature[view])
feature[view] = F.dropout(feature[view], self.dropout, training=self.training)
MMfeature = torch.cat([i for i in feature.values()], dim=1) # 拼接特征
MMlogit = self.MMClasifier(MMfeature)
if infer:
return MMlogit
MMLoss = torch.mean(criterion(MMlogit, label))
return MMLoss, MMlogit
def infer(self, data_list):
MMlogit = self.forward(data_list, infer=True)
return MMlogit
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/syb1121/multimodal-medical-diagnosis.git
git@gitee.com:syb1121/multimodal-medical-diagnosis.git
syb1121
multimodal-medical-diagnosis
多模态医学诊断
master

搜索帮助