代码拉取完成,页面将自动刷新
from numpy.lib.shape_base import expand_dims
from tqdm.auto import tqdm
import argparse
import time, datetime
import json
import torch
import torchvision
from models.gpt2 import get_gpt2
# from ..models.visual_prefix import get_visual_prefix
from utils.random_seeds import same_seeds
from data.dataloader import get_inf_dataset
from data.vqa import get_vqa
def get_nshot_examples_prefix(vqa, text_embedder, vp, n, device, ques_type):
if n <= 0: return None
examples_prefix = None
ques_ids = vqa.getQuesIds(quesTypes=[ques_type])
with torch.no_grad():
for i in range(n):
img, question, answer = vqa.get_one_vqa_item(ques_ids[i])
img = img.to(device)
img_embeds = vp(img)
ques_text = '<|endoftext|>Question: '+question+' Answer: '+answer+'<|endoftext|>'
ques_embeds = text_embedder(tokenizer(ques_text, return_tensors="pt")['input_ids'].to(device))
if examples_prefix == None:
examples_prefix = torch.cat([img_embeds, ques_embeds], dim=1)
else:
examples_prefix = torch.cat([examples_prefix, img_embeds, ques_embeds], dim=1)
return examples_prefix
def test(vqa, decoder, text_embedder, vp, n_shot, device, max_len=10):
ques_ids = vqa.getQuesIds()
vp.eval()
examples_prefix = dict()
answers = []
with torch.no_grad():
for ques_id in tqdm(ques_ids):
img, question = vqa.get_one_test_item(ques_id)
img = img.to(device)
img_embeds = vp(img)
ques_text = '<|endoftext|>Question: '+question+' Answer: '
ques_embeds = text_embedder(tokenizer(ques_text, return_tensors="pt")['input_ids'].to(device))
inputs_embeds = torch.cat([img_embeds, ques_embeds], dim=1)
# add n shot support
ques_type = vqa.qa[ques_id]['question_type']
if n_shot > 0 and ques_type != 'none of the above':
if ques_type not in examples_prefix:
examples_prefix[ques_type] = get_nshot_examples_prefix(vqa, text_embedder, vp, n_shot, device, ques_type)
inputs_embeds = torch.cat([examples_prefix[ques_type], inputs_embeds], dim=1)
# generate results
generated = []
past_key_values = None # 第一次迭代时还无past_key_values元组.
for i in range(max_len):
if i == 0:
output = decoder(inputs_embeds=inputs_embeds, past_key_values=past_key_values)
else:
output = decoder(context, past_key_values=past_key_values)
past_key_values = output.past_key_values
# 此时获取GPT2模型计算的输出结果hidden_states张量中第二维度最后一个元素的argmax值, 得出的argmax值即为此次GPT2模型迭代
# 计算生成的下一个token. 注意, 此时若是第一次迭代, 输出结果hidden_states张量的形状为(batch_size, sel_len, n_state);
# 此时若是第二次及之后的迭代, 输出结果hidden_states张量的形状为(batch_size, 1, n_state), all_head_size=n_state=nx=768.
token = torch.argmax(output.logits[..., -1, :])
if token == tokenizer.eos_token_id:
break
# 将本次迭代生成的token的张量变为二维张量, 以作为下一次GPT2模型迭代计算的上下文context.
context = token.unsqueeze(0)
# 将本次迭代计算生成的token的序列索引变为列表存入generated
generated += [token.tolist()]
# 将generated中所有的token的索引转化为token字符.
answer = tokenizer.decode(generated).strip()
# sequence = sequence.split(".")[:-1]
answers.append({'answer': answer, 'question_id': ques_id})
with open('results-'+str(n_shot)+'shot'+'.json', 'w') as f:
f.write(json.dumps(answers))
if __name__ == '__main__':
same_seeds(2021)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
decoder, tokenizer, text_embedder = get_gpt2(device=device, model_name='gpt2') # these components has been frozen
total = sum([param.nelement() for param in decoder.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
vp = torch.load('checkpoints/mocov2_1epoch_for_gpt2.pt', map_location=device)
vqa = get_vqa()
n_shots = [0, 1, 4]
for n_shot in n_shots:
test(vqa, decoder, text_embedder, vp, n_shot, device)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。