1 Star 3 Fork 8

kangchi/pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
aten.bzl 2.94 KB
一键复制 编辑 原始数据 按行查看 历史
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@rules_cc//cc:defs.bzl", "cc_library")
CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
CAPABILITY_COMPILER_FLAGS = {
"AVX2": ["-mavx2", "-mfma"],
"DEFAULT": [],
}
PREFIX = "aten/src/ATen/native/"
EXTRA_PREFIX = "aten/src/ATen/"
def intern_build_aten_ops(copts, deps, extra_impls):
for cpu_capability in CPU_CAPABILITY_NAMES:
srcs = []
for impl in native.glob(
[
PREFIX + "cpu/*.cpp",
PREFIX + "quantized/cpu/kernels/*.cpp",
],
):
name = impl.replace(PREFIX, "")
out = PREFIX + name + "." + cpu_capability + ".cpp"
native.genrule(
name = name + "_" + cpu_capability + "_cp",
srcs = [impl],
outs = [out],
cmd = "cp $< $@",
)
srcs.append(out)
for impl in extra_impls:
name = impl.replace(EXTRA_PREFIX, "")
out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
native.genrule(
name = name + "_" + cpu_capability + "_cp",
srcs = [impl],
outs = [out],
cmd = "cp $< $@",
)
srcs.append(out)
cc_library(
name = "ATen_CPU_" + cpu_capability,
srcs = srcs,
copts = copts + [
"-DCPU_CAPABILITY=" + cpu_capability,
"-DCPU_CAPABILITY_" + cpu_capability,
] + CAPABILITY_COMPILER_FLAGS[cpu_capability],
deps = deps,
linkstatic = 1,
)
cc_library(
name = "ATen_CPU",
deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES],
linkstatic = 1,
)
def generate_aten_impl(ctx):
# Declare the entire ATen/ops/ directory as an output
ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops")
outputs = [ops_dir] + ctx.outputs.outs
install_dir = paths.dirname(ops_dir.path)
tool_inputs, tool_inputs_manifest = ctx.resolve_tools(tools = [ctx.attr.generator])
ctx.actions.run_shell(
outputs = outputs,
inputs = ctx.files.srcs,
command = ctx.executable.generator.path + " $@",
arguments = [
"--source-path",
"aten/src/ATen",
"--per-operator-headers",
"--install_dir",
install_dir,
],
tools = tool_inputs,
input_manifests = tool_inputs_manifest,
use_default_shell_env = True,
mnemonic = "GenerateAten",
)
return [DefaultInfo(files = depset(outputs))]
generate_aten = rule(
implementation = generate_aten_impl,
attrs = {
"generator": attr.label(
executable = True,
allow_files = True,
mandatory = True,
cfg = "exec",
),
"outs": attr.output_list(),
"srcs": attr.label_list(allow_files = True),
},
)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/kangchi/pytorch.git
git@gitee.com:kangchi/pytorch.git
kangchi
pytorch
pytorch
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385