3 Star 2 Fork 0

Gitee 极速下载/sharpmask

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/facebookresearch/deepmask
克隆/下载
TrainerSharpMask.lua 4.39 KB
一键复制 编辑 原始数据 按行查看 历史
Pedro Pinheiro 提交于 2016-06-22 01:49 . first commit
--[[----------------------------------------------------------------------------
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
Training and testing loop for SharpMask
------------------------------------------------------------------------------]]
paths.dofile('trainMeters.lua')
local Trainer = torch.class('Trainer')
--------------------------------------------------------------------------------
-- function: init
function Trainer:__init(model, criterion, config)
-- training params
self.model = model
self.criterion = criterion
self.lr = config.lr
-- allocate cuda tensors
self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor()
-- meters
self.lossmeter = LossMeter()
self.maskmeter = IouMeter(0.5,config.testmaxload*config.batch)
-- log
self.modelsv = {model=model:clone('weight', 'bias'),config=config}
self.rundir = config.rundir
self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd()
end
--------------------------------------------------------------------------------
-- function: train
function Trainer:train(epoch, dataloader)
self.model:training()
self:updateScheduler(epoch)
self.lossmeter:reset()
local timer = torch.Timer()
for n, sample in dataloader:run() do
-- copy samples to the GPU
self:copySamples(sample)
-- forward/backward
local outputs = self.model:forward(self.inputs)
local lossbatch = self.criterion:forward(outputs, self.labels)
local gradOutputs = self.criterion:backward(outputs, self.labels)
gradOutputs:mul(self.inputs:size(1))
self.model:zeroGradParameters()
self.model:backward(self.inputs, gradOutputs)
self.model:updateParameters(self.lr)
-- update loss
self.lossmeter:add(lossbatch)
end
-- write log
local logepoch =
string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ',
epoch, timer:time().real/dataloader:size(),self.lossmeter:value())
print(logepoch)
self.log:writeString(string.format('%s\n',logepoch))
self.log:synchronize()
--save model
torch.save(string.format('%s/model.t7', self.rundir),self.modelsv)
if epoch%50 == 0 then
torch.save(string.format('%s/model_%d.t7', self.rundir, epoch),
self.modelsv)
end
collectgarbage()
end
--------------------------------------------------------------------------------
-- function: test
local maxacc = 0
function Trainer:test(epoch, dataloader)
self.model:evaluate()
self.maskmeter:reset()
for n, sample in dataloader:run() do
-- copy input and target to the GPU
self:copySamples(sample)
-- infer mask in batch
local outputs = self.model:forward(self.inputs):float()
cutorch.synchronize()
self.maskmeter:add(outputs:view(sample.labels:size()),sample.labels)
end
self.model:training()
-- check if bestmodel so far
local z,bestmodel = self.maskmeter:value('0.7')
if z > maxacc then
torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv)
maxacc = z
bestmodel = true
end
-- write log
local logepoch =
string.format('[test] | epoch %05d '..
'| IoU: mean %06.2f median %06.2f suc@.5 %06.2f suc@.7 %06.2f '..
'| bestmodel %s',
epoch,
self.maskmeter:value('mean'),self.maskmeter:value('median'),
self.maskmeter:value('0.5'), self.maskmeter:value('0.7'),
bestmodel and '*' or 'x')
print(logepoch)
self.log:writeString(string.format('%s\n',logepoch))
self.log:synchronize()
collectgarbage()
end
--------------------------------------------------------------------------------
-- function: copy inputs/labels to CUDA tensor
function Trainer:copySamples(sample)
self.inputs:resize(sample.inputs:size()):copy(sample.inputs)
self.labels:resize(sample.labels:size()):copy(sample.labels)
end
--------------------------------------------------------------------------------
-- function: update training schedule according to epoch
function Trainer:updateScheduler(epoch)
if self.lr == 0 then
local regimes = {
{ 1, 50, 1e-3},
{ 51, 80, 5e-4},
{ 81, 1e8, 1e-4}
}
for _, row in ipairs(regimes) do
if epoch >= row[1] and epoch <= row[2] then
self.lr = row[3]
end
end
end
end
return Trainer
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mirrors/sharpmask.git
git@gitee.com:mirrors/sharpmask.git
mirrors
sharpmask
sharpmask
master

搜索帮助