1 Star 0 Fork 0

Guikun Chen/Frozen

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
inference.py 3.04 KB
一键复制 编辑 原始数据 按行查看 历史
Guikun Chen 提交于 2021-12-16 11:06 . initial commit
from tqdm.auto import tqdm
import argparse
import time, datetime
import torch
import torchvision
from models.gpt2 import get_gpt2
from models.visual_prefix import get_visual_prefix
from utils.random_seeds import same_seeds
from data.dataloader import get_inf_dataset
def test(decoder, prompt_embds, vp, img):
vp.eval()
with torch.no_grad():
img_embeds = vp(img)
inputs_embeds = torch.cat([img_embeds, prompt_embds], dim=1)
generated = []
past_key_values = None # 第一次迭代时还无past_key_values元组.
for i in range(30):
if i == 0:
output = decoder(inputs_embeds=inputs_embeds, past_key_values=past_key_values)
else:
output = decoder(context, past_key_values=past_key_values)
past_key_values = output.past_key_values
# 此时获取GPT2模型计算的输出结果hidden_states张量中第二维度最后一个元素的argmax值, 得出的argmax值即为此次GPT2模型迭代
# 计算生成的下一个token. 注意, 此时若是第一次迭代, 输出结果hidden_states张量的形状为(batch_size, sel_len, n_state);
# 此时若是第二次及之后的迭代, 输出结果hidden_states张量的形状为(batch_size, 1, n_state), all_head_size=n_state=nx=768.
token = torch.argmax(output.logits[..., -1, :])
if token == tokenizer.eos_token_id:
break
# 将本次迭代生成的token的张量变为二维张量, 以作为下一次GPT2模型迭代计算的上下文context.
context = token.unsqueeze(0)
# 将本次迭代计算生成的token的序列索引变为列表存入generated
generated += [token.tolist()]
# 将generated中所有的token的索引转化为token字符.
sequence = tokenizer.decode(generated)
# sequence = sequence.split(".")[:-1]
print('Pred: ' + sequence)
if __name__ == '__main__':
same_seeds(2021)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
decoder, tokenizer, text_embedder = get_gpt2(device=device, model_name='gpt2-xl') # these components has been frozen
input_ids = tokenizer("The Manhattan bridge", return_tensors="pt")['input_ids'].to(device)
total = sum([param.nelement() for param in decoder.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
prompt_embds = text_embedder(tokenizer('<|endoftext|>', return_tensors="pt")['input_ids'].to(device))
vp = torch.load('checkpoints/vp_1epoch_for_gpt2_xl.pt', map_location=device)
# vp = get_visual_prefix().to(device)
train_set, val_set = get_inf_dataset(224)
selected_idx = [5, 105, 335, 455]
for idx in selected_idx:
img, caption = train_set[idx]
# Load all data into GPU
img = img.unsqueeze(0)
torchvision.utils.save_image(img, str(idx) + '.jpg')
img = img.to(device)
print('gt: ' + caption)
test(decoder, prompt_embds, vp, img)
# break
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/guikunchen/frozen.git
git@gitee.com:guikunchen/frozen.git
guikunchen
frozen
Frozen
master

搜索帮助