1 Star 0 Fork 0

Cppowboy/video_caption

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 1.19 KB
一键复制 编辑 原始数据 按行查看 历史
panyinxu 提交于 2017-11-23 15:52 . project _h, c, x, h into dim_embed vec
from core.solver import CaptioningSolver
from core.model import CaptionGenerator
from core.utils import load_coco_data
from core.data import DataEngine
def main():
# load train dataset
# data = load_coco_data(data_path='./data', split='train')
# word_to_idx = data['word_to_idx']
# load val dataset to print out bleu scores every epoch
# val_data = load_coco_data(data_path='./data', split='val')
engine = DataEngine()
word_to_idx, _ = engine.wordidx()
data, val_data, test_data = engine.get_data()
model = CaptionGenerator(word_to_idx, dim_feature=[28, 2048], dim_embed=512,
dim_hidden=2000, n_time_step=16, prev2out=True,
ctx2out=True, alpha_c=1.0, selector=True, dropout=True)
solver = CaptioningSolver(model, data, val_data, n_epochs=100, batch_size=32, update_rule='adam',
learning_rate=0.001, print_every=1000, save_every=1, image_path='./image/',
pretrained_model=None, model_path='model/lstm/', test_model='model/lstm/model-10',
print_bleu=True, log_path='log/')
solver.train()
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cppowboy/video_caption.git
git@gitee.com:cppowboy/video_caption.git
cppowboy
video_caption
video_caption
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385