1 Star 1 Fork 3

zhuyuxiao/GLM-MS

forked from zyw_hw/GLM-MS 
Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
test_model.py 4.29 KB
Copy Edit Raw Blame History
huanglei authored 2023-03-07 11:26 . fix & basic complete incremental infer
"""test model"""
import argparse
from functools import partial
import numpy as np
import time
from mindspore import Model
from mindspore import dtype as mstype
from mindspore import set_context, Tensor
from glm.model.glm_130b_model import GLM130B
from glm.transformer.op_parallel_config import default_dpmp_config
from glm.transformer.vocab_embedding import default_embedding_parallel_config
from text_generate.model import batch_filling_sequence
from text_generate.strategies import BaseStrategy
from text_generate.generate import get_masks_and_position_ids
def set_args():
args = argparse.Namespace()
args.num_layers = 2
args.hidden_size = 128
args.batch_size = 1
args.inner_hidden_size = 128
args.num_attention_heads = 8
args.model_parallel_size = 1
# args.fp16 = False
args.fp16 = True
args.vocab_size = 200
args.max_sequence_length = 128
args.hidden_dropout = 0.0
args.attention_dropout = 0.0
args.hidden_size_per_attention_head = 16 # default "None" means hidden-size/num-attention-heads.
args.checkpoint_activations = False
args.checkpoint_num_layers = 1
args.layernorm_order = "post"
args.skip_init = True
args.use_gpu_initialization = False
args.op_parallel_config = default_dpmp_config
args.embed_parallel_config = default_embedding_parallel_config
args.use_past = True
return args
def test_glm_130b_ms():
args = set_args()
set_context(mode=0, device_target="CPU")
# set_context(mode=1, device_target="CPU")
# set_context(mode=1, device_target="Ascend", device_id=1)
# set_context(mode=0, device_target="Ascend", device_id=2)
params_dtype = mstype.float16 if args.fp16 else mstype.float32
model = GLM130B(args, params_dtype=params_dtype)
# print(model)
# for k, v in model.parameters_and_names():
# print(k)
# print(v.dtype)
# print(v.shape)
input_ids = args.vocab_size * np.random.random((args.batch_size, args.max_sequence_length))
input_ids = input_ids.astype(np.int32)
input_ids = Tensor(input_ids, dtype=mstype.int32)
position_ids = np.arange(args.max_sequence_length).astype(np.int32)
position_ids = np.expand_dims(position_ids, 0)
position_ids = Tensor(position_ids, mstype.int32)
attention_mask = np.ones((args.batch_size, args.max_sequence_length, args.max_sequence_length))
attention_mask = np.tril(attention_mask)
attention_mask = Tensor(attention_mask, dtype=mstype.int32)
init_set = False
valid_length = Tensor([128,], dtype=mstype.int32)
logit = model(input_ids, position_ids, attention_mask, init_set, valid_length)
print(logit.shape, logit)
for item in model.parameters_and_names():
print(item[0])
print(item[1].shape)
if "_past" in item[0]:
print(item[1])
init_set = True
model.add_flags_recursive(is_first_iteration=False)
input_ids = input_ids[:, -1:]
position_ids = position_ids[:, -1:]
attention_mask = attention_mask[:, -1:, :]
print(input_ids.shape, position_ids.shape, attention_mask.shape)
logit2 = model(input_ids, position_ids, attention_mask, init_set, valid_length)
print(logit2.shape, logit2)
def test_batch_filling_sequence():
# set_context(mode=0, device_target="CPU")
set_context(mode=0, device_target="Ascend", device_id=0)
# args
args = set_args()
# model
params_dtype = mstype.float16 if args.fp16 else mstype.float32
glm = GLM130B(args, params_dtype=params_dtype)
glm.set_train(False)
model = Model(glm)
# init seq
seq = args.vocab_size * np.random.random((args.batch_size, 16))
seq = Tensor(seq, dtype=mstype.int32)
# strategy
strategy = BaseStrategy(args.batch_size)
start_time = time.time()
out = batch_filling_sequence(model, seq, strategy=strategy,
get_masks_and_position_ids=partial(
get_masks_and_position_ids,
mask_position=16,
max_gen_length=args.max_sequence_length - seq.shape[-1],
gmask=False,),)
print('final out', out.shape)
end_time = time.time()
print('infer cost time: ', end_time - start_time)
if __name__ == "__main__":
# test_glm_130b_ms()
test_batch_filling_sequence()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zyx5256/glm-ms.git
git@gitee.com:zyx5256/glm-ms.git
zyx5256
glm-ms
GLM-MS
master

Search

23e8dbc6 1850385 7e0993f3 1850385