Fetch the repository succeeded.
This action will force synchronization from zyw_hw/GLM-MS, which will overwrite any changes that you have made since you forked the repository, and can not be recovered!!!
Synchronous operation will process in the background and will refresh the page when finishing processing. Please be patient.
"""test model"""
import argparse
from functools import partial
import numpy as np
import time
from mindspore import Model
from mindspore import dtype as mstype
from mindspore import set_context, Tensor
from glm.model.glm_130b_model import GLM130B
from glm.transformer.op_parallel_config import default_dpmp_config
from glm.transformer.vocab_embedding import default_embedding_parallel_config
from text_generate.model import batch_filling_sequence
from text_generate.strategies import BaseStrategy
from text_generate.generate import get_masks_and_position_ids
def set_args():
args = argparse.Namespace()
args.num_layers = 2
args.hidden_size = 128
args.batch_size = 1
args.inner_hidden_size = 128
args.num_attention_heads = 8
args.model_parallel_size = 1
# args.fp16 = False
args.fp16 = True
args.vocab_size = 200
args.max_sequence_length = 128
args.hidden_dropout = 0.0
args.attention_dropout = 0.0
args.hidden_size_per_attention_head = 16 # default "None" means hidden-size/num-attention-heads.
args.checkpoint_activations = False
args.checkpoint_num_layers = 1
args.layernorm_order = "post"
args.skip_init = True
args.use_gpu_initialization = False
args.op_parallel_config = default_dpmp_config
args.embed_parallel_config = default_embedding_parallel_config
args.use_past = True
return args
def test_glm_130b_ms():
args = set_args()
set_context(mode=0, device_target="CPU")
# set_context(mode=1, device_target="CPU")
# set_context(mode=1, device_target="Ascend", device_id=1)
# set_context(mode=0, device_target="Ascend", device_id=2)
params_dtype = mstype.float16 if args.fp16 else mstype.float32
model = GLM130B(args, params_dtype=params_dtype)
# print(model)
# for k, v in model.parameters_and_names():
# print(k)
# print(v.dtype)
# print(v.shape)
input_ids = args.vocab_size * np.random.random((args.batch_size, args.max_sequence_length))
input_ids = input_ids.astype(np.int32)
input_ids = Tensor(input_ids, dtype=mstype.int32)
position_ids = np.arange(args.max_sequence_length).astype(np.int32)
position_ids = np.expand_dims(position_ids, 0)
position_ids = Tensor(position_ids, mstype.int32)
attention_mask = np.ones((args.batch_size, args.max_sequence_length, args.max_sequence_length))
attention_mask = np.tril(attention_mask)
attention_mask = Tensor(attention_mask, dtype=mstype.int32)
init_set = False
valid_length = Tensor([128,], dtype=mstype.int32)
logit = model(input_ids, position_ids, attention_mask, init_set, valid_length)
print(logit.shape, logit)
for item in model.parameters_and_names():
print(item[0])
print(item[1].shape)
if "_past" in item[0]:
print(item[1])
init_set = True
model.add_flags_recursive(is_first_iteration=False)
input_ids = input_ids[:, -1:]
position_ids = position_ids[:, -1:]
attention_mask = attention_mask[:, -1:, :]
print(input_ids.shape, position_ids.shape, attention_mask.shape)
logit2 = model(input_ids, position_ids, attention_mask, init_set, valid_length)
print(logit2.shape, logit2)
def test_batch_filling_sequence():
# set_context(mode=0, device_target="CPU")
set_context(mode=0, device_target="Ascend", device_id=0)
# args
args = set_args()
# model
params_dtype = mstype.float16 if args.fp16 else mstype.float32
glm = GLM130B(args, params_dtype=params_dtype)
glm.set_train(False)
model = Model(glm)
# init seq
seq = args.vocab_size * np.random.random((args.batch_size, 16))
seq = Tensor(seq, dtype=mstype.int32)
# strategy
strategy = BaseStrategy(args.batch_size)
start_time = time.time()
out = batch_filling_sequence(model, seq, strategy=strategy,
get_masks_and_position_ids=partial(
get_masks_and_position_ids,
mask_position=16,
max_gen_length=args.max_sequence_length - seq.shape[-1],
gmask=False,),)
print('final out', out.shape)
end_time = time.time()
print('infer cost time: ', end_time - start_time)
if __name__ == "__main__":
# test_glm_130b_ms()
test_batch_filling_sequence()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。