代码拉取完成,页面将自动刷新
import os
import keras
import pytest
def pytest_addoption(parser):
parser.addoption(
"--run_large",
action="store_true",
default=False,
help="run large tests",
)
parser.addoption(
"--run_extra_large",
action="store_true",
default=False,
help="run extra_large tests",
)
parser.addoption(
"--docstring_module",
action="store",
default="",
help="restrict docs testing to modules whose name matches this flag",
)
parser.addoption(
"--check_gpu",
action="store_true",
default=False,
help="fail if a gpu is not present",
)
def pytest_configure(config):
# Verify that device has GPU and detected by backend
if config.getoption("--check_gpu"):
found_gpu = False
backend = keras.config.backend()
if backend == "jax":
import jax
try:
found_gpu = bool(jax.devices("gpu"))
except RuntimeError:
found_gpu = False
elif backend == "tensorflow":
import tensorflow as tf
found_gpu = bool(tf.config.list_logical_devices("GPU"))
elif backend == "torch":
import torch
found_gpu = bool(torch.cuda.device_count())
if not found_gpu:
pytest.fail(f"No GPUs discovered on the {backend} backend.")
config.addinivalue_line(
"markers",
"large: mark test as being slow or requiring a network",
)
config.addinivalue_line(
"markers",
"extra_large: mark test as being too large to run continuously",
)
config.addinivalue_line(
"markers",
"tf_only: mark test as a tf only test",
)
config.addinivalue_line(
"markers",
"kaggle_key_required: mark test needing a kaggle key",
)
def pytest_collection_modifyitems(config, items):
run_extra_large_tests = config.getoption("--run_extra_large")
# Run large tests for --run_extra_large or --run_large.
run_large_tests = config.getoption("--run_large") or run_extra_large_tests
# Messages to annotate skipped tests with.
skip_large = pytest.mark.skipif(
not run_large_tests,
reason="need --run_large option to run",
)
skip_extra_large = pytest.mark.skipif(
not run_extra_large_tests,
reason="need --run_extra_large option to run",
)
tf_only = pytest.mark.skipif(
not keras.config.backend() == "tensorflow",
reason="tests only run on tf backend",
)
found_kaggle_key = all(
[
os.environ.get("KAGGLE_USERNAME", None),
os.environ.get("KAGGLE_KEY", None),
]
)
kaggle_key_required = pytest.mark.skipif(
not found_kaggle_key,
reason="tests only run with a kaggle api key",
)
for item in items:
if "large" in item.keywords:
item.add_marker(skip_large)
if "extra_large" in item.keywords:
item.add_marker(skip_extra_large)
if "tf_only" in item.keywords:
item.add_marker(tf_only)
if "kaggle_key_required" in item.keywords:
item.add_marker(kaggle_key_required)
# Disable traceback filtering for quicker debugging of tests failures.
keras.config.disable_traceback_filtering()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。