1 Star 0 Fork 0

liuqiang123456789/CN-DPM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 3.49 KB
一键复制 编辑 原始数据 按行查看 历史
Soochan Lee 提交于 2020-05-17 11:34 . Remove checkpoint saving (#5)
import os
import pickle
import torch
from tensorboardX import SummaryWriter
from models import NdpmModel
from data import DataScheduler
def _write_summary(summary, writer: SummaryWriter, step):
for summary_type, summary_dict in summary.items():
if summary_type == 'scalar':
write_fn = writer.add_scalar
elif summary_type == 'image':
write_fn = writer.add_image
elif summary_type == 'histogram':
write_fn = writer.add_histogram
else:
raise RuntimeError('Unsupported summary type: %s' % summary_type)
for tag, value in summary_dict.items():
write_fn(tag, value, step)
def _make_collage(samples, config, grid_h, grid_w):
s = samples.view(
grid_h, grid_w,
config['x_c'], config['x_h'], config['x_w']
)
collage = s.permute(2, 0, 3, 1, 4).contiguous().view(
config['x_c'],
config['x_h'] * grid_h,
config['x_w'] * grid_w
)
return collage
def train_model(config, model: NdpmModel,
scheduler: DataScheduler,
writer: SummaryWriter):
for step, (x, y, t) in enumerate(scheduler):
step += 1
if isinstance(model, NdpmModel):
print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format(
step,
len(model.ndpm.stm_x), config['stm_capacity'],
len(model.ndpm.experts) - 1
), end='')
else:
print('\r[Step {:4}]'.format(step), end='')
summarize = step % config['summary_step'] == 0
summarize_experts = summarize and isinstance(model, NdpmModel)
summarize_samples = summarize and config['summarize_samples']
# learn the model
model.learn(x, y, t, step)
# Evaluate the model
evaluatable = (
not isinstance(model, NdpmModel) or len(model.ndpm.experts) > 1
)
if evaluatable and step % config['eval_step'] == 0:
scheduler.eval(model, writer, step, 'model')
# Evaluate experts of the model's DPMoE
if summarize_experts:
writer.add_scalar('num_experts', len(model.ndpm.experts) - 1, step)
# Summarize samples
if summarize_samples:
is_ndpm = isinstance(model, NdpmModel)
comps = [e.g for e in model.ndpm.experts[1:]] \
if is_ndpm else [model.component]
if len(comps) == 0:
continue
grid_h, grid_w = config['sample_grid']
total_samples = []
# Sample from each expert
for i, expert in enumerate(comps):
with torch.no_grad():
samples = expert.sample(grid_h * grid_w)
total_samples.append(samples)
collage = _make_collage(samples, config, grid_h, grid_w)
writer.add_image('samples/{}'.format(i + 1), collage, step)
if is_ndpm:
counts = model.ndpm.prior.counts[1:]
expert_w = counts / counts.sum()
num_samples = torch.distributions.multinomial.Multinomial(
grid_h * grid_w, probs=expert_w).sample().type(torch.int)
to_collage = []
for i, samples in enumerate(total_samples):
to_collage.append(samples[:num_samples[i]])
to_collage = torch.cat(to_collage, dim=0)
collage = _make_collage(to_collage, config, grid_h, grid_w)
writer.add_image('samples/ndpm', collage, step)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/liuqiang123456789/CN-DPM.git
git@gitee.com:liuqiang123456789/CN-DPM.git
liuqiang123456789
CN-DPM
CN-DPM
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385