1 Star 2 Fork 1

panzhihui/autotest

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
autotest.py 9.98 KB
一键复制 编辑 原始数据 按行查看 历史
panzhihui 提交于 2024-01-17 03:16 . Only import junitparser when needed
import sys, os
from datetime import datetime
import argparse
import subprocess
import itertools
import time
from multiprocessing import Process
from functools import reduce
PWD = os.getcwd()
DEVICES = ['Ascend', 'CPU', 'GPU']
ASCEND, CPU, GPU = DEVICES
MODES = ['PYNATIVE', 'GRAPH']
DEVICE_LOG_PATH='/root/ascend/log/debug'
device_id = os.getenv("DEVICE_ID")
device_id = 0 if not device_id else device_id
dlog_path = os.path.join(DEVICE_LOG_PATH, f'device-{device_id}')
timetag = None
def make_parser():
global timetag
timetag = datetime.now().strftime('%m%d%H%M%S')
parser = argparse.ArgumentParser()
parser.add_argument('--whl',
help='Install the wheel file with the provided path . Or download the file\
first if a URL is provided.')
parser.add_argument('--target',
type=int,
default=1,
help='Device target. 1=ascend, 2=cpu, 4=gpu. Use addition as combination')
parser.add_argument('--mode',
choices=['pynative', 'graph', 'all'],
default='pynative',
help='Mode.')
parser.add_argument('--mslog',
type=int,
default=None,
metavar="LEVEL",
help='Set MindSpore log level.')
parser.add_argument('--ascendlog',
type=int,
default=None,
metavar="LEVEL",
help='Set Ascend log level.')
parser.add_argument('--dlog',
action='store_true',
help='Print Ascend device error log to stdout. On by default.')
parser.add_argument('--no-dlog', dest='dlog', action='store_false')
parser.add_argument('--ge',
action='store_true',
help='GE mode on ascend.')
parser.add_argument('--ge-graph',
action='store_true',
help='store GE graph to folder "ge_graph".')
parser.add_argument('--xml',
type=str,
default=None,
metavar="DIR",
help='Dump JunitXML of test at given path.')
parser.add_argument('--xml-prefix',
type=str,
default="",
metavar="PREFIX",
help='Prepend to xml file name.')
parser.add_argument('--xml-suffix',
type=str,
default="",
metavar="SUFFIX",
help='Append to xml file name.')
parser.add_argument('--file',
type=str,
metavar="FILE",
help='Load test names from file.')
parser.add_argument('--dump',
type=str,
metavar="NET",
dest='dump_net',
default=None,
help='Dump network. All kernels are dumped in npy format.')
parser.add_argument('--dump-bin',
action='store_true',
help='dump bin format".')
return parser
def shell_run(cmd, *, check=True, timeout=None, shell=True):
if isinstance(cmd, list):
cmd = ' '.join(cmd)
try:
subprocess.run(cmd, check=check, timeout=timeout, shell=shell)
except subprocess.TimeoutExpired:
print(f"!Timeout for cmd: {cmd}")
def download_whl(url):
"""Download whl file to the working path. Delete local file if exists."""
whl_file = os.path.join(PWD, url.split('/')[-1])
if os.path.isfile(whl_file):
os.remove(whl_file)
shell_run(["wget --no-check-certificate", url])
def install_whl(whl):
if not whl:
return
if os.path.isfile(os.path.join(PWD, whl)):
whl_file = os.path.join(PWD, whl)
else:
download_whl(whl)
whl_file = os.path.join(PWD, whl.split('/')[-1])
shell_run([f"pip uninstall -y {whl_file} && pip install {whl_file}",
whl_file])
def print_dlog(pid):
dlog_name = None
pid = str(pid)
MAX_ITER = 20
for _ in range(MAX_ITER): # wait util log file being generated
files = os.listdir(dlog_path)
for file in files:
if pid in file:
dlog_name = file
break
if dlog_name: break
time.sleep(1)
if not dlog_name:
print(f"***************************************")
print(f"device log {dlog_name} not found")
print(f"***************************************")
return
dlog_file = os.path.join(dlog_path, dlog_name)
print(f"\ndeviec log: {dlog_file}\n")
shell_run(f'tail --pid {pid} -f {dlog_file} | grep "\[ERROR\]"')
def call_pytest(args):
if '--disable-warnings' not in args:
args.append('--disable-warnings')
if '-s' not in args:
args.append('-s')
return shell_run(["pytest"] + args, check=False, timeout=15*60)
def call_pytest_and_printing_dlog(pytest_args):
pid = os.getpid()
log_task = Process(target=print_dlog, args=(pid,))
log_task.start()
call_pytest(pytest_args)
log_task.terminate()
def decode_device(device_code):
bitmask = [device_code & (1<<i) for i in range(3)]
return itertools.compress(DEVICES, bitmask)
def setup_env(device, mode):
os.environ['CONTEXT_DEVICE_TARGET'] = device
os.environ['CONTEXT_MODE'] = mode
def setup_xml(device, mode, args, pytest_args, test:str=None):
'''Setup final xml file name and pass it to pytest.'''
xml_dir = os.path.abspath(args.xml)
if not os.path.isdir(xml_dir):
raise ValueError(f"{args.xml} is not a folder.")
xml = f"{device}_{mode}"
if args.ge:
xml += '_ge'
if args.xml_prefix:
xml = f"{args.xml_prefix}_{xml}"
if args.xml_suffix:
xml = f"{xml}_{args.xml_suffix}"
runtags.add(xml.lower())
if test:
testname = test[5:] if test.startswith("test_") else test
testname = testname.split('.')[0]
xml += '_' + testname
xml = f'.{timetag}_{xml.lower()}.xml'
xml = os.path.join(xml_dir, xml)
pytest_args.extend(['--junit-xml', xml])
def run_test(args, _pytest_args, test=None):
if not _pytest_args and not test:
return
if args.target > 7 or args.target < 0:
sys.exit(f"--target argument invalid:{args.target}")
devices = decode_device(args.target)
modes = MODES if args.mode == 'all' else [args.mode.upper()]
pytest_args = _pytest_args.copy()
if test:
pytest_args.append(test)
for device in devices:
for mode in modes:
print("******************************************************************")
print(f"device: {device}, mode: {mode}")
print("******************************************************************")
setup_env(device, mode)
if args.xml:
setup_xml(device, mode, args, pytest_args, test)
if (args.dlog and device == ASCEND):
call_pytest_and_printing_dlog(pytest_args)
else:
call_pytest(pytest_args)
def load_test_names(path):
'''Load test names from file.'''
def case_filter(case: str):
comment_prefix = '#'
return case and not case.startswith(comment_prefix)
with open(path, 'r', encoding='utf-8') as f:
cases = f.read().strip().split('\n')
return list(filter(case_filter, cases))
def run_tests(args, pytest_args):
if not args.file:
run_test(args, pytest_args)
return
fname = os.path.abspath(args.file)
if not os.path.isfile(fname):
raise ValueError("--cases-file doesn't refer to a valid file.")
with open(fname, 'r', encoding='utf-8') as f:
tests = load_test_names(fname)
for test in tests:
run_test(args, pytest_args, test)
merge_xml(args.xml)
def ms_log_setup(level):
if level is None:
return
if not 0<=level<=4:
raise ValueError(f"Not supported mindspore log level: {level}")
os.environ['GLOG_v'] = str(level)
def ascend_log_setup(level):
if level is None:
return
if not 0<=level<=4:
raise ValueError(f"Not supported ascend log level: {level}")
os.environ['ASCEND_SLOG_PRINT_TO_STDOUT'] = '1'
os.environ['ASCEND_GLOBAL_LOG_LEVEL'] = str(level)
def ge_setup(ge_mode, save_graph):
if ge_mode:
os.environ['MS_ENABLE_GE'] = '1'
os.environ['MS_GE_TRIAN'] = '1'
os.environ['MS_ENABLE_REF_MODE'] = '1'
os.environ['MS_DEV_FORCE_ACL'] = '1'
args.device = 1
if save_graph:
os.environ['DUMP_GRAPH_LEVEL'] = '2'
os.environ['DUMP_GE_GRAPH'] = '2'
os.environ['DUMP_GRAPH_PATH'] = 'ge_graphs'
def merge_xml(xml_dir):
'''Merge tests with same runtag.'''
if not xml_dir:
return
from junitparser import JUnitXml
xmls = os.listdir(xml_dir)
for runtag in runtags:
to_merge = []
prefix = f'.{timetag}_{runtag}'
for xml in xmls:
if xml.startswith(prefix):
fname = os.path.join(xml_dir, xml)
xml = JUnitXml.fromfile(fname)
os.remove(fname)
to_merge.append(xml)
merged = reduce(lambda x,y:x+y, to_merge)
merged.write(os.path.join(xml_dir, runtag+'.xml'))
def dump_setup(netname, dump_bin):
'''Setup dump config'''
import jinja2
if not netname:
return
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(os.path.dirname(__file__)),
autoescape=jinja2.select_autoescape(),
trim_blocks=True,
lstrip_blocks=True
)
template = env.get_template("dump_config_template.json")
dump_dir = os.path.join(os.getcwd(), "dump")
if not os.path.isdir(dump_dir):
os.mkdir(dump_dir)
format = "bin" if dump_bin else "npy"
json = template.render(path=dump_dir, netname=netname, format=format)
json_path = os.path.join(os.getcwd(), "dump_config.json")
with open(json_path, 'w', encoding='utf-8') as f:
f.write(json)
os.environ['MINDSPORE_DUMP_CONFIG']=json_path
if __name__ == "__main__":
parser = make_parser()
args, pytest_args = parser.parse_known_args()
runtags = set()
install_whl(args.whl)
ms_log_setup(args.mslog)
ascend_log_setup(args.ascendlog)
ge_setup(args.ge, args.ge_graph)
dump_setup(args.dump_net, args.dump_bin)
run_tests(args, pytest_args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/panzhihui1/autotest.git
git@gitee.com:panzhihui1/autotest.git
panzhihui1
autotest
autotest
master

搜索帮助