代码拉取完成,页面将自动刷新
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import tempfile
from typing import List, Tuple
import torch
import numpy as np
import onnx
from onnx import shape_inference, numpy_helper
import onnx_graphsurgeon as gs
from polygraphy.backend.onnx.loader import fold_constants
from modules import sd_hijack, sd_unet
from datastructures import ProfileSettings
class UNetModel(torch.nn.Module):
def __init__(
self, unet, embedding_dim: int, text_minlen: int = 77, is_xl: bool = False
) -> None:
super().__init__()
self.unet = unet
self.is_xl = is_xl
self.text_minlen = text_minlen
self.embedding_dim = embedding_dim
self.num_xl_classes = 2816 # Magic number for num_classes
self.emb_chn = 1280
self.in_channels = self.unet.in_channels
self.dyn_axes = {
"sample": {0: "2B", 2: "H", 3: "W"},
"encoder_hidden_states": {0: "2B", 1: "77N"},
"timesteps": {0: "2B"},
"latent": {0: "2B", 2: "H", 3: "W"},
"y": {0: "2B"},
}
def apply_torch_model(self):
def disable_checkpoint(self):
if getattr(self, "use_checkpoint", False) == True:
self.use_checkpoint = False
if getattr(self, "checkpoint", False) == True:
self.checkpoint = False
self.unet.apply(disable_checkpoint)
self.set_unet("None")
def set_unet(self, ckpt: str):
# TODO test if using this with TRT works
sd_unet.apply_unet(ckpt)
sd_hijack.model_hijack.apply_optimizations(ckpt)
def get_input_names(self) -> List[str]:
names = ["sample", "timesteps", "encoder_hidden_states"]
if self.is_xl:
names.append("y")
return names
def get_output_names(self) -> List[str]:
return ["latent"]
def get_dynamic_axes(self) -> dict:
io_names = self.get_input_names() + self.get_output_names()
dyn_axes = {name: self.dyn_axes[name] for name in io_names}
return dyn_axes
def get_sample_input(
self,
batch_size: int,
latent_height: int,
latent_width: int,
text_len: int,
device: str = "cuda",
dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor]:
return (
torch.randn(
batch_size,
self.in_channels,
latent_height,
latent_width,
dtype=dtype,
device=device,
),
torch.randn(batch_size, dtype=dtype, device=device),
torch.randn(
batch_size,
text_len,
self.embedding_dim,
dtype=dtype,
device=device,
),
torch.randn(batch_size, self.num_xl_classes, dtype=dtype, device=device)
if self.is_xl
else None,
)
def get_input_profile(self, profile: ProfileSettings) -> dict:
min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim()
(
min_latent_height,
latent_height,
max_latent_height,
min_latent_width,
latent_width,
max_latent_width,
) = profile.get_latent_dim()
shape_dict = {
"sample": [
(min_batch, self.unet.in_channels, min_latent_height, min_latent_width),
(opt_batch, self.unet.in_channels, latent_height, latent_width),
(max_batch, self.unet.in_channels, max_latent_height, max_latent_width),
],
"timesteps": [(min_batch,), (opt_batch,), (max_batch,)],
"encoder_hidden_states": [
(min_batch, profile.t_min, self.embedding_dim),
(opt_batch, profile.t_opt, self.embedding_dim),
(max_batch, profile.t_max, self.embedding_dim),
],
}
if self.is_xl:
shape_dict["y"] = [
(min_batch, self.num_xl_classes),
(opt_batch, self.num_xl_classes),
(max_batch, self.num_xl_classes),
]
return shape_dict
# Helper utility for weights map
def export_weights_map(self, onnx_opt_path: str, weights_map_path: dict):
onnx_opt_dir = onnx_opt_path
state_dict = self.unet.state_dict()
onnx_opt_model = onnx.load(onnx_opt_path)
# Create initializer data hashes
def init_hash_map(onnx_opt_model):
initializer_hash_mapping = {}
for initializer in onnx_opt_model.graph.initializer:
initializer_data = numpy_helper.to_array(
initializer, base_dir=onnx_opt_dir
).astype(np.float16)
initializer_hash = hash(initializer_data.data.tobytes())
initializer_hash_mapping[initializer.name] = (
initializer_hash,
initializer_data.shape,
)
return initializer_hash_mapping
initializer_hash_mapping = init_hash_map(onnx_opt_model)
weights_name_mapping = {}
weights_shape_mapping = {}
# set to keep track of initializers already added to the name_mapping dict
initializers_mapped = set()
for wt_name, wt in state_dict.items():
# get weight hash
wt = wt.cpu().detach().numpy().astype(np.float16)
wt_hash = hash(wt.data.tobytes())
wt_t_hash = hash(np.transpose(wt).data.tobytes())
for initializer_name, (
initializer_hash,
initializer_shape,
) in initializer_hash_mapping.items():
# Due to constant folding, some weights are transposed during export
# To account for the transpose op, we compare the initializer hash to the
# hash for the weight and its transpose
if wt_hash == initializer_hash or wt_t_hash == initializer_hash:
# The assert below ensures there is a 1:1 mapping between
# PyTorch and ONNX weight names. It can be removed in cases where 1:many
# mapping is found and name_mapping[wt_name] = list()
assert initializer_name not in initializers_mapped
weights_name_mapping[wt_name] = initializer_name
initializers_mapped.add(initializer_name)
is_transpose = False if wt_hash == initializer_hash else True
weights_shape_mapping[wt_name] = (
initializer_shape,
is_transpose,
)
# Sanity check: Were any weights not matched
if wt_name not in weights_name_mapping:
print(
f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer"
)
print(
f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers"
)
assert weights_name_mapping.keys() == weights_shape_mapping.keys()
with open(weights_map_path, "w") as fp:
json.dump([weights_name_mapping, weights_shape_mapping], fp)
@staticmethod
def optimize(name, onnx_graph, verbose=False):
opt = Optimizer(onnx_graph, verbose=verbose)
opt.info(name + ": original")
opt.cleanup()
opt.info(name + ": cleanup")
opt.fold_constants()
opt.info(name + ": fold constants")
opt.infer_shapes()
opt.info(name + ": shape inference")
onnx_opt_graph = opt.cleanup(return_onnx=True)
opt.info(name + ": finished")
return onnx_opt_graph
class Optimizer:
def __init__(self, onnx_graph, verbose=False):
self.graph = gs.import_onnx(onnx_graph)
self.verbose = verbose
def info(self, prefix):
if self.verbose:
print(
f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs"
)
def cleanup(self, return_onnx=False):
self.graph.cleanup().toposort()
if return_onnx:
return gs.export_onnx(self.graph)
def select_outputs(self, keep, names=None):
self.graph.outputs = [self.graph.outputs[o] for o in keep]
if names:
for i, name in enumerate(names):
self.graph.outputs[i].name = name
def fold_constants(self, return_onnx=False):
onnx_graph = fold_constants(
gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True
)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def infer_shapes(self, return_onnx=False):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
temp_dir = tempfile.TemporaryDirectory().name
os.makedirs(temp_dir, exist_ok=True)
onnx_orig_path = os.path.join(temp_dir, "model.onnx")
onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx")
onnx.save_model(
onnx_graph,
onnx_orig_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
convert_attribute=False,
)
onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path)
onnx_graph = onnx.load(onnx_inferred_path)
else:
onnx_graph = shape_inference.infer_shapes(onnx_graph)
self.graph = gs.import_onnx(onnx_graph)
if return_onnx:
return onnx_graph
def clip_add_hidden_states(self, return_onnx=False):
hidden_layers = -1
onnx_graph = gs.export_onnx(self.graph)
for i in range(len(onnx_graph.graph.node)):
for j in range(len(onnx_graph.graph.node[i].output)):
name = onnx_graph.graph.node[i].output[j]
if "layers" in name:
hidden_layers = max(
int(name.split(".")[1].split("/")[0]), hidden_layers
)
for i in range(len(onnx_graph.graph.node)):
for j in range(len(onnx_graph.graph.node[i].output)):
if onnx_graph.graph.node[i].output[
j
] == "/text_model/encoder/layers.{}/Add_1_output_0".format(
hidden_layers - 1
):
onnx_graph.graph.node[i].output[j] = "hidden_states"
for j in range(len(onnx_graph.graph.node[i].input)):
if onnx_graph.graph.node[i].input[
j
] == "/text_model/encoder/layers.{}/Add_1_output_0".format(
hidden_layers - 1
):
onnx_graph.graph.node[i].input[j] = "hidden_states"
if return_onnx:
return onnx_graph
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。