From 0a3ae3edb8f4ab31cece2ec2cb1c3b27f5f51fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=95=85?= Date: Wed, 4 Dec 2024 15:52:36 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[bugfix]=20m4/m5=E6=8E=A5=E5=85=A5=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BB=93=E7=B2=BE=E5=BA=A6=E5=BC=82=E5=B8=B8=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../llm_ptq/anti_outlier/anti_outlier.py | 6 +++--- .../llm_ptq/llm_ptq_tools/quant_tools.py | 19 ++++++++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/msmodelslim/msmodelslim/pytorch/llm_ptq/anti_outlier/anti_outlier.py b/msmodelslim/msmodelslim/pytorch/llm_ptq/anti_outlier/anti_outlier.py index fa46e1314..e81d0289a 100644 --- a/msmodelslim/msmodelslim/pytorch/llm_ptq/anti_outlier/anti_outlier.py +++ b/msmodelslim/msmodelslim/pytorch/llm_ptq/anti_outlier/anti_outlier.py @@ -234,7 +234,7 @@ class AntiOutlier(object): self.org_model = model # 保存anti_outlier处理前的原始权重,为避免显存的额外占用,原始权重放在内存上 - if self.cfg.anti_method != "m6": + if self.cfg.anti_method not in ['m4', 'm5', 'm6']: states_dic = {} for key, value in self.org_model.state_dict().items(): states_dic[key] = copy.deepcopy(value).to('cpu') @@ -265,10 +265,10 @@ class AntiOutlier(object): raise Exception("Please check your config, model and input!", e) from e # 保存anti_outlier处理前的原始权重,作为属性存入model中 - if self.cfg.anti_method != "m6": + if self.cfg.anti_method not in ['m4', 'm5', 'm6']: setattr(self.model, 'ori_state_dict', states_dic) else: - setattr(self.model, 'anti_method', 'm6') + setattr(self.model, 'anti_method', self.cfg.anti_method) def init_dag(self): dummy_input = input_to_cpu(self.calib_data[0][0]) diff --git a/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py b/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py index a496e7585..7b0b1cb33 100644 --- a/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py +++ b/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py @@ -506,7 +506,16 @@ class Calibrator(object): if self.cfg.use_fa_quant: for attention_module_name in self.fa_module_param_dict: self.set_fa_quant_safetensor(attention_module_name, safetensor_weight) + + if not hasattr(self.model, 'ori_state_dict'): + keys_to_delete = [key for key in safetensor_weight.keys() if 'module.weight' in key] + for key in keys_to_delete: + del safetensor_weight[key] + keys_to_delete = [key for key in self.quant_model_json_description.quant_model_description.keys() if 'module.weight' in key] + for key in keys_to_delete: + del self.quant_model_json_description.quant_model_description[key] + for key, item in safetensor_weight.items(): safetensor_weight[key] = item.cpu().contiguous() @@ -685,9 +694,13 @@ class Calibrator(object): anti_norm_bias = module.bias.cpu() anti_norm_name_weight = name + '.module.weight' anti_norm_name_bias = name + '.module.bias' - self.quant_param_dict[anti_norm_name_weight] = anti_norm_weight.clone().detach() - self.quant_param_dict[anti_norm_name_bias] = anti_norm_bias.clone().detach() - self.quantized_module_param_dict[name + '.weight'] = [anti_norm_name_weight, anti_norm_name_bias] + if not hasattr(self.model, 'ori_state_dict'): + self.quant_param_dict[name + '.weight'] = anti_norm_weight.clone().detach() + self.quantized_module_param_dict[anti_norm_name_weight] = [name + '.weight'] + else: + self.quant_param_dict[anti_norm_name_weight] = anti_norm_weight.clone().detach() + self.quant_param_dict[anti_norm_name_bias] = anti_norm_bias.clone().detach() + self.quantized_module_param_dict[name + '.weight'] = [anti_norm_name_weight, anti_norm_name_bias] # 处理Linear、以及附属scale、offset等params if isinstance(module, (LinearQuantizer, LinearSparseQuantizer, LowBitLinearQuantizer)): -- Gitee From 13d4caef4f681dcd900de27a7444e274c41351f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=95=85?= Date: Wed, 4 Dec 2024 16:01:04 +0800 Subject: [PATCH 2/2] =?UTF-8?q?m4/m5=20=E4=B8=8Em6=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E4=BA=92=E4=B8=8D=E5=BD=B1=E5=93=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py b/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py index 7b0b1cb33..d2cc8011b 100644 --- a/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py +++ b/msmodelslim/msmodelslim/pytorch/llm_ptq/llm_ptq_tools/quant_tools.py @@ -507,7 +507,7 @@ class Calibrator(object): for attention_module_name in self.fa_module_param_dict: self.set_fa_quant_safetensor(attention_module_name, safetensor_weight) - if not hasattr(self.model, 'ori_state_dict'): + if hasattr(self.model, "anti_method") and self.model.anti_method in ['m4', 'm5']: keys_to_delete = [key for key in safetensor_weight.keys() if 'module.weight' in key] for key in keys_to_delete: del safetensor_weight[key] @@ -694,7 +694,7 @@ class Calibrator(object): anti_norm_bias = module.bias.cpu() anti_norm_name_weight = name + '.module.weight' anti_norm_name_bias = name + '.module.bias' - if not hasattr(self.model, 'ori_state_dict'): + if hasattr(self.model, "anti_method") and self.model.anti_method in ['m4', 'm5']: self.quant_param_dict[name + '.weight'] = anti_norm_weight.clone().detach() self.quantized_module_param_dict[anti_norm_name_weight] = [name + '.weight'] else: -- Gitee