diff --git a/MindChemistry/applications/cdvae/README.md b/MindChemistry/applications/cdvae/README.md new file mode 100644 index 0000000000000000000000000000000000000000..19a526608e36fd6e8c0a4e1631743457419e8d26 --- /dev/null +++ b/MindChemistry/applications/cdvae/README.md @@ -0,0 +1,123 @@ +# 模型名称 + +> CDVAE + +## 介绍 + +> Crystal Diffusion Variational AutoEncoder (CDVAE)是用来生成材料的周期性结构的SOTA模型,相关论文已发表在ICLR上。其主要可以实现以下三个任务: + +1. Reconstruction: 根据输入的结构重建一个稳定的晶体结构。 +2. Generation: 随机生成一个稳定的晶体结构。 +3. Optimization: 通过优化材料特定性质来生成晶体结构。 + +## 数据集 + +> 提供了三个数据集: + +1. Perov_5 (Castelli et al., 2012): 包含接近19000个钙钛矿晶体结构,结构相似,但是组成不同,下载地址:[Perov_5](https://figshare.com/articles/dataset/Perov5/22705189)。 +2. Carbon_24 (Pickard, 2020): 包含10000个仅包含碳原子的晶体结构,因此其具有相同的组成,但是结构不同,下载地址:[Carbon_24](https://figshare.com/articles/dataset/Carbon24/22705192)。 +3. MP_20(Jain et al., 2013): 包含有45000个无机材料结构,包含绝大多数小于单胞小于20个原子的实验已知材料,下载地址:[mp_20](https://figshare.com/articles/dataset/mp_20/25563693)。 + +前两个数据集下载后直接放在./data目录下即可。MP_20数据集下载后运行`python ./cdvae/dataloader/mp_20_process.py --init_path ./data/mp_20.json --data_path ./data/mp_20`, 其中 init_path是下载得到的json格式数据集的位置,而data_path是dataset存放的位置。 + +## 环境要求 + +> 1. 安装`pip install -r requirements.txt` + +## 快速入门 + +> 训练命令: `python train.py --dataset 'perov_5'` + +## 脚本说明 + +### 代码目录结构 + +```txt +└─cdvae + │ README.md README文件 + │ train.py 训练启动脚本 + │ evaluation.py 推理启动脚本 + │ compute_metrics.py 评估结果脚本 + │ create_dataset.py 生成数据集 + │ mp_20_process.py 对mp_20数据集预处理 + │ + └─src + │ evaluate_utils.py 推理结果生成 + │ metrics_utils.py 评估结果计算 + │ + └─dataloader 数据处理 + │ dataloader.py 将数据集加载到网络 + │ mp_20_process.py 对mp_20数据集预处理 + └─conf 参数配置 + │ config.yaml 网络参数 + └─data 数据集参数 +``` + +## 训练过程 + +### 训练 + +直接训练 + +```txt +python train.py --dataset 'perov_5' (可以将perov_5替换为其他数据集的名称) +python train.py --dataset 'perov_5' --device_id 7(通过使用device_id来指定使用的NPU) +python train.py --dataset 'perov_5' --num_samples_train 300(第一次训练某个数据集时会自动生成数据集文件,通过设置num_samples_train/val/test控制数据集大小,设置为-1创建完整数据集) +python train.py --dataset 'perov_5' --name_ckpt './loss/loss.ckpt'(指定ckpt的保存位置和device_id) +``` + +训练过程日志 + +```log +INFO:Creating dataset...... +100%|████████████████████████████████████████████████████████████████████████| 10000/10000 [00:40<00:00, 246.31it/s] +100%|████████████████████████████████████████████████████████████████████████████| 300/300 [00:01<00:00, 178.99it/s] +100%|████████████████████████████████████████████████████████████████████████████| 300/300 [00:01<00:00, 172.23it/s] +. +. +. +INFO:Train Epoch: 0 [0] Loss: 112.65778, time_step: 94.30673432350159 +INFO:Train Epoch: 0 [10] Loss: 67.35183, time_step: 0.5614199638366699 +INFO:Train Epoch: 0 [20] Loss: 62.0342, time_step: 0.5528759956359863 +INFO:Train Epoch: 0 [30] Loss: 54.571114, time_step: 0.5621426105499268 +INFO:Train Epoch: 0 [40] Loss: 50.179485, time_step: 0.5850610733032227 +INFO:Train Epoch: 0 [50] Loss: 49.987503, time_step: 0.622471809387207 +INFO:Train Epoch: 0 [60] Loss: 40.30918, time_step: 0.5863668918609619 +INFO:Train Epoch: 0 [70] Loss: 43.979298, time_step: 0.5797436237335205 +INFO:Val Loss: 36.698123931884766 +INFO:Updata best acc: 36.698124 +. +. +. +INFO:Finished Training +``` + +## 推理评估过程 + +### 推理过程 + +```txt +1.将权重checkpoint文件保存至 `/loss/`目录下(默认读取目录) +2.执行推理脚本:reconstruction任务: + python evaluation.py --dataset perov_5 --tasks 'recon' (指定dataset为perov_5) + generation任务: + python evaluation.py --dataset perov_5 --tasks 'gen' + optimization任务(如需使用optimization,在训练时请在configs.yaml中将predict_property设置为True): + python evaluation.py --dataset perov_5 --tasks 'opt' + 如需指定ckpt路径和device id:python evaluation.py --dataset perov_5 --tasks 'opt' --model_path './loss/loss.ckpt' --device_id 1 +``` + +推理结果 + +```txt +可以在`/eval_result/`路径下找到推理的输出文件。 +reconstruction的输出文件为eval_recon.npy和gt_recon.npy,分别包含了reconstruction后的晶体结构信息以及作为ground truth的晶体结构信息; +generation的输出文件为eval_gen.npy,包含了随机生成结果的晶体结构信息; +optimization的输出文件为eval_opt.npy,包含了基于特定性质优化的晶体结构信息。 +``` + +### 结果评估 + +```txt +运行 python comput_metrics.py --eval_path './eval_result' --dataset 'perov_5' --task recon, 结果会保存在./eval_path文件夹下的eval_metrics.json文件中(目前支持recon和generation两种模式) +``` \ No newline at end of file diff --git a/MindChemistry/applications/cdvae/compute_metrics.py b/MindChemistry/applications/cdvae/compute_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..78bfe29d192854a9b5297a4968f865b74d0884f2 --- /dev/null +++ b/MindChemistry/applications/cdvae/compute_metrics.py @@ -0,0 +1,323 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Compute metrics +""" +from collections import Counter +import logging +import argparse +import os +import json +from pathlib import Path + +import numpy as np +from tqdm import tqdm +from p_tqdm import p_map +from scipy.stats import wasserstein_distance +from pymatgen.core.structure import Structure +from pymatgen.core.composition import Composition +from pymatgen.core.lattice import Lattice +from pymatgen.analysis.structure_matcher import StructureMatcher +from matminer.featurizers.site.fingerprint import CrystalNNFingerprint +from matminer.featurizers.composition.composite import ElementProperty +from mindchemistry.cell.cdvae.data_utils import StandardScaler +from src.metrics_utils import ( + smact_validity, structure_validity, get_fp_pdist, + get_crystals_list, compute_cov) + +CRYSTALNNFP = CrystalNNFingerprint.from_preset("ops") +COMPFP = ElementProperty.from_preset("magpie") + +COV_CUTOFFS = { + "mp_20": {"struct": 0.4, "comp": 10.}, + "carbon_24": {"struct": 0.2, "comp": 4.}, + "perov_5": {"struct": 0.2, "comp": 4}, +} +# threshold for coverage metrics, olny struct distance and comp distance +# smaller than the threshold will be counted as covered. + + +class Crystal(): + """get crystal structures""" + + def __init__(self, crys_array_dict): + self.frac_coords = crys_array_dict["frac_coords"] + self.atom_types = crys_array_dict["atom_types"] + self.lengths = crys_array_dict["lengths"] + self.angles = crys_array_dict["angles"] + self.dict = crys_array_dict + + self.get_structure() + self.get_composition() + self.get_validity() + self.get_fingerprints() + + def get_structure(self): + """get structure""" + if min(self.lengths.tolist()) < 0: + self.constructed = False + self.invalid_reason = "non_positive_lattice" + else: + try: + self.structure = Structure( + lattice=Lattice.from_parameters( + *(self.lengths.tolist() + self.angles.tolist())), + species=self.atom_types, coords=self.frac_coords, coords_are_cartesian=False) + self.constructed = True + except (ValueError, AttributeError, TypeError): + self.constructed = False + self.invalid_reason = "construction_raises_exception" + if self.structure.volume < 0.1: + self.constructed = False + self.invalid_reason = "unrealistically_small_lattice" + + def get_composition(self): + elem_counter = Counter(self.atom_types) + composition = [(elem, elem_counter[elem]) + for elem in sorted(elem_counter.keys())] + elems, counts = list(zip(*composition)) + counts = np.array(counts) + counts = counts / np.gcd.reduce(counts) + self.elems = elems + self.comps = tuple(counts.astype("int").tolist()) + + def get_validity(self): + self.comp_valid = smact_validity(self.elems, self.comps) + if self.constructed: + self.struct_valid = structure_validity(self.structure) + else: + self.struct_valid = False + self.valid = self.comp_valid and self.struct_valid + + def get_fingerprints(self): + """get fingerprints""" + elem_counter = Counter(self.atom_types) + comp = Composition(elem_counter) + self.comp_fp = COMPFP.featurize(comp) + try: + site_fps = [CRYSTALNNFP.featurize( + self.structure, i) for i in range(len(self.structure))] + except (ValueError, AttributeError, TypeError): + # counts crystal as invalid if fingerprint cannot be constructed. + self.valid = False + self.comp_fp = None + self.struct_fp = None + return + self.struct_fp = np.array(site_fps).mean(axis=0) + + +class RecEval(): + """reconstruction eval""" + + def __init__(self, pred_crys, gt_crys, stol=0.5, angle_tol=10, ltol=0.3): + assert len(pred_crys) == len(gt_crys) + self.matcher = StructureMatcher( + stol=stol, angle_tol=angle_tol, ltol=ltol) + self.preds = pred_crys + self.gts = gt_crys + + def get_match_rate_and_rms(self): + """get match rate and rms""" + def process_one(pred, gt, is_valid): + if not is_valid: + return None + try: + rms_dist = self.matcher.get_rms_dist( + pred.structure, gt.structure) + rms_dist = None if rms_dist is None else rms_dist[0] + return rms_dist + except (ValueError, AttributeError, TypeError): + return None + validity = [c.valid for c in self.preds] + + rms_dists = [] + for i in tqdm(range(len(self.preds))): + rms_dists.append(process_one( + self.preds[i], self.gts[i], validity[i])) + rms_dists = np.array(rms_dists) + match_rate = sum(x is not None for x in rms_dists) / len(self.preds) + mean_rms_dist = np.array( + [x for x in rms_dists if x is not None]).mean() + return {"match_rate": match_rate, + "rms_dist": mean_rms_dist} + + def get_metrics(self): + return self.get_match_rate_and_rms() + + +class GenEval(): + """Gen Eval""" + + def __init__(self, pred_crys, gt_crys, comp_scaler, n_samples=10, eval_model_name=None): + self.crys = pred_crys + self.gt_crys = gt_crys + self.n_samples = n_samples + self.eval_model_name = eval_model_name + self.comp_scaler = comp_scaler + + valid_crys = [c for c in pred_crys if c.valid] + if len(valid_crys) >= n_samples: + sampled_indices = np.random.choice( + len(valid_crys), n_samples, replace=False) + self.valid_samples = [valid_crys[i] for i in sampled_indices] + else: + raise Exception( + f"not enough valid crystals in the predicted set: {len(valid_crys)}/{n_samples}") + + def get_validity(self): + comp_valid = np.array([c.comp_valid for c in self.crys]).mean() + struct_valid = np.array([c.struct_valid for c in self.crys]).mean() + valid = np.array([c.valid for c in self.crys]).mean() + return {"comp_valid": comp_valid, + "struct_valid": struct_valid, + "valid": valid} + + def get_comp_diversity(self): + comp_fps = [c.comp_fp for c in self.valid_samples] + comp_fps = self.comp_scaler.transform(comp_fps) + comp_div = get_fp_pdist(comp_fps) + return {"comp_div": comp_div} + + def get_struct_diversity(self): + return {"struct_div": get_fp_pdist([c.struct_fp for c in self.valid_samples])} + + def get_density_wdist(self): + pred_densities = [c.structure.density for c in self.valid_samples] + gt_densities = [c.structure.density for c in self.gt_crys] + wdist_density = wasserstein_distance(pred_densities, gt_densities) + return {"wdist_density": wdist_density} + + def get_num_elem_wdist(self): + pred_nelems = [len(set(c.structure.species)) + for c in self.valid_samples] + gt_nelems = [len(set(c.structure.species)) for c in self.gt_crys] + wdist_num_elems = wasserstein_distance(pred_nelems, gt_nelems) + return {"wdist_num_elems": wdist_num_elems} + + def get_coverage(self): + cutoff_dict = COV_CUTOFFS[self.eval_model_name] + (cov_metrics_dict, _) = compute_cov( + self.crys, self.gt_crys, self.comp_scaler, + struc_cutoff=cutoff_dict["struct"], + comp_cutoff=cutoff_dict["comp"]) + return cov_metrics_dict + + def get_metrics(self): + metrics = {} + metrics.update(self.get_validity()) + metrics.update(self.get_comp_diversity()) + metrics.update(self.get_struct_diversity()) + metrics.update(self.get_density_wdist()) + metrics.update(self.get_num_elem_wdist()) + print(metrics) + metrics.update(self.get_coverage()) + return metrics + + +def get_crystal_array_list(data, gt_data=None, ground_truth=False): + """get crystal array list""" + crys_array_list = get_crystals_list( + np.concatenate(data["frac_coords"], axis=1).squeeze(0), + np.concatenate(data["atom_types"], axis=1).squeeze(0), + np.concatenate(data["lengths"], axis=1).squeeze(0), + np.concatenate(data["angles"], axis=1).squeeze(0), + np.concatenate(data["num_atoms"], axis=1).squeeze(0)) + + # if "input_data_batch" in data: + if ground_truth: + true_crystal_array_list = get_crystals_list( + np.concatenate(gt_data["frac_coords"], axis=0).squeeze(), + np.concatenate(gt_data["atom_types"], axis=0).squeeze(), + np.concatenate(gt_data["lengths"], + axis=0).squeeze().reshape(-1, 3), + np.concatenate(gt_data["angles"], axis=0).squeeze().reshape(-1, 3), + np.concatenate(gt_data["num_atoms"], axis=0).squeeze()) + else: + true_crystal_array_list = None + + return crys_array_list, true_crystal_array_list + + +def main(args): + all_metrics = {} + eval_model_name = args.dataset + + if "recon" in args.tasks: + out_data = np.load(args.eval_path+"/eval_recon.npy", + allow_pickle=True).item() + gt_data = np.load(args.eval_path+"/gt_recon.npy", + allow_pickle=True).item() + crys_array_list, true_crystal_array_list = get_crystal_array_list( + out_data, gt_data, ground_truth=True) + pred_crys = p_map(Crystal, crys_array_list) + gt_crys = p_map(Crystal, true_crystal_array_list) + + rec_evaluator = RecEval(pred_crys, gt_crys) + recon_metrics = rec_evaluator.get_metrics() + all_metrics.update(recon_metrics) + + if "gen" in args.tasks: + out_data = np.load(args.eval_path+"/eval_gen.npy", + allow_pickle=True).item() + gt_data = np.load(args.eval_path+"/gt_recon.npy", + allow_pickle=True).item() + crys_array_list, true_crystal_array_list = get_crystal_array_list( + out_data, gt_data, ground_truth=True) + + gen_crys = p_map(Crystal, crys_array_list) + gt_crys = p_map(Crystal, true_crystal_array_list) + gt_comp_fps = [c.comp_fp for c in gt_crys] + gt_fp_np = np.array(gt_comp_fps) + comp_scaler = StandardScaler(replace_nan_token=0.) + comp_scaler.fit(gt_fp_np) + + gen_evaluator = GenEval( + gen_crys, gt_crys, comp_scaler, eval_model_name=eval_model_name) + gen_metrics = gen_evaluator.get_metrics() + all_metrics.update(gen_metrics) + + logging.info(all_metrics) + + if args.label == "": + metrics_out_file = "eval_metrics.json" + else: + metrics_out_file = f"eval_metrics_{args.label}.json" + metrics_out_file = os.path.join(args.eval_path, metrics_out_file) + + # only overwrite metrics computed in the new run. + if Path(metrics_out_file).exists(): + with open(metrics_out_file, "r") as f: + written_metrics = json.load(f) + if isinstance(written_metrics, dict): + written_metrics.update(all_metrics) + else: + with open(metrics_out_file, "w") as f: + json.dump(all_metrics, f) + if isinstance(written_metrics, dict): + with open(metrics_out_file, "w") as f: + json.dump(written_metrics, f) + else: + with open(metrics_out_file, "w") as f: + json.dump(all_metrics, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", default="perov_5") + parser.add_argument("--eval_path", default="./eval_result") + parser.add_argument("--label", default="") + parser.add_argument("--tasks", nargs="+", default=["recon"]) + main_args = parser.parse_args() + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + main(main_args) diff --git a/MindChemistry/applications/cdvae/conf/configs.yaml b/MindChemistry/applications/cdvae/conf/configs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..114682a534527f41b3ae4a9f5d5b77b61534b7cc --- /dev/null +++ b/MindChemistry/applications/cdvae/conf/configs.yaml @@ -0,0 +1,65 @@ +hidden_dim: 256 +latent_dim: 256 +fc_num_layers: 1 +max_atoms: 20 +cost_natom: 1. +cost_coord: 10. +cost_type: 1. +cost_lattice: 10. +cost_composition: 1. +cost_edge: 10. +cost_property: 1. +beta: 0.01 +max_neighbors: 20 +radius: 7. +sigma_begin: 10. +sigma_end: 0.01 +type_sigma_begin: 5. +type_sigma_end: 0.01 +num_noise_level: 50 +teacher_forcing_lattice: True +predict_property: True + +Encoder: + hidden_channels: 128 + num_blocks: 4 + int_emb_size: 64 + basis_emb_size: 8 + out_emb_channels: 256 + num_spherical: 7 + num_radial: 6 + cutoff: 7.0 + max_num_neighbors: 20 + envelope_exponent: 5 + num_before_skip: 1 + num_after_skip: 2 + num_output_layers: 3 + +Decoder: + hidden_dim: 128 + +Optimizer: + learning_rate: 0.001 + factor: 0.6 + patience: 30 + cooldown: 10 + min_lr: 0.0001 + +Scaler: + TripInteraction_1_had_rbf: 18.873615264892578 + TripInteraction_1_sum_cbf: 7.996850490570068 + AtomUpdate_1_sum: 1.220463752746582 + TripInteraction_2_had_rbf: 16.10817527770996 + TripInteraction_2_sum_cbf: 7.614634037017822 + AtomUpdate_2_sum: 0.9690994620323181 + TripInteraction_3_had_rbf: 15.01930046081543 + TripInteraction_3_sum_cbf: 7.025179862976074 + AtomUpdate_3_sum: 0.8903237581253052 + OutBlock_0_sum: 1.6437848806381226 + OutBlock_0_had: 16.161039352416992 + OutBlock_1_sum: 1.1077653169631958 + OutBlock_1_had: 13.54678726196289 + OutBlock_2_sum: 0.9477927684783936 + OutBlock_2_had: 12.754337310791016 + OutBlock_3_sum: 0.9059251546859741 + OutBlock_3_had: 13.484951972961426 diff --git a/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml b/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a7c7093b586db54bbee2afc09bb05da2b5e31fa --- /dev/null +++ b/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml @@ -0,0 +1,12 @@ +prop: energy_per_atom +num_targets: 1 +niggli: true +primitive: false +graph_method: crystalnn +lattice_scale_method: scale_length +preprocess_workers: 30 +readout: mean +max_atoms: 24 +otf_graph: false +eval_model_name: carbon +batch_size: 50 diff --git a/MindChemistry/applications/cdvae/conf/data/mp_20.yaml b/MindChemistry/applications/cdvae/conf/data/mp_20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f43fb448daecec40d3ce5b1f1237add9ecd5c743 --- /dev/null +++ b/MindChemistry/applications/cdvae/conf/data/mp_20.yaml @@ -0,0 +1,12 @@ +prop: formation_energy_per_atom +num_targets: 1 +niggli: true +primitive: False +graph_method: crystalnn +lattice_scale_method: scale_length +preprocess_workers: 30 +readout: mean +max_atoms: 20 +otf_graph: false +eval_model_name: mp20 +batch_size: 50 diff --git a/MindChemistry/applications/cdvae/conf/data/perov_5.yaml b/MindChemistry/applications/cdvae/conf/data/perov_5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f25a93abd529484492d46d02156591f150b5d656 --- /dev/null +++ b/MindChemistry/applications/cdvae/conf/data/perov_5.yaml @@ -0,0 +1,12 @@ +prop: heat_ref +num_targets: 1 +niggli: true +primitive: false +graph_method: crystalnn +lattice_scale_method: scale_length +preprocess_workers: 24 +readout: mean +max_atoms: 20 +otf_graph: false +eval_model_name: perovskite +batch_size: 128 diff --git a/MindChemistry/applications/cdvae/create_dataset.py b/MindChemistry/applications/cdvae/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..67f9c190affe1bb1c33072bdd36011b7a17f76c2 --- /dev/null +++ b/MindChemistry/applications/cdvae/create_dataset.py @@ -0,0 +1,332 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""create_dataset""" + +import os +import logging +import argparse +import numpy as np +import pandas as pd +from p_tqdm import p_umap +from pymatgen.core.structure import Structure +from pymatgen.core.lattice import Lattice +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis import local_env + +from mindchemistry.utils.load_config import load_yaml_config_from_path +from mindchemistry.cell.gemnet.data_utils import get_scaler_from_data_list +from mindchemistry.cell.gemnet.data_utils import lattice_params_to_matrix +from mindchemistry.cell.dimenet.preprocess import PreProcess +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + + +class CreateDataset: + """Create Dataset for crystal structures + + Args: + name (str): Name of the dataset + path (str): Path to the dataset + prop (str): Property to predict + niggli (bool): Whether to convert to Niggli reduced cell + primitive (bool): Whether to convert to primitive cell + graph_method (str): Method to create graph + preprocess_workers (int): Number of workers for preprocessing + lattice_scale_method (str): Method to scale lattice + num_samples (int): Number of samples to use, if None use all + """ + + def __init__(self, name, path, + prop, niggli, primitive, + graph_method, preprocess_workers, + lattice_scale_method, num_samples=None): + super().__init__() + self.path = path + self.name = name + self.num_samples = num_samples + self.prop = prop + self.niggli = niggli + self.primitive = primitive + self.graph_method = graph_method + self.lattice_scale_method = lattice_scale_method + self.preprocess = PreProcess( + num_spherical=7, num_radial=6, envelope_exponent=5, + otf_graph=False, cutoff=7.0, max_num_neighbors=20,) + + self.cached_data = data_preprocess( + self.path, + preprocess_workers, + niggli=self.niggli, + primitive=self.primitive, + graph_method=self.graph_method, + prop_list=[prop], + num_samples=self.num_samples + )[:self.num_samples] + add_scaled_lattice_prop(self.cached_data, lattice_scale_method) + self.lattice_scaler = None + self.scaler = None + + def __len__(self) -> int: + return len(self.cached_data) + + def __getitem__(self, index): + data = self.cached_data[index] + + # scaler is set in DataModule set stage + prop = self.scaler.transform(data[self.prop]) + (frac_coords, atom_types, lengths, angles, edge_indices, + to_jimages, num_atoms) = data["graph_arrays"] + data_res = self.preprocess.data_process(angles.reshape(1, -1), lengths.reshape(1, -1), + np.array( + [num_atoms]), edge_indices.T, frac_coords, + edge_indices.shape[0], to_jimages, atom_types, prop) + return data_res + + def __repr__(self) -> str: + return f"CrystDataset({self.name}, {self.path})" + + def get_dataset_size(self): + return len(self.cached_data) + + +# match element with its chemical symbols +chemical_symbols = [ + # 0 + "X", + # 1 + "H", "He", + # 2 + "Li", "Be", "B", "C", "N", "O", "F", "Ne", + # 3 + "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", + # 4 + "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", + "Ga", "Ge", "As", "Se", "Br", "Kr", + # 5 + "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", + "In", "Sn", "Sb", "Te", "I", "Xe", + # 6 + "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", + "Ho", "Er", "Tm", "Yb", "Lu", + "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", + "Po", "At", "Rn", + # 7 + "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", + "Cf", "Es", "Fm", "Md", "No", "Lr", + "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", + "Lv", "Ts", "Og" +] + +# used for crystal matching +CRYSTALNN = local_env.CrystalNN( + distance_cutoffs=None, x_diff_weight=-1, porous_adjustment=False) + + +def build_crystal(crystal_str, niggli=True, primitive=False): + """Build crystal from cif string.""" + crystal = Structure.from_str(crystal_str, fmt="cif") + + if primitive: + crystal = crystal.get_primitive_structure() + + if niggli: + crystal = crystal.get_reduced_structure() + + canonical_crystal = Structure( + lattice=Lattice.from_parameters(*crystal.lattice.parameters), + species=crystal.species, + coords=crystal.frac_coords, + coords_are_cartesian=False, + ) + # match is gaurantteed because cif only uses lattice params & frac_coords + assert canonical_crystal.matches(crystal) + return canonical_crystal + + +def build_crystal_graph(crystal, graph_method="crystalnn"): + """build crystal graph""" + + if graph_method == "crystalnn": + crystal_graph = StructureGraph.with_local_env_strategy( + crystal, CRYSTALNN) + elif graph_method == "none": + pass + else: + raise NotImplementedError + + frac_coords = crystal.frac_coords + atom_types = crystal.atomic_numbers + lattice_parameters = crystal.lattice.parameters + lengths = lattice_parameters[:3] + angles = lattice_parameters[3:] + + assert np.allclose(crystal.lattice.matrix, + lattice_params_to_matrix(*lengths, *angles)) + + edge_indices, to_jimages = [], [] + if graph_method != "none": + for i, j, to_jimage in crystal_graph.graph.edges(data="to_jimage"): + edge_indices.append([j, i]) + to_jimages.append(to_jimage) + edge_indices.append([i, j]) + to_jimages.append(tuple(-tj for tj in to_jimage)) + + atom_types = np.array(atom_types) + lengths, angles = np.array(lengths), np.array(angles) + edge_indices = np.array(edge_indices) + to_jimages = np.array(to_jimages) + num_atoms = atom_types.shape[0] + + return frac_coords, atom_types, lengths, angles, edge_indices, to_jimages, num_atoms + + +def save_data(dataset, is_train, dataset_name): + """save created dataset to npy""" + processed_data = dict() + data_parameters = ["atom_types", "dist", "angle", "idx_kj", "idx_ji", + "edge_j", "edge_i", "pos", "batch", "lengths", + "num_atoms", "angles", "frac_coords", + "num_bonds", "num_triplets", "sbf", "y"] + for j, name in enumerate(data_parameters): + if j == 16: + processed_data[name] = [i[j].astype(np.float32) for i in dataset] + elif j == 14: + processed_data[name] = [i[j].sum() for i in dataset] + else: + processed_data[name] = [i[j] for i in dataset] + + if not os.path.exists(f"./data/{dataset_name}/{is_train}"): + os.makedirs(f"./data/{dataset_name}/{is_train}") + logging.info("%s has been created", + f"./data/{dataset_name}/{is_train}") + if is_train == "train": + np.savetxt(f"./data/{dataset_name}/{is_train}/scaler_mean.csv", + dataset.scaler.means.reshape(-1)) + np.savetxt(f"./data/{dataset_name}/{is_train}/scaler_std.csv", + dataset.scaler.stds.reshape(-1)) + np.savetxt( + f"./data/{dataset_name}/{is_train}/lattice_scaler_mean.csv", dataset.lattice_scaler.means) + np.savetxt( + f"./data/{dataset_name}/{is_train}/lattice_scaler_std.csv", dataset.lattice_scaler.stds) + np.save( + f"./data/{dataset_name}/{is_train}/processed_data.npy", processed_data) + + +def process_one(row, niggli, primitive, graph_method, prop_list): + """process one one sample""" + crystal_str = row["cif"] + crystal = build_crystal( + crystal_str, niggli=niggli, primitive=primitive) + graph_arrays = build_crystal_graph(crystal, graph_method) + properties = {k: row[k] for k in prop_list if k in row.keys()} + result_dict = { + "mp_id": row["material_id"], + "cif": crystal_str, + "graph_arrays": graph_arrays, + } + result_dict.update(properties) + return result_dict + + +def data_preprocess(input_file, num_workers, niggli, primitive, graph_method, prop_list, num_samples): + """process data""" + df = pd.read_csv(input_file)[:num_samples] + + unordered_results = p_umap( + process_one, + [df.iloc[idx] for idx in range(len(df))], + [niggli] * len(df), + [primitive] * len(df), + [graph_method] * len(df), + [prop_list] * len(df), + num_cpus=num_workers) + + mpid_to_results = {result["mp_id"]: result for result in unordered_results} + ordered_results = [mpid_to_results[df.iloc[idx]["material_id"]] + for idx in range(len(df))] + + return ordered_results + + +def add_scaled_lattice_prop(data_list, lattice_scale_method): + """add scaled lattice prop to dataset""" + for data in data_list: + graph_arrays = data["graph_arrays"] + # the indexes are brittle if more objects are returned + lengths = graph_arrays[2] + angles = graph_arrays[3] + num_atoms = graph_arrays[-1] + assert lengths.shape[0] == angles.shape[0] == 3 + assert isinstance(num_atoms, int) + + if lattice_scale_method == "scale_length": + lengths = lengths / float(num_atoms)**(1 / 3) + + data["scaled_lattice"] = np.concatenate([lengths, angles]) + + +def create_dataset(args): + """create dataset""" + config_data_path = f"./conf/data/{args.dataset}.yaml" + config_data = load_yaml_config_from_path(config_data_path) + prop = config_data.get("prop") + niggli = config_data.get("niggli") + primitive = config_data.get("primitive") + graph_method = config_data.get("graph_method") + lattice_scale_method = config_data.get("lattice_scale_method") + preprocess_workers = config_data.get("preprocess_workers") + path_train = f"./data/{args.dataset}/train.csv" + train_dataset = CreateDataset("Formation energy train", path_train, prop, + niggli, primitive, graph_method, + preprocess_workers, lattice_scale_method, args.num_samples_train) + lattice_scaler = get_scaler_from_data_list( + train_dataset.cached_data, + key="scaled_lattice") + scaler = get_scaler_from_data_list( + train_dataset.cached_data, + key=train_dataset.prop) + train_dataset.lattice_scaler = lattice_scaler + train_dataset.scaler = scaler + save_data(train_dataset, "train", args.dataset) + + path_val = f"./data/{args.dataset}/val.csv" + val_dataset = CreateDataset("Formation energy val", path_val, prop, + niggli, primitive, graph_method, + preprocess_workers, lattice_scale_method, args.num_samples_val) + val_dataset.lattice_scaler = lattice_scaler + val_dataset.scaler = scaler + save_data(val_dataset, "val", args.dataset) + + path_test = f"./data/{args.dataset}/test.csv" + test_dataset = CreateDataset("Formation energy test", path_test, prop, + niggli, primitive, graph_method, + preprocess_workers, lattice_scale_method, + args.num_samples_test) + test_dataset.lattice_scaler = lattice_scaler + test_dataset.scaler = scaler + save_data(test_dataset, "test", args.dataset) + + +def main(args): + create_dataset(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", default="perov_5") + parser.add_argument("--num_samples_train", default=300, type=int) + parser.add_argument("--num_samples_val", default=300, type=int) + parser.add_argument("--num_samples_test", default=300, type=int) + main_args = parser.parse_args() + main(main_args) diff --git a/MindChemistry/applications/cdvae/evaluation.py b/MindChemistry/applications/cdvae/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..84efbe47b8901e9b4a065deffd2e637f37bb4a1f --- /dev/null +++ b/MindChemistry/applications/cdvae/evaluation.py @@ -0,0 +1,192 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluation +""" + +import os +import time +import logging +from types import SimpleNamespace +import argparse +import mindspore as ms +import numpy as np + +from mindchemistry.cell.cdvae import CDVAE +from mindchemistry.cell.gemnet.data_utils import StandardScaler +from src.dataloader.dataloader import DataLoaderBaseCDVAE +from src.evaluate_utils import reconstructon, generation, optimization + + +def task_reconstruction(model, ld_kwargs, graph_dataset, recon_args): + """Evaluate model on the reconstruction task.""" + logging.info("Evaluate model on the reconstruction task.") + (frac_coords, num_atoms, atom_types, lengths, angles, + gt_frac_coords, gt_num_atoms, gt_atom_types, + gt_lengths, gt_angles) = reconstructon( + graph_dataset, model, ld_kwargs, recon_args.num_evals, + recon_args.force_num_atoms, recon_args.force_atom_types) + + if recon_args.label == "": + recon_out_name = "eval_recon.npy" + else: + recon_out_name = f"eval_recon_{recon_args.label}.npy" + + result = { + "eval_setting": recon_args, + "frac_coords": frac_coords, + "num_atoms": num_atoms, + "atom_types": atom_types, + "lengths": lengths, + "angles": angles, + } + # save result as numpy + np.save("./eval_result/" + recon_out_name, result) + groundtruth = { + "frac_coords": gt_frac_coords, + "num_atoms": gt_num_atoms, + "atom_types": gt_atom_types, + "lengths": gt_lengths, + "angles": gt_angles, + } + # save ground truth as numpy + np.save("./eval_result/gt_recon.npy", groundtruth) + + +def task_generation(model, ld_kwargs, gen_args): + """Evaluate model on the generation task.""" + logging.info("Evaluate model on the generation task.") + + (frac_coords, num_atoms, atom_types, lengths, angles, + all_frac_coords_stack, all_atom_types_stack) = generation( + model, ld_kwargs, gen_args.num_batches_to_samples, gen_args.num_evals, + gen_args.batch_size, gen_args.down_sample_traj_step) + + if gen_args.label == "": + gen_out_name = "eval_gen.npy" + else: + gen_out_name = f"eval_gen_{gen_args.label}.npy" + + result = { + "eval_setting": gen_args, + "frac_coords": frac_coords, + "num_atoms": num_atoms, + "atom_types": atom_types, + "lengths": lengths, + "angles": angles, + "all_frac_coords_stack": all_frac_coords_stack, + "all_atom_types_stack": all_atom_types_stack, + } + # save result as numpy + np.save("./eval_result/" + gen_out_name, result) + + +def task_optimization(model, ld_kwargs, graph_dataset, opt_args): + """Evaluate model on the property optimization task.""" + logging.info("Evaluate model on the property optimization task.") + if opt_args.start_from == "data": + loader = graph_dataset + else: + loader = None + optimized_crystals = optimization(model, ld_kwargs, loader) + if opt_args.label == "": + gen_out_name = "eval_opt.npy" + else: + gen_out_name = f"eval_opt_{opt_args.label}.npy" + # save result as numpy + np.save("./eval_result/" + gen_out_name, optimized_crystals) + + +def main(args): + # check whether path exists, if not exists create the direction + folder_path = os.path.dirname(args.model_path) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + logging.info("%s has been created", folder_path) + result_path = "./eval_result/" + if not os.path.exists(result_path): + os.makedirs(result_path) + logging.info("%s has been created", result_path) + config_path = "./conf/configs.yaml" + data_config_path = f"./conf/data/{args.dataset}.yaml" + # load model + model = CDVAE(config_path, data_config_path) + # load mindspore check point + param_dict = ms.load_checkpoint(args.model_path) + param_not_load, _ = ms.load_param_into_net(model, param_dict) + logging.info("parameter not load: %s.", param_not_load) + model.set_train(False) + + ld_kwargs = SimpleNamespace(n_step_each=args.n_step_each, + step_lr=args.step_lr, + min_sigma=args.min_sigma, + save_traj=args.save_traj, + disable_bar=args.disable_bar) + # load dataset + processed_data = np.load( + f"./data/{args.dataset}/test/processed_data.npy", allow_pickle=True).item() + graph_dataset = DataLoaderBaseCDVAE( + args.batch_size, processed_data, shuffle_dataset=False) + # load scaler + lattice_scaler_mean = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/lattice_scaler_mean.csv"), ms.float32) + lattice_scaler_std = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/lattice_scaler_std.csv"), ms.float32) + scaler_std = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/scaler_std.csv"), ms.float32) + scaler_mean = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/scaler_mean.csv"), ms.float32) + lattice_scaler = StandardScaler( + lattice_scaler_mean, lattice_scaler_std).to_mindspore() + scaler = StandardScaler(scaler_mean, scaler_std).to_mindspore() + model.lattice_scaler = lattice_scaler + model.scaler = scaler + + start_time_eval = time.time() + if "recon" in args.tasks: + task_reconstruction(model, ld_kwargs, graph_dataset, args) + if "gen" in args.tasks: + task_generation(model, ld_kwargs, args) + if "opt" in args.tasks: + task_optimization(model, ld_kwargs, graph_dataset, args) + logging.info("end evaluation, time: %f s.", time.time() - start_time_eval) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device_target", default="Ascend") + parser.add_argument("--device_id", default=7, type=int) + parser.add_argument("--model_path", default="./loss/loss.ckpt") + parser.add_argument("--dataset", default="perov_5") + parser.add_argument("--tasks", nargs="+", default=["gen"]) + parser.add_argument("--n_step_each", default=1, type=int) + parser.add_argument("--step_lr", default=1e-3, type=float) + parser.add_argument("--min_sigma", default=0, type=float) + parser.add_argument("--save_traj", default=False, type=bool) + parser.add_argument("--disable_bar", default=False, type=bool) + parser.add_argument("--num_evals", default=1, type=int) + parser.add_argument("--num_batches_to_samples", default=1, type=int) + parser.add_argument("--start_from", default="data", type=str) + parser.add_argument("--batch_size", default=128, type=int) + parser.add_argument("--force_num_atoms", action="store_true") + parser.add_argument("--force_atom_types", action="store_true") + parser.add_argument("--down_sample_traj_step", default=10, type=int) + parser.add_argument("--label", default="") + + main_args = parser.parse_args() + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + ms.context.set_context(device_target=main_args.device_target) + ms.context.set_context(device_id=main_args.device_id) + ms.context.set_context(mode=1) + main(main_args) diff --git a/MindChemistry/applications/cdvae/mp_20_process.py b/MindChemistry/applications/cdvae/mp_20_process.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a4a0c0cce93b2b8572a1e0feff6c1d34b799ad --- /dev/null +++ b/MindChemistry/applications/cdvae/mp_20_process.py @@ -0,0 +1,67 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" script used for generate mp_20 dataset from raw data""" +import os +import logging +import argparse +import pandas as pd +from pymatgen.core.structure import Structure +from pymatgen.core.lattice import Lattice +from pymatgen.io.cif import CifWriter + + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +parser = argparse.ArgumentParser() +parser.add_argument("--init_path", default="./data/mp_20.json") +parser.add_argument("--data_path", default="./data/mp_20") +args = parser.parse_args() + +# read json file and transfer to pandasframe +if not os.path.exists(args.data_path): + os.makedirs(args.data_path) + logging.info("%s has been created", args.data_path) + +df = pd.read_json(args.init_path) +df = df[["id", "formation_energy_per_atom", "band_gap", "pretty_formula", + "e_above_hull", "elements", "atoms", "spacegroup_number"]] +struct_list = [] +element_list = [] +# generate Structure from its df["atoms"] for each samples +for struct in df["atoms"]: + lattice = Lattice(struct["lattice_mat"], (False, False, False)) + pos = struct["coords"] + species = struct["elements"] + structure = Structure(lattice, species, pos) + # save cif from Structure + cif = CifWriter(structure) + struct_list.append(cif.__str__()) + element_list.append(struct["elements"]) + +# add cif to df +df.insert(7, "cif", struct_list) +df = df.drop("atoms", axis=1) +df["elements"] = element_list + + +# save to csv file +tot_len = len(df) +# solit the dataset to train:val:test = 6:2:2 +train_df = df.iloc[:int(0.6 * len(df))] +val_df = df.iloc[int(0.6 * len(df)):int(0.8 * len(df))] +test_df = df.iloc[int(0.8 * len(df)):] +train_df.to_csv(args.data_path+"/train.csv", index=False) +val_df.to_csv(args.data_path+"/val.csv", index=False) +test_df.to_csv(args.data_path+"/test.csv", index=False) +logging.info("Finished!") diff --git a/MindChemistry/applications/cdvae/requirements.txt b/MindChemistry/applications/cdvae/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8cac7af853e1281c606c9a5aa325e6453394feee --- /dev/null +++ b/MindChemistry/applications/cdvae/requirements.txt @@ -0,0 +1,12 @@ +matminer==0.7.3 +mindchemistry_ascend==0.1.0 +mindspore==2.3.0.20240411 +numpy==1.26.4 +p_tqdm==1.4.0 +pandas==2.2.2 +pymatgen==2023.8.10 +sciai==0.1.0 +scipy==1.13.1 +SMACT==2.2.1 +sympy==1.12 +tqdm==4.66.2 diff --git a/MindChemistry/applications/cdvae/src/__init__.py b/MindChemistry/applications/cdvae/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2278ece729a571b5c56d3a687e37951a94127b9 --- /dev/null +++ b/MindChemistry/applications/cdvae/src/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""src""" diff --git a/MindChemistry/applications/cdvae/src/dataloader/__init__.py b/MindChemistry/applications/cdvae/src/dataloader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2278ece729a571b5c56d3a687e37951a94127b9 --- /dev/null +++ b/MindChemistry/applications/cdvae/src/dataloader/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""src""" diff --git a/MindChemistry/applications/cdvae/src/dataloader/dataloader.py b/MindChemistry/applications/cdvae/src/dataloader/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9ed6d16d89f85dbf2e0347fb05242f96a5f264 --- /dev/null +++ b/MindChemistry/applications/cdvae/src/dataloader/dataloader.py @@ -0,0 +1,220 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataloader +""" +import random +import numpy as np +from mindspore import Tensor +import mindspore as ms + + +class DataLoaderBaseCDVAE: + r""" + DataLoader for CDVAE + """ + + def __init__(self, + batch_size, + dataset, + shuffle_dataset=True): + self.atom_types = dataset["atom_types"] + self.dist = dataset["dist"] + self.angle = dataset["angle"] + self.idx_kj = dataset["idx_kj"] + self.idx_ji = dataset["idx_ji"] + self.edge_j = dataset["edge_j"] + self.edge_i = dataset["edge_i"] + self.pos = dataset["pos"] + self.batch = dataset["batch"] + self.lengths = dataset["lengths"] + self.num_atoms = dataset["num_atoms"] + self.angles = dataset["angles"] + self.frac_coords = dataset["frac_coords"] + self.y = dataset["y"] + self.num_bonds = dataset["num_bonds"] + self.num_triplets = dataset["num_triplets"] + self.sbf = dataset["sbf"] + self.edge_attr = self.edge_j + self.batch_size = batch_size + self.index = 0 + self.step = 0 + self.shuffle_dataset = shuffle_dataset + + # can be customized to specific dataset + self.label = self.num_atoms + self.node_attr = self.atom_types + self.sample_num = len(self.node_attr) + + self.max_start_sample = self.sample_num - self.batch_size + 1 + + def get_dataset_size(self): + return self.sample_num + + def __iter__(self): + if self.shuffle_dataset: + self.shuffle() + else: + self.restart() + while self.index < self.max_start_sample: + # can be customized to generate different attributes or labels according to specific dataset + num_bonds_step = self.gen_global_attr( + self.num_bonds, self.batch_size).astype(np.int32) + num_atoms_step = self.gen_global_attr( + self.num_atoms, self.batch_size).squeeze().astype(np.int32) + num_triplets_step = self.gen_global_attr( + self.num_triplets, self.batch_size).astype(np.int32) + atom_types_step = self.gen_node_attr( + self.atom_types, self.batch_size).astype(np.int32) + dist_step = self.gen_edge_attr( + self.dist, self.batch_size).astype(np.float32) + angle_step = self.gen_triplet_attr( + self.angle, self.batch_size).astype(np.float32) + idx_kj_step = self.gen_triplet_attr(self.idx_kj, self.batch_size) + idx_kj_step = self.add_index_offset( + idx_kj_step, num_bonds_step, num_triplets_step).astype(np.int32) + idx_ji_step = self.gen_triplet_attr(self.idx_ji, self.batch_size) + idx_ji_step = self.add_index_offset( + idx_ji_step, num_bonds_step, num_triplets_step).astype(np.int32) + edge_j_step = self.gen_edge_attr(self.edge_j, self.batch_size) + edge_j_step = self.add_index_offset( + edge_j_step, num_atoms_step, num_bonds_step).astype(np.int32) + edge_i_step = self.gen_edge_attr(self.edge_j, self.batch_size) + edge_i_step = self.add_index_offset( + edge_i_step, num_atoms_step, num_bonds_step).astype(np.int32) + pos_step = self.gen_node_attr( + self.pos, self.batch_size).astype(np.float32) + batch_step = np.repeat( + np.arange(num_atoms_step.shape[0],), num_atoms_step, axis=0).astype(np.int32) + lengths_step = self.gen_crystal_attr( + self.lengths, self.batch_size).astype(np.float32) + angles_step = self.gen_crystal_attr( + self.angles, self.batch_size).astype(np.float32) + frac_coords_step = self.gen_node_attr( + self.frac_coords, self.batch_size).astype(np.float32) + y_step = self.gen_global_attr( + self.y, self.batch_size).astype(np.float32) + sbf_step = self.gen_triplet_attr( + self.sbf, self.batch_size).astype(np.float32) + total_atoms = num_atoms_step.sum().item() + self.add_step_index(self.batch_size) + + ############## change to mindspore Tensor ############# + atom_types_step = Tensor(atom_types_step, ms.int32) + dist_step = Tensor(dist_step, ms.float32) + angle_step = Tensor(angle_step, ms.float32) + idx_kj_step = Tensor(idx_kj_step, ms.int32) + idx_ji_step = Tensor(idx_ji_step, ms.int32) + edge_j_step = Tensor(edge_j_step, ms.int32) + edge_i_step = Tensor(edge_i_step, ms.int32) + pos_step = Tensor(pos_step, ms.float32) + batch_step = Tensor(batch_step, ms.int32) + lengths_step = Tensor(lengths_step, ms.float32) + num_atoms_step = Tensor(num_atoms_step, ms.int32) + angles_step = Tensor(angles_step, ms.float32) + frac_coords_step = Tensor(frac_coords_step, ms.float32) + y_step = Tensor(y_step, ms.float32) + sbf_step = Tensor(sbf_step, ms.float32) + + yield (atom_types_step, dist_step, angle_step, idx_kj_step, + idx_ji_step, edge_j_step, edge_i_step, batch_step, + lengths_step, num_atoms_step, angles_step, frac_coords_step, + y_step, self.batch_size, sbf_step, total_atoms) + + def add_index_offset(self, edge_index, num_atoms, num_bonds): + index_offset = ( + np.cumsum(num_atoms, axis=0) - num_atoms + ) + + index_offset_expand = np.repeat( + index_offset, num_bonds + ) + edge_index += index_offset_expand + return edge_index + + def shuffle_index(self): + """shuffle_index""" + indices = list(range(self.sample_num)) + random.shuffle(indices) + return indices + + def shuffle(self): + """shuffle""" + self.shuffle_action() + self.step = 0 + self.index = 0 + + def shuffle_action(self): + """shuffle_action""" + indices = self.shuffle_index() + self.atom_types = [self.atom_types[i] for i in indices] + self.dist = [self.dist[i] for i in indices] + self.angle = [self.angle[i] for i in indices] + self.idx_kj = [self.idx_kj[i] for i in indices] + self.idx_ji = [self.idx_ji[i] for i in indices] + self.edge_j = [self.edge_j[i] for i in indices] + self.edge_i = [self.edge_i[i] for i in indices] + self.pos = [self.pos[i] for i in indices] + self.batch = [self.batch[i] for i in indices] + self.lengths = [self.lengths[i] for i in indices] + self.num_atoms = [self.num_atoms[i] for i in indices] + self.angles = [self.angles[i] for i in indices] + self.frac_coords = [self.frac_coords[i] for i in indices] + self.y = [self.y[i] for i in indices] + self.num_bonds = [self.num_bonds[i] for i in indices] + self.num_triplets = [self.num_triplets[i] for i in indices] + self.sbf = [self.sbf[i] for i in indices] + + def restart(self): + """restart""" + self.step = 0 + self.index = 0 + + def gen_node_attr(self, node_attr, batch_size): + """gen_node_attr""" + node_attr_step = np.concatenate( + node_attr[self.index:self.index + batch_size], 0) + return node_attr_step + + def gen_edge_attr(self, edge_attr, batch_size): + """gen_edge_attr""" + edge_attr_step = np.concatenate( + edge_attr[self.index:self.index + batch_size], 0) + + return edge_attr_step + + def gen_global_attr(self, global_attr, batch_size): + """gen_global_attr""" + global_attr_step = np.stack( + global_attr[self.index:self.index + batch_size], 0) + + return global_attr_step + + def gen_crystal_attr(self, global_attr, batch_size): + """gen_global_attr""" + global_attr_step = np.stack( + global_attr[self.index:self.index + batch_size], 0).squeeze() + return global_attr_step + + def gen_triplet_attr(self, triplet_attr, batch_size): + """gen_triplet_attr""" + global_attr_step = np.concatenate( + triplet_attr[self.index:self.index + batch_size], 0) + + return global_attr_step + + def add_step_index(self, batch_size): + """add_step_index""" + self.index = self.index + batch_size + self.step += 1 diff --git a/MindChemistry/applications/cdvae/src/dataloader/mp_20_process.py b/MindChemistry/applications/cdvae/src/dataloader/mp_20_process.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a4a0c0cce93b2b8572a1e0feff6c1d34b799ad --- /dev/null +++ b/MindChemistry/applications/cdvae/src/dataloader/mp_20_process.py @@ -0,0 +1,67 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" script used for generate mp_20 dataset from raw data""" +import os +import logging +import argparse +import pandas as pd +from pymatgen.core.structure import Structure +from pymatgen.core.lattice import Lattice +from pymatgen.io.cif import CifWriter + + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) +parser = argparse.ArgumentParser() +parser.add_argument("--init_path", default="./data/mp_20.json") +parser.add_argument("--data_path", default="./data/mp_20") +args = parser.parse_args() + +# read json file and transfer to pandasframe +if not os.path.exists(args.data_path): + os.makedirs(args.data_path) + logging.info("%s has been created", args.data_path) + +df = pd.read_json(args.init_path) +df = df[["id", "formation_energy_per_atom", "band_gap", "pretty_formula", + "e_above_hull", "elements", "atoms", "spacegroup_number"]] +struct_list = [] +element_list = [] +# generate Structure from its df["atoms"] for each samples +for struct in df["atoms"]: + lattice = Lattice(struct["lattice_mat"], (False, False, False)) + pos = struct["coords"] + species = struct["elements"] + structure = Structure(lattice, species, pos) + # save cif from Structure + cif = CifWriter(structure) + struct_list.append(cif.__str__()) + element_list.append(struct["elements"]) + +# add cif to df +df.insert(7, "cif", struct_list) +df = df.drop("atoms", axis=1) +df["elements"] = element_list + + +# save to csv file +tot_len = len(df) +# solit the dataset to train:val:test = 6:2:2 +train_df = df.iloc[:int(0.6 * len(df))] +val_df = df.iloc[int(0.6 * len(df)):int(0.8 * len(df))] +test_df = df.iloc[int(0.8 * len(df)):] +train_df.to_csv(args.data_path+"/train.csv", index=False) +val_df.to_csv(args.data_path+"/val.csv", index=False) +test_df.to_csv(args.data_path+"/test.csv", index=False) +logging.info("Finished!") diff --git a/MindChemistry/applications/cdvae/src/evaluate_utils.py b/MindChemistry/applications/cdvae/src/evaluate_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba8d074c72d3e41d9877ca88657b3614a812ba3 --- /dev/null +++ b/MindChemistry/applications/cdvae/src/evaluate_utils.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate_utils""" +import logging +import mindspore as ms +import mindspore.mint as mint +from mindspore.nn import Adam +from tqdm import tqdm + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + + +def reconstructon(loader, model, ld_kwargs, num_evals, + force_num_atoms=False, force_atom_types=False): + """ + reconstruct the crystals in . + """ + result_frac_coords = [] + result_num_atoms = [] + result_atom_types = [] + result_lengths = [] + result_angles = [] + gt_frac_coords = [] + groundtruth_num_atoms = [] + groundtruth_atom_types = [] + gt_lengths = [] + gt_angles = [] + for idx, data in enumerate(loader): + logging.info("Reconstructing %d", int(idx * data[-3])) + batch_frac_coords, batch_num_atoms, batch_atom_types = [], [], [] + batch_lengths, batch_angles = [], [] + + # only sample one z, multiple evals for stoichaticity in langevin dynamics + (atom_types, dist, _, idx_kj, idx_ji, + edge_j, edge_i, batch, lengths, num_atoms, + angles, frac_coords, _, batch_size, sbf, + total_atoms) = data + gt_frac_coords.append(frac_coords.asnumpy()) + gt_angles.append(angles.asnumpy()) + gt_lengths.append(lengths.asnumpy()) + groundtruth_atom_types.append(atom_types.asnumpy()) + groundtruth_num_atoms.append(num_atoms.asnumpy()) + _, _, z = model.encode(atom_types, dist, + idx_kj, idx_ji, edge_j, edge_i, + batch, total_atoms, batch_size, sbf) + for _ in range(num_evals): + gt_num_atoms = num_atoms if force_num_atoms else None + gt_atom_types = atom_types if force_atom_types else None + outputs = model.langevin_dynamics( + z, ld_kwargs, batch_size, total_atoms, gt_num_atoms, gt_atom_types) + # collect sampled crystals in this batch. + batch_frac_coords.append(outputs["frac_coords"].asnumpy()) + batch_num_atoms.append(outputs["num_atoms"].asnumpy()) + batch_atom_types.append(outputs["atom_types"].asnumpy()) + batch_lengths.append(outputs["lengths"].asnumpy()) + batch_angles.append(outputs["angles"].asnumpy()) + # collect sampled crystals for this z. + result_frac_coords.append(batch_frac_coords) + result_num_atoms.append(batch_num_atoms) + result_atom_types.append(batch_atom_types) + result_lengths.append(batch_lengths) + result_angles.append(batch_angles) + + return ( + result_frac_coords, result_num_atoms, result_atom_types, + result_lengths, result_angles, + gt_frac_coords, groundtruth_num_atoms, groundtruth_atom_types, + gt_lengths, gt_angles) + + +def generation(model, ld_kwargs, num_batches_to_sample, num_samples_per_z, + batch_size=512, down_sample_traj_step=1): + """ + generate new crystals based on randomly sampled z. + """ + all_frac_coords_stack = [] + all_atom_types_stack = [] + result_frac_coords = [] + result_num_atoms = [] + result_atom_types = [] + result_lengths = [] + result_angles = [] + + for _ in range(num_batches_to_sample): + batch_all_frac_coords = [] + batch_all_atom_types = [] + batch_frac_coords, batch_num_atoms, batch_atom_types = [], [], [] + batch_lengths, batch_angles = [], [] + + z = ms.ops.randn(batch_size, model.hidden_dim) + + for _ in range(num_samples_per_z): + samples = model.langevin_dynamics(z, ld_kwargs, batch_size) + + # collect sampled crystals in this batch. + batch_frac_coords.append(samples["frac_coords"].asnumpy()) + batch_num_atoms.append(samples["num_atoms"].asnumpy()) + batch_atom_types.append(samples["atom_types"].asnumpy()) + batch_lengths.append(samples["lengths"].asnumpy()) + batch_angles.append(samples["angles"].asnumpy()) + if ld_kwargs.save_traj: + batch_all_frac_coords.append( + samples["all_frac_coords"][::down_sample_traj_step].asnumpy()) + batch_all_atom_types.append( + samples["all_atom_types"][::down_sample_traj_step].asnumpy()) + + # collect sampled crystals for this z. + result_frac_coords.append(batch_frac_coords) + result_num_atoms.append(batch_num_atoms) + result_atom_types.append(batch_atom_types) + result_lengths.append(batch_lengths) + result_angles.append(batch_angles) + if ld_kwargs.save_traj: + all_frac_coords_stack.append( + batch_all_frac_coords) + all_atom_types_stack.append( + batch_all_atom_types) + + return (result_frac_coords, result_num_atoms, result_atom_types, + result_lengths, result_angles, + all_frac_coords_stack, all_atom_types_stack) + + +def optimization(model, ld_kwargs, data_loader, + num_starting_points=128, num_gradient_steps=5000, + lr=1e-3, num_saved_crys=10): + """ + optimize the structure based on specific proprety. + """ + model.set_train(True) + if data_loader is not None: + data = next(iter(data_loader)) + (atom_types, dist, _, idx_kj, idx_ji, + edge_j, edge_i, batch, _, num_atoms, + _, _, _, batch_size, sbf, + total_atoms) = data + _, _, z = model.encode(atom_types, dist, + idx_kj, idx_ji, edge_j, edge_i, + batch, total_atoms, batch_size, sbf) + z = mint.narrow(z, 0, 0, num_starting_points) + z = ms.Parameter(z, requires_grad=True) + else: + z = mint.randn(num_starting_points, model.hparams.hidden_dim) + z = ms.Parameter(z, requires_grad=True) + + opt = Adam([z], learning_rate=lr) + freeze_model(model) + + loss_fn = model.fc_property + + def forward_fn(data): + loss = loss_fn(data) + return loss + grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters) + + def train_step(data): + loss, grads = grad_fn(data) + opt(grads) + return loss + + all_crystals = [] + total_atoms = mint.sum(mint.narrow( + num_atoms, 0, 0, num_starting_points)).item() + interval = num_gradient_steps // (num_saved_crys - 1) + for i in tqdm(range(num_gradient_steps)): + loss = mint.mean(train_step(z)) + logging.info("Task opt step: %d, loss: %f", i, loss) + if i % interval == 0 or i == (num_gradient_steps - 1): + crystals = model.langevin_dynamics( + z, ld_kwargs, batch_size, total_atoms) + all_crystals.append(crystals) + return {k: mint.cat([d[k] for d in all_crystals]).unsqueeze(0).asnumpy() for k in + ["frac_coords", "atom_types", "num_atoms", "lengths", "angles"]} + + +def freeze_model(model): + """ The model is fixed, only optimize z""" + for param in model.get_parameters(): + param.requires_grad = False diff --git a/MindChemistry/applications/cdvae/src/metrics_utils.py b/MindChemistry/applications/cdvae/src/metrics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0748aac2c8d7072f60e7a66d9bef2f2804f7c86f --- /dev/null +++ b/MindChemistry/applications/cdvae/src/metrics_utils.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utils for compute metrics""" +import itertools +import numpy as np + +from scipy.spatial.distance import pdist +from scipy.spatial.distance import cdist + +import smact +from smact.screening import pauling_test + +from create_dataset import chemical_symbols + + +def get_crystals_list( + frac_coords, atom_types, lengths, angles, num_atoms): + """ + args: + frac_coords: (num_atoms, 3) + atom_types: (num_atoms) + lengths: (num_crystals) + angles: (num_crystals) + num_atoms: (num_crystals) + """ + assert frac_coords.shape[0] == atom_types.shape[0] == num_atoms.sum() + assert lengths.shape[0] == angles.shape[0] == num_atoms.shape[0] + + start_idx = 0 + crystal_array_list = [] + for batch_idx, num_atom in enumerate(num_atoms.tolist()): + cur_frac_coords = frac_coords[start_idx:start_idx+num_atom] + cur_atom_types = atom_types[start_idx:start_idx+num_atom] + cur_lengths = lengths[batch_idx] + cur_angles = angles[batch_idx] + + crystal_array_list.append({ + "frac_coords": cur_frac_coords, + "atom_types": cur_atom_types, + "lengths": cur_lengths, + "angles": cur_angles, + }) + start_idx = start_idx + num_atom + return crystal_array_list + + +def smact_validity(comp, count, + use_pauling_test=True, + include_alloys=True): + """compute smact validity""" + elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) + space = smact.element_dictionary(elem_symbols) + smact_elems = [e[1] for e in space.items()] + electronegs = [e.pauling_eneg for e in smact_elems] + ox_combos = [e.oxidation_states for e in smact_elems] + if len(set(elem_symbols)) == 1: + return True + if include_alloys: + is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] + if all(is_metal_list): + return True + + threshold = np.max(count) + compositions = [] + for ox_states in itertools.product(*ox_combos): + stoichs = [(c,) for c in count] + # Test for charge balance + cn_e, cn_r = smact.neutral_ratios( + ox_states, stoichs=stoichs, threshold=threshold) + # Electronegativity test + if cn_e: + if use_pauling_test: + try: + electroneg_pass = pauling_test(ox_states, electronegs) + except TypeError: + # if no electronegativity data, assume it is okay + electroneg_pass = True + else: + electroneg_pass = True + if electroneg_pass: + for ratio in cn_r: + compositions.append( + tuple([elem_symbols, ox_states, ratio])) + compositions = [(i[0], i[2]) for i in compositions] + compositions = list(set(compositions)) + res = bool(compositions) + return res + + +def structure_validity(crystal, cutoff=0.5): + """compute structure validity""" + dist_mat = crystal.distance_matrix + # Pad diagonal with a large number + dist_mat = dist_mat + np.diag( + np.ones(dist_mat.shape[0]) * (cutoff + 10.)) + res = None + if dist_mat.min() < cutoff or crystal.volume < 0.1: + res = False + else: + res = True + return res + + +def get_fp_pdist(fp_array): + if isinstance(fp_array, list): + fp_array = np.array(fp_array) + fp_pdists = pdist(fp_array) + return fp_pdists.mean() + + +def filter_fps(struc_fps, comp_fps): + assert len(struc_fps) == len(comp_fps) + + filtered_struc_fps, filtered_comp_fps = [], [] + + for struc_fp, comp_fp in zip(struc_fps, comp_fps): + if struc_fp is not None and comp_fp is not None: + filtered_struc_fps.append(struc_fp) + filtered_comp_fps.append(comp_fp) + return filtered_struc_fps, filtered_comp_fps + + +def compute_cov(crys, gt_crys, comp_scaler, + struc_cutoff, comp_cutoff, num_gen_crystals=None): + """compute COV""" + struc_fps = [c.struct_fp for c in crys] + comp_fps = [c.comp_fp for c in crys] + gt_struc_fps = [c.struct_fp for c in gt_crys] + gt_comp_fps = [c.comp_fp for c in gt_crys] + + assert len(struc_fps) == len(comp_fps) + assert len(gt_struc_fps) == len(gt_comp_fps) + + # Use number of crystal before filtering to compute COV + if num_gen_crystals is None: + num_gen_crystals = len(struc_fps) + + struc_fps, comp_fps = filter_fps(struc_fps, comp_fps) + + comp_fps = comp_scaler.transform(comp_fps) + gt_comp_fps = comp_scaler.transform(gt_comp_fps) + + struc_fps = np.array(struc_fps) + gt_struc_fps = np.array(gt_struc_fps) + comp_fps = np.array(comp_fps) + gt_comp_fps = np.array(gt_comp_fps) + + struc_pdist = cdist(struc_fps, gt_struc_fps) + comp_pdist = cdist(comp_fps, gt_comp_fps) + + struc_recall_dist = struc_pdist.min(axis=0) + struc_precision_dist = struc_pdist.min(axis=1) + comp_recall_dist = comp_pdist.min(axis=0) + comp_precision_dist = comp_pdist.min(axis=1) + + cov_recall = np.mean(np.logical_and( + struc_recall_dist <= struc_cutoff, + comp_recall_dist <= comp_cutoff)) + cov_precision = np.sum(np.logical_and( + struc_precision_dist <= struc_cutoff, + comp_precision_dist <= comp_cutoff)) / num_gen_crystals + + metrics_dict = { + "cov_recall": cov_recall, + "cov_precision": cov_precision, + "amsd_recall": np.mean(struc_recall_dist), + "amsd_precision": np.mean(struc_precision_dist), + "amcd_recall": np.mean(comp_recall_dist), + "amcd_precision": np.mean(comp_precision_dist), + } + + combined_dist_dict = { + "struc_recall_dist": struc_recall_dist.tolist(), + "struc_precision_dist": struc_precision_dist.tolist(), + "comp_recall_dist": comp_recall_dist.tolist(), + "comp_precision_dist": comp_precision_dist.tolist(), + } + + return metrics_dict, combined_dist_dict diff --git a/MindChemistry/applications/cdvae/train.py b/MindChemistry/applications/cdvae/train.py new file mode 100644 index 0000000000000000000000000000000000000000..71ecd8d7cf4516ed1980570ba82abe730a2daa3f --- /dev/null +++ b/MindChemistry/applications/cdvae/train.py @@ -0,0 +1,176 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Train +""" + +import os +import logging +import argparse +import time +import numpy as np +import mindspore as ms +from mindspore.experimental import optim +from mindchemistry.utils.load_config import load_yaml_config_from_path +from mindchemistry.cell.cdvae import CDVAE +from mindchemistry.cell.gemnet.data_utils import StandardScalerMindspore +from create_dataset import create_dataset +from src.dataloader.dataloader import DataLoaderBaseCDVAE + + +def train_epoch(epoch, model, optimizer, scheduler, train_dataset): + """Train the model for one epoch""" + model.set_train() + # Define forward function + + def forward_fn(data): + (atom_types, dist, _, idx_kj, idx_ji, + edge_j, edge_i, batch, lengths, num_atoms, + angles, frac_coords, y, batch_size, sbf, total_atoms) = data + loss = model(atom_types, dist, idx_kj, idx_ji, edge_j, edge_i, + batch, lengths, num_atoms, angles, frac_coords, + y, batch_size, sbf, total_atoms, True, True) + return loss + # Get gradient function + grad_fn = ms.value_and_grad( + forward_fn, None, optimizer.parameters, has_aux=False) + # Define function of one-step training + + def train_step(data): + loss, grads = grad_fn(data) + scheduler.step(loss) + optimizer(grads) + return loss + + start_time_step = time.time() + for batch, data in enumerate(train_dataset): + loss = train_step(data) + time_step = time.time() - start_time_step + start_time_step = time.time() + if batch % 10 == 0: + logging.info("Train Epoch: %d [%d]\tLoss: %4f,\t time_step: %4f", + epoch, batch, loss, time_step) + + +def test_epoch(model, val_dataset): + """test for one epoch""" + model.set_train(False) + test_loss = 0 + i = 1 + for i, data in enumerate(val_dataset): + (atom_types, dist, _, idx_kj, idx_ji, + edge_j, edge_i, batch, lengths, num_atoms, + angles, frac_coords, y, batch_size, sbf, total_atoms) = data + output = model(atom_types, dist, + idx_kj, idx_ji, edge_j, edge_i, + batch, lengths, num_atoms, + angles, frac_coords, y, batch_size, + sbf, total_atoms, False, True) + test_loss += float(output) + test_loss /= (i+1) + logging.info("Val Loss: %4f", test_loss) + return test_loss + + +def train_net(args): + """training process""" + folder_path = os.path.dirname(args.name_ckpt) + if not os.path.exists(folder_path): + os.makedirs(folder_path) + logging.info("%s has been created", folder_path) + config_path = "./conf/configs.yaml" + data_config_path = f"./conf/data/{args.dataset}.yaml" + + model = CDVAE(config_path, data_config_path) + + if args.load_ckpt: + ### load check point ### + model_path = f"./loss/{args.name_ckpt}" + param_dict = ms.load_checkpoint(model_path) + param_not_load, _ = ms.load_param_into_net(model, param_dict) + logging.info("%s have not been loaded", param_not_load) + + ### create dataset in run first-time or dataset is not exist ### + if args.create_dataset or not os.path.exists(f"./data/{args.dataset}/train/processed_data.npy"): + logging.info("Creating dataset......") + create_dataset(args) + + ### read dataset from processed_data.npy ### + batch_size = 128 + processed_data = np.load( + f"./data/{args.dataset}/train/processed_data.npy", allow_pickle=True).item() + train_dataset = DataLoaderBaseCDVAE( + batch_size, processed_data, shuffle_dataset=True) + + processed_data = np.load( + f"./data/{args.dataset}/val/processed_data.npy", allow_pickle=True).item() + val_dataset = DataLoaderBaseCDVAE( + batch_size, processed_data, shuffle_dataset=False) + + lattice_scaler_mean = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/lattice_scaler_mean.csv"), ms.float32) + lattice_scaler_std = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/lattice_scaler_std.csv"), ms.float32) + scaler_std = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/scaler_std.csv"), ms.float32) + scaler_mean = ms.Tensor(np.loadtxt( + f"./data/{args.dataset}/train/scaler_mean.csv"), ms.float32) + lattice_scaler = StandardScalerMindspore( + lattice_scaler_mean, lattice_scaler_std) + scaler = StandardScalerMindspore(scaler_mean, scaler_std) + model.lattice_scaler = lattice_scaler + model.scaler = scaler + + config_opt = load_yaml_config_from_path(config_path).get("Optimizer") + learning_rate = config_opt.get("learning_rate") + min_lr = config_opt.get("min_lr") + factor = config_opt.get("factor") + patience = config_opt.get("patience") + + optimizer = optim.Adam(model.trainable_params(), learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, 'min', factor=factor, patience=patience, min_lr=min_lr) + + min_test_loss = 9999 + for epoch in range(args.epoch_num): + train_epoch(epoch, model, optimizer, scheduler, train_dataset) + test_loss = test_epoch(model, val_dataset) + + if test_loss < min_test_loss: + min_test_loss = test_loss + ms.save_checkpoint(model, args.name_ckpt) + logging.info("Updata best acc: %f", test_loss) + + logging.info('Finished Training') + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", default="perov_5") + parser.add_argument("--create_dataset", default=False, type=bool) + parser.add_argument("--num_samples_train", default=500, type=int) + parser.add_argument("--num_samples_val", default=300, type=int) + parser.add_argument("--num_samples_test", default=300, type=int) + parser.add_argument("--name_ckpt", default="./loss/loss.ckpt") + parser.add_argument("--load_ckpt", default=False) + parser.add_argument("--device_target", default="Ascend") + parser.add_argument("--device_id", default=3, type=int) + parser.add_argument("--epoch_num", default=100, type=int) + main_args = parser.parse_args() + logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + ms.context.set_context(device_target=main_args.device_target) + ms.context.set_context(device_id=main_args.device_id) + ms.context.set_context(mode=1) + train_net(main_args) diff --git a/MindChemistry/mindchemistry/cell/dimenet/dimenet_wrap.py b/MindChemistry/mindchemistry/cell/dimenet/dimenet_wrap.py index f2bc8a2bf9d285fe68e1995bc2005b87d04cc58b..fe022df31eff03db7502df180dedfa423c58e2b3 100644 --- a/MindChemistry/mindchemistry/cell/dimenet/dimenet_wrap.py +++ b/MindChemistry/mindchemistry/cell/dimenet/dimenet_wrap.py @@ -115,7 +115,8 @@ class DimeNetWrap: readout=data_config.get("readout") ) - def evaluation(self, angles, lengths, num_atoms, edge_index, frac_coords, num_bonds, to_jimages, atom_types): + def evaluation(self, angles, lengths, num_atoms, edge_index, + frac_coords, num_bonds, to_jimages, atom_types, y): """ Perform evaluation using the DimeNet model. """ @@ -124,7 +125,7 @@ class DimeNetWrap: (atom_types, dist, idx_kj, idx_ji, edge_j, edge_i, batch, sbf) = self.preprocess.data_process(angles, lengths, num_atoms, edge_index, frac_coords, num_bonds, - to_jimages, atom_types) + to_jimages, atom_types, y) energy = self.dimenet(atom_types, dist, idx_kj, idx_ji, edge_i, edge_j, batch, total_atoms, batch_size, sbf) return energy diff --git a/MindChemistry/mindchemistry/cell/dimenet/preprocess.py b/MindChemistry/mindchemistry/cell/dimenet/preprocess.py index 905ce82c1fa17826ff68b3675427b622532e21b2..5f2696f00c95293f712cc615d159341dc0b93a8b 100644 --- a/MindChemistry/mindchemistry/cell/dimenet/preprocess.py +++ b/MindChemistry/mindchemistry/cell/dimenet/preprocess.py @@ -46,7 +46,7 @@ class PreProcess: self.task = task def data_process(self, angles, lengths, num_atoms, edge_index, frac_coords, - num_bonds, to_jimages, atom_types): + num_bonds, to_jimages, atom_types, y): r""" Process the input data. @@ -59,6 +59,7 @@ class PreProcess: num_bonds (np.ndarray): The shape of tensor is :math:`(batch\_size,)`. to_jimages (np.ndarray): The shape of tensor is :math:`(total\_edges,)`. atom_types (np.ndarray): The shape of tensor is :math:`(total\_atoms,)`. + y (np.ndarray): The shape of tensor is :math:`(batch\_size,)`. Returns: Tuple[Union[np.ndarray, ms.Tensor]]: A tuple containing the following processed data: @@ -79,6 +80,7 @@ class PreProcess: - **num_triplets** (Union[np.ndarray, ms.Tensor]) - The shape of tensor is :math:`(batch\_size,)`. - **sbf** (Union[np.ndarray, ms.Tensor]) - The shape of tensor is :math:`(total\_triplets, num\_spherical * num\_radial)`. + - **y** (Union[np.ndarray, ms.Tensor]) - The shape of tensor is :math:`(batch\_size,)`. """ num_atoms = num_atoms.reshape(-1) batch = np.repeat(np.arange(num_atoms.shape[0],), num_atoms, axis=0) @@ -135,7 +137,7 @@ class PreProcess: return (atom_types, dist, idx_kj, idx_ji, edge_j, edge_i, batch, sbf) return (atom_types, dist, angle, idx_kj, idx_ji, edge_j, edge_i, pos, - batch, lengths, num_atoms, angles, frac_coords, num_bonds, num_triplets, sbf) + batch, lengths, num_atoms, angles, frac_coords, num_bonds, num_triplets, sbf, y) @staticmethod def triplets(edge_index): diff --git a/tests/st/mindchemistry/cell/test_cdvae/test_cdvae.py b/tests/st/mindchemistry/cell/test_cdvae/test_cdvae.py index a1afd77aefc60444421ffb26fd78336d40e982d4..bb27e950c7b4b6e351feeeaf05e736cd5c4c6a4c 100644 --- a/tests/st/mindchemistry/cell/test_cdvae/test_cdvae.py +++ b/tests/st/mindchemistry/cell/test_cdvae/test_cdvae.py @@ -112,9 +112,10 @@ def test_dimenet(): to_jimages = np.zeros((edge_index.shape[1], 3), np.int32) num_bonds = np.array([4, 4], np.int32) num_atoms = np.array([2, 2], np.int32) + y = Tensor([0.08428, 0.01353], ms.float32) out = dimenet.evaluation(angles, lengths, num_atoms, edge_index, frac_coords, - num_bonds, to_jimages, atom_types) + num_bonds, to_jimages, atom_types, y) assert out.shape == (2, dimenet.latent_dim), f"For `DimeNetPlusPlus`, the output shape should be\ (2, {dimenet.latent_dim}), but got {out.shape}." assert mint.isclose(out.sum(), ms.Tensor(0.0, dtype=ms.float32)), f"For `CDVAE`, the summary output\