1 Star 0 Fork 0

34bunny/ferminet-ms

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 17.30 KB
一键复制 编辑 原始数据 按行查看 历史
34bunny 提交于 2021-06-07 16:55 . add several files
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import math
import os
import random
import datetime
import ast
import sys
import numpy as np
import mindspore
from mindspore import Tensor, context
from mindspore.train.model import Model, ParallelMode
from mindspore import log as logger
from src.utils import system
from src import definitions
mindspore.common.set_seed(cfg.random_seed)
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
parser = argparse.ArgumentParser(
description='Training configuration', add_help=False)
# High-level training flags
parser.add_argument('--GPU', action='store_true', default=False,
help='use GPU for training (default: False)')
parser.add_argument('--distributed', action='store_false')
parser.add_argument('--batch_size', type=int, default=4096,
help='number of walkers')
# Pretrain flags
parser.add_argument('--pretrain_iterations', default=1000,
help='number of iterations for which to pretrain the network to match Hartree-Fock orbitals')
parser.add_argument('--pretrain_basis', type=str,
default='sto-3g', help='basis set used to run Hartree-Fock calculation in PySCF')
# Optimization flags
parser.add_argument('--iterations', type=int,
default=1000000, help='number of iterations')
parser.add_argument('--clip_el', type=float,
default=5.0, help='if not none, scale at which to clip local energy')
# Learning rate flags
parser.add_argument('--learning_rate', type=float,
default=1.e-4, help='learning rate')
parser.add_argument('--learning_rate_decay', type=float,
default=1.0, help='exponent of learning rate decay')
parser.add_argument('--learning_rate_delay', type=float,
default=10000.0, help='set the scale of the rate decay')
# KFAC flags
parser.add_argument('--use_kfac', action="store_false",
help='if false, use ADAM, else use KFAC as optimizer')
parser.add_argument('--kfac_invert_every', type=int,
default=1, help='See KFAC documentation')
parser.add_argument('--kfac_cov_update_every', type=int,
default=1, help='See KFAC documentation')
parser.add_argument('--kfac_damping', type=float,
default=0.001, help='See KFAC documentation')
parser.add_argument('--kfac_cov_ema_decay', type=float,
default=0.95, help='See KFAC documentation')
parser.add_argument('--kfac_momentum', type=float,
default=0.0, help='See KFAC documentation')
parser.add_argument('--kfac_momentum_type', type=str,
default="regular", help='See KFAC documentation')
parser.add_argument('--kfac_adapt_damping', action="store_false",
help='See KFAC documentation')
parser.add_argument('--kfac_damping_adaptation_decay', type=float,
default=0.9, help='See KFAC documentation')
parser.add_argument('--kfac_damping_adaptation_interval', type=int,
default=5, help='See KFAC documentation')
parser.add_argument('--kfac_min_damping', type=float,
default=1.e-4, help='See KFAC documentation')
parser.add_argument('--kfac_norm_constraint', type=float,
default=0.001, help='See KFAC documentation')
# System Flags
parser.add_argument('--system_type', type=str,
default='molecule', help='function to be called to create the system')
parser.add_argument('--system', type=str,
default='LiH', help='If system_type is "molecule",'
' the name of the molecule. If system_type is "atom",'
' the atomic symbol. If system_type is "hn", the number'
' of atoms in the hydrogen chain.')
parser.add_argument('--system_charge', type=int,
default=0, help='The overall charge of the system. Positive for cations '
'and negative for anions.')
parser.add_argument('--system_dim', type=int,
default=3, help='Number of dimensions of the system. Change with care.')
parser.add_argument('--system_units', type=str,
default='bohr', help='Units of *input* coords of atoms. Either "bohr" or '
'"angstrom". Internally work in a.u.; positions in '
'Angstroms are converged to Bohr.')
# Flags related to diatomics, the hydrogen chain, and the hydrogen circle
parser.add_argument('--system_separation', type=float,
default=0.0, help='For the hydrogen chain and diatomic systems, the '
'separation between nuclei. For the H4 circle, the radius '
'of the circle. For diatomics, will default to the '
'equilibrium bond length if set to 0.')
# Flags related to the hydrogen circle
parser.add_argument('--system_angle', type=float,
default=np.pi / 4.0, help='Angle from the x-axis for the H4 circle')
# Flags related to the MCMC chain
parser.add_argument('--mcmc_burn_in', type=int,
default=100, help='Number of burn in steps after pretraining. '
'If zero do not burn in or reinitialize walkers.')
parser.add_argument('--mcmc_steps', type=int,
default=10, help='Number of MCMC steps to make between network updates.')
parser.add_argument('--mcmc_init_width', type=float,
default=0.8, help='Width of (atom-centred) Gaussian used to generate initial '
'electron configurations.')
parser.add_argument('--mcmc_init_means', nargs="+", type=float,
default=[], help='Iterable of 3*nelectrons giving the mean initial position '
'of each electron. Configurations are drawn using Gaussians '
'of width init_width at each 3D position. Alpha electrons '
'are listed before beta electrons. If empty, electrons are '
'assigned to atoms based upon the isolated atom spin '
'configuration.')
# Flags related to the network architecture
parser.add_argument('--network_architecture', type=str, choices=['ferminet', 'slater'],
default='ferminet', help='The choice of architecture to run the calculation with. '
'Either "ferminet" or "slater" for the Fermi Net and '
'standard Slater determinant respectively.')
parser.add_argument('--hidden_units', type=str,
default='((256, 32), (256, 32), (256, 32), (256, 32))',
help='Number of hidden units in each layer of the network. If '
'the Fermi Net with one- and two-electron streams is used, '
'a tuple is provided for each layer, with the first '
'element giving the number of hidden units in the '
'one-electron stream and the second element giving the '
'number of units in the two-electron stream.')
parser.add_argument('--determinants', type=int,
default=16, help='number of determinants in the Fermi Net')
parser.add_argument('--r12_en_features', action='store_true',
help='include r12/distance features between electrons and nuclei')
parser.add_argument('--r12_ee_features', action='store_true',
help='include r12/distance features between pairs of electrons')
parser.add_argument('--pos_ee_features', action='store_true',
help='include electron-electron position features')
parser.add_argument('--use_envelope', action='store_true',
help='Include multiplicative exponentially-decaying envelope. '
'Calculations will not converge if set to False.')
parser.add_argument('--backflow', action='store_false',
help='Include backflow transformation in input coordinates. '
'Only for use if network_architecture == "slater". '
'Implies --build_backflow.')
parser.add_argument('--build_backflow', action='store_false',
help='Create backflow weights but do '
'not include backflow coordinate transformation. Use to '
'train a Slater-Jastrow architecture and then train a '
'Slater-Jastrow-Backflow architecture based on it.')
parser.add_argument('--residual', action='store_true',
help='Use residual connections. Recommended.')
parser.add_argument('--after_det', nargs="+", type=float,
default=[1], help='Comma-separated configuration of neural network after the '
'determinants. By default, just takes a weighted sum of '
'determinants with no nonlinearity.')
# Flags related to the Jastrow factor
parser.add_argument('--jastrow_en', action='store_false',
help='include electron - nuclear Jastrow factor')
parser.add_argument('--jastrow_ee', action='store_false',
help='include electron-electron Jastrow factor')
parser.add_argument('--jastrow_een', action='store_false',
help='include electron-electron-nuclear Jastrow factor')
# Flags related to logging, checkpointing, and restoring
parser.add_argument('--stats_frequency', type=int,
default=1, help='iterations between logging of stats')
parser.add_argument('--save_frequency', type=float,
default=10.0, help='minutes between saving network params')
parser.add_argument('--result_folder', type=str,
default='.', help='Path to save results and checkpoints to. A new '
'subdirectory will be created for every experiment. '
'By default, save locally.')
parser.add_argument('--restore_path', type=str,
default='', help='path containing checkpoint to restore network from')
parser.add_argument('--log_walkers', action='store_false',
help='Whether or not to log the values of all walkers every '
'iteration. Use with caution!!! Produces a lot of data '
'very quickly.')
parser.add_argument('--log_local_energies', action='store_false',
help='Whether or not to log all local energies for each walker '
'at each step.')
parser.add_argument('--log_wavefunction', action='store_false',
help='Whether or not to log all values of wavefunction for '
'each walker at each step.')
# Flags related to debugging
parser.add_argument('--check_loss', action='store_false',
help='Apply gradient update only if the loss is not NaN. '
'If true, training could be slightly slower but the '
'checkpoint written out when a NaN is detected will be '
'with the network weights which led to '
'the NaN.')
parser.add_argument('--determinism', action='store_false',
help='CPU only mode that also enforces determinism.'
'Will run significantly slower if used.')
parser.add_argument('--random_seed', type=int, default=1,
help='Only works in determinism mode. '
'Set a random seed for the run.')
parser.add_argument('--graph_path', type=str, default='',
help='File to write graph to.')
def main():
args, _ = parser.parse_known_args()
rank_id, rank_size = 0, 1
context.set_context(mode=context.GRAPH_MODE)
if args.distributed:
if args.GPU:
init("nccl")
context.set_context(device_target='GPU')
else:
raise ValueError("Only supported GPU training.")
context.reset_auto_parallel_context()
rank_id = get_rank()
rank_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=rank_size)
if args.determinism:
mindspore.common.set_seed(args.random_seed)
random.seed(args.random_seed)
np.random.seed(args.random_seed)
logger.warning('Activating determinism mode. Expect slow performance.')
# Create folders for logging
result_path = os.path.join(args.result_folder, 'ferminet_results_' +
datetime.datetime.ctime(datetime.datetime.now()).replace(' ', '_'))
if not os.path.isdir(args.result_folder):
os.mkdir(args.result_folder)
os.mkdir(result_path)
# Save the command line arguments for reproducibility
with open(os.path.join(result_path, 'flags.txt'), 'w') as f:
f.write(' '.join(sys.argv[1:]) + '\n')
# Run function to create the system from flags.
logger.info('System Type: %s', args.system_type)
logger.info('System: %s', args.system)
if args.system_type == "molecule":
molecule, spins = system.molecule(args.system,
bond_length=args.system_separation,
units=args.system_units)
elif args.system_type == 'atom':
molecule, spins = system.atom(args.system, charge=args.system_charge)
elif args.system_type == 'hn':
molecule, spins = system.hn(int(args.system),
args.system_separation,
charge=args.system_charge,
units=args.system_units)
elif args.system_type == 'h4_circle':
molecule, spins = system.h4_circle(args.system_separation,
args.system_angle,
units=args.system_units)
else:
raise ValueError('Not a recognized system type: %s' % args.system_type)
network_config = definitions.NetworkConfig(
architecture=args.network_architecture,
hidden_units=ast.literal_eval(args.hidden_units),
determinants=args.determinants,
r12_en_features=args.r12_en_features,
r12_ee_features=args.r12_ee_features,
pos_ee_features=args.pos_ee_features,
use_envelope=args.use_envelope,
backflow=args.backflow,
build_backflow=args.build_backflow,
residual=args.residual,
after_det=tuple(int(x) for x in args.after_det),
jastrow_en=args.jastrow_en,
jastrow_ee=args.jastrow_ee,
jastrow_een=args.jastrow_een,
)
pretrain_config = definitions.PretrainConfig(
iterations=args.pretrain_iterations,
basis=args.pretrain_basis,
)
optim_config = train.OptimConfig(
iterations=args.iterations,
learning_rate=args.learning_rate,
learning_rate_decay=args.learning_rate_decay,
learning_rate_delay=args.learning_rate_delay,
clip_el=args.clip_el,
use_kfac=args.use_kfac,
check_loss=args.check_loss,
deterministic=args.determinism,
)
kfac_config = train.KfacConfig(
invert_every=args.kfac_invert_every,
cov_update_every=args.kfac_cov_update_every,
damping=args.kfac_damping,
cov_ema_decay=args.kfac_cov_ema_decay,
momentum=args.kfac_momentum,
momentum_type=args.kfac_momentum_type,
adapt_damping=args.kfac_adapt_damping,
damping_adaptation_decay=args.kfac_damping_adaptation_decay,
damping_adaptation_interval=args.kfac_damping_adaptation_interval,
min_damping=args.kfac_min_damping,
)
mcmc_config = train.MCMCConfig(
burn_in=args.mcmc_burn_in,
steps=args.mcmc_steps,
init_width=args.mcmc_init_width,
move_width=args.mcmc_move_width,
init_means=tuple(float(x) for x in args.mcmc_init_means),
)
logging_config = train.LoggingConfig(
result_path=result_path,
save_frequency=args.save_frequency,
restore_path=args.restore_path,
stats_frequency=args.stats_frequency,
walkers=args.log_walkers,
wavefunction=args.log_wavefunction,
local_energy=args.log_local_energies,
config={
'system_type': args.system_type,
'system': args.system,
'system_units': args.system_units,
'system_separation': args.system_separation,
'system_charge': args.system_charge,
},
)
definitions.train(
molecule=molecule,
spins=spins,
batch_size=args.batch_size,
network_config=network_config,
pretrain_config=pretrain_config,
optim_config=optim_config,
kfac_config=kfac_config,
mcmc_config=mcmc_config,
logging_config=logging_config,
multi_gpu=args.distributed,
double_precision=False,
graph_path=args.graph_path)
logger.info('Fermi Net training run completed successfully.')
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/TFbunny/ferminet-ms.git
git@gitee.com:TFbunny/ferminet-ms.git
TFbunny
ferminet-ms
ferminet-ms
master

搜索帮助