代码拉取完成,页面将自动刷新
from tqdm.auto import tqdm
import argparse
import time, datetime
import torch
import torchvision
from models.gpt2 import get_gpt2
from models.visual_prefix import get_visual_prefix
from utils.random_seeds import same_seeds
from data.dataloader import get_inf_dataset
def test(decoder, prompt_embds, vp, img):
vp.eval()
with torch.no_grad():
img_embeds = vp(img)
inputs_embeds = torch.cat([img_embeds, prompt_embds], dim=1)
generated = []
past_key_values = None # 第一次迭代时还无past_key_values元组.
for i in range(30):
if i == 0:
output = decoder(inputs_embeds=inputs_embeds, past_key_values=past_key_values)
else:
output = decoder(context, past_key_values=past_key_values)
past_key_values = output.past_key_values
# 此时获取GPT2模型计算的输出结果hidden_states张量中第二维度最后一个元素的argmax值, 得出的argmax值即为此次GPT2模型迭代
# 计算生成的下一个token. 注意, 此时若是第一次迭代, 输出结果hidden_states张量的形状为(batch_size, sel_len, n_state);
# 此时若是第二次及之后的迭代, 输出结果hidden_states张量的形状为(batch_size, 1, n_state), all_head_size=n_state=nx=768.
token = torch.argmax(output.logits[..., -1, :])
if token == tokenizer.eos_token_id:
break
# 将本次迭代生成的token的张量变为二维张量, 以作为下一次GPT2模型迭代计算的上下文context.
context = token.unsqueeze(0)
# 将本次迭代计算生成的token的序列索引变为列表存入generated
generated += [token.tolist()]
# 将generated中所有的token的索引转化为token字符.
sequence = tokenizer.decode(generated)
# sequence = sequence.split(".")[:-1]
print('Pred: ' + sequence)
if __name__ == '__main__':
same_seeds(2021)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
decoder, tokenizer, text_embedder = get_gpt2(device=device, model_name='gpt2-xl') # these components has been frozen
input_ids = tokenizer("The Manhattan bridge", return_tensors="pt")['input_ids'].to(device)
total = sum([param.nelement() for param in decoder.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
prompt_embds = text_embedder(tokenizer('<|endoftext|>', return_tensors="pt")['input_ids'].to(device))
vp = torch.load('checkpoints/vp_1epoch_for_gpt2_xl.pt', map_location=device)
# vp = get_visual_prefix().to(device)
train_set, val_set = get_inf_dataset(224)
selected_idx = [5, 105, 335, 455]
for idx in selected_idx:
img, caption = train_set[idx]
# Load all data into GPU
img = img.unsqueeze(0)
torchvision.utils.save_image(img, str(idx) + '.jpg')
img = img.to(device)
print('gt: ' + caption)
test(decoder, prompt_embds, vp, img)
# break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。