快速入门
概述
本篇教程以Allegro预测分子间势能为例。
Allegro是基于等变图神经网络构建的SOTA模型,相关论文已经发表在期刊Nature Communications上,该案例验证了Allegro在分子势能预测中的有效性,具有较高的应用价值。
本教程介绍了Allegro的研究背景和技术路径,并展示了如何通过MindSpore Chemistry训练和快速推理模型。更多信息参见文章。
技术路径
MindSpore Chemistry求解分子势能预测问题的具体流程如下:
创建数据集
模型构建
损失函数
优化器
模型训练
模型推理
[1]:
import time
import random
import mindspore as ms
import numpy as np
from mindspore import nn
from mindspore.experimental import optim
下述src
可以在allegro/src下载。
[13]:
from mindchemistry.cell.allegro import Allegro
from mindchemistry.utils.load_config import load_yaml_config_from_path
from src.allegro_embedding import AllegroEmbedding
from src.dataset import create_training_dataset, create_test_dataset
from src.potential import Potential
[3]:
ms.set_seed(123)
ms.dataset.config.set_seed(1)
np.random.seed(1)
random.seed(1)
模型涉及的参数、优化器、数据配置见config。
[4]:
configs = load_yaml_config_from_path("rmd.yaml")
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_device("Ascend", 0)
创建数据集
在Revised MD17 dataset (rMD17)下,下载训练数据集验证数据集到 ./dataset/rmd17/npz_data/
目录。默认配置文件读取数据集的路径为 dataset/rmd17/npz_data/rmd17_uracil.npz。
[5]:
n_epoch = 5
batch_size = configs['BATCH_SIZE']
batch_size_eval = configs['BATCH_SIZE_EVAL']
learning_rate = configs['LEARNING_RATE']
is_profiling = configs['IS_PROFILING']
shuffle = configs['SHUFFLE']
split_random = configs['SPLIT_RANDOM']
lrdecay = configs['LRDECAY']
n_train = configs['N_TRAIN']
n_eval = configs['N_EVAL']
patience = configs['PATIENCE']
factor = configs['FACTOR']
parallel_mode = "NONE"
print("Loading data... ")
data_path = configs['DATA_PATH']
ds_train, edge_index, batch, ds_test, eval_edge_index, eval_batch, num_type = create_training_dataset(
config={
"path": data_path,
"batch_size": batch_size,
"batch_size_eval": batch_size_eval,
"n_train": n_train,
"n_val": n_eval,
"split_random": split_random,
"shuffle": shuffle
},
dtype=ms.float32,
pred_force=False,
parallel_mode=parallel_mode
)
Loading data...
模型构建
Allegro模型可从mindchemistry库导入,Embedding与势能预测模块可从src导入。
[6]:
def build(num_type, configs):
""" Build Potential model
Args:
num_atom (int): number of atoms
Returns:
net (Potential): Potential model
"""
literal_hidden_dims = 'hidden_dims'
literal_activation = 'activation'
literal_weight_init = 'weight_init'
literal_uniform = 'uniform'
emb = AllegroEmbedding(
num_type=num_type,
cutoff=configs['CUTOFF']
)
model = Allegro(
l_max=configs['L_MAX'],
irreps_in={
"pos": "1x1o",
"edge_index": None,
"node_attrs": f"{num_type}x0e",
"node_features": f"{num_type}x0e",
"edge_embedding": f"{configs['NUM_BASIS']}x0e"
},
avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
num_layers=configs['NUM_LAYERS'],
env_embed_multi=configs['ENV_EMBED_MULTI'],
two_body_kwargs={
literal_hidden_dims: configs['two_body_latent_mlp_latent_dimensions'],
literal_activation: 'silu',
literal_weight_init: literal_uniform
},
latent_kwargs={
literal_hidden_dims: configs['latent_mlp_latent_dimensions'],
literal_activation: 'silu',
literal_weight_init: literal_uniform
},
env_embed_kwargs={
literal_hidden_dims: configs['env_embed_mlp_latent_dimensions'],
literal_activation: None,
literal_weight_init: literal_uniform
},
enable_mix_precision=configs['enable_mix_precision'],
)
net = Potential(
embedding=emb,
model=model,
avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
edge_eng_mlp_latent_dimensions=configs['edge_eng_mlp_latent_dimensions']
)
return net
[7]:
print("Initializing model... ")
model = build(num_type, configs)
Initializing model...
损失函数
Allegro在模型训练中使用平均平方误差和平均绝对误差。
[8]:
loss_fn = nn.MSELoss()
metric_fn = nn.MAELoss()
优化器
使用Adam优化器,学习率更新策略采用ReduceLROnPlateau。
[9]:
optimizer = optim.Adam(params=model.trainable_params(), lr=learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=patience)
模型训练
在本教程中,我们自定义train_step与test_step,并进行模型训练。
[10]:
# 1. Define forward function
def forward(x, pos, edge_index, batch, batch_size, energy):
pred = model(x, pos, edge_index, batch, batch_size)
loss = loss_fn(pred, energy)
if batch_size != 0:
square_atom_num = (x.shape[0] / batch_size) ** 2
else:
raise ValueError("batch_size should not be zero")
if square_atom_num != 0:
loss = loss / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
return loss
# 2. Get gradient function
backward = ms.value_and_grad(forward, None, optimizer.parameters)
# 3. Define function of one-step training and validation
@ms.jit
def train_step(x, pos, edge_index, batch, batch_size, energy):
loss_, grads_ = backward(x, pos, edge_index, batch, batch_size, energy)
optimizer(grads_)
return loss_
@ms.jit
def test_step(x, pos, edge_index, batch, batch_size):
return model(x, pos, edge_index, batch, batch_size)
def _unpack(data):
return (data['x'], data['pos']), data['energy']
def train_epoch(model, trainset, edge_index, batch, batch_size, loss_train: list):
size = trainset.get_dataset_size()
model.set_train()
total_train_loss = 0
loss_train_epoch = []
ti = time.time()
for current, data_dict in enumerate(trainset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
loss = train_step(inputs[0], inputs[1], edge_index, batch, batch_size, label)
# AtomWise
loss = loss.asnumpy()
loss_train_epoch.append(loss)
if current % 10 == 0:
# pylint: disable=W1203
print(f"loss: {loss:.16f} [{current:>3d}/{size:>3d}]")
total_train_loss += loss
loss_train.append(loss_train_epoch)
if size != 0:
loss_train_avg = total_train_loss / size
else:
raise ValueError("size should not be zero")
t_now = time.time()
print('train loss: %.16f, time gap: %.4f' %(loss_train_avg, (t_now - ti)))
def test(model, dataset, edge_index, batch, batch_size, loss_fn, loss_eval: list, metric_fn, metric_list: list):
num_batches = dataset.get_dataset_size()
model.set_train(False)
test_loss = 0
metric = 0
for _, data_dict in enumerate(dataset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
if batch_size != 0:
atom_num = inputs[0].shape[0] / batch_size
else:
raise ValueError("batch_size should not be zero")
square_atom_num = atom_num ** 2
pred = test_step(inputs[0], inputs[1], edge_index, batch, batch_size)
if square_atom_num != 0:
test_loss += loss_fn(pred, label).asnumpy() / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
if atom_num != 0:
metric += metric_fn(pred, label).asnumpy() / atom_num
else:
raise ValueError("atom_num should not be zero")
test_loss /= num_batches
metric /= num_batches
# AtomWise
loss_eval.append(test_loss)
metric_list.append(metric)
print("Test: mse loss: %.16f" %test_loss)
print("Test: mae metric: %.16f" %metric)
return test_loss
# == Training ==
if is_profiling:
print("Initializing profiler... ")
profiler = ms.Profiler(output_path="dump_output" + "/profiler_data", profile_memory=True)
print("Initializing train... ")
print("seed is: %d" %ms.get_seed())
loss_eval = []
loss_train = []
metric_list = []
for t in range(n_epoch):
print("Epoch %d\n-------------------------------" %(t + 1))
train_epoch(model, ds_train, edge_index, batch, batch_size, loss_train)
test_loss = test(
model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, loss_eval, metric_fn, metric_list
)
if lrdecay:
lr_scheduler.step(test_loss)
last_lr = optimizer.param_groups[0].get('lr').value()
print("lr: %.10f\n" %last_lr)
if (t + 1) % 50 == 0:
ms.save_checkpoint(model, "./model.ckpt")
if is_profiling:
profiler.analyse()
print("Training Done!")
Initializing train...
seed is: 123
Epoch 1
-------------------------------
loss: 468775552.0000000000000000 [ 0/190]
loss: 468785216.0000000000000000 [ 10/190]
loss: 468774752.0000000000000000 [ 20/190]
loss: 468760512.0000000000000000 [ 30/190]
loss: 468342560.0000000000000000 [ 40/190]
loss: 414531872.0000000000000000 [ 50/190]
loss: 435014.9062500000000000 [ 60/190]
loss: 132964368.0000000000000000 [ 70/190]
loss: 82096352.0000000000000000 [ 80/190]
loss: 12417458.0000000000000000 [ 90/190]
loss: 202487.4687500000000000 [100/190]
loss: 300066.6875000000000000 [110/190]
loss: 468295.9375000000000000 [120/190]
loss: 1230706.0000000000000000 [130/190]
loss: 487508.2812500000000000 [140/190]
loss: 242425.6406250000000000 [150/190]
loss: 841241.0000000000000000 [160/190]
loss: 84912.1328125000000000 [170/190]
loss: 1272812.5000000000000000 [180/190]
train loss: 139695450.3818462193012238, time gap: 101.1313
Test: mse loss: 328262.9555555555853061
Test: mae metric: 463.6971659342447083
lr: 0.0020000001
Epoch 2
-------------------------------
loss: 469823.8750000000000000 [ 0/190]
loss: 650901.1875000000000000 [ 10/190]
loss: 183339.9687500000000000 [ 20/190]
loss: 256283.4218750000000000 [ 30/190]
loss: 335927.1875000000000000 [ 40/190]
loss: 913293.2500000000000000 [ 50/190]
loss: 1257833.7500000000000000 [ 60/190]
loss: 630779.2500000000000000 [ 70/190]
loss: 1652336.1250000000000000 [ 80/190]
loss: 155349.2500000000000000 [ 90/190]
loss: 183506.0468750000000000 [100/190]
loss: 322167.0000000000000000 [110/190]
loss: 738248.0000000000000000 [120/190]
loss: 628022.9375000000000000 [130/190]
loss: 693525.0000000000000000 [140/190]
loss: 237971.7812500000000000 [150/190]
loss: 728099.2500000000000000 [160/190]
loss: 50060.8867187500000000 [170/190]
loss: 1229544.0000000000000000 [180/190]
train loss: 441576.3795847039436921, time gap: 26.0444
Test: mse loss: 366946.9187499999534339
Test: mae metric: 493.0364359537759924
lr: 0.0020000001
Epoch 3
-------------------------------
loss: 522800.6250000000000000 [ 0/190]
loss: 751694.5000000000000000 [ 10/190]
loss: 187226.3750000000000000 [ 20/190]
loss: 240447.8906250000000000 [ 30/190]
loss: 302177.7812500000000000 [ 40/190]
loss: 834946.6875000000000000 [ 50/190]
loss: 1170818.2500000000000000 [ 60/190]
loss: 596591.8750000000000000 [ 70/190]
loss: 1559648.0000000000000000 [ 80/190]
loss: 144896.1718750000000000 [ 90/190]
loss: 171495.7656250000000000 [100/190]
loss: 302823.1250000000000000 [110/190]
loss: 681209.3750000000000000 [120/190]
loss: 594635.8750000000000000 [130/190]
loss: 648062.7500000000000000 [140/190]
loss: 221139.5312500000000000 [150/190]
loss: 684927.0000000000000000 [160/190]
loss: 57762.1718750000000000 [170/190]
loss: 1197153.3750000000000000 [180/190]
train loss: 414760.5352384868310764, time gap: 25.4267
Test: mse loss: 337391.1312500000349246
Test: mae metric: 473.0654032389323334
lr: 0.0020000001
Epoch 4
-------------------------------
loss: 479449.6250000000000000 [ 0/190]
loss: 658094.1875000000000000 [ 10/190]
loss: 167701.1875000000000000 [ 20/190]
loss: 218166.0000000000000000 [ 30/190]
loss: 274594.4375000000000000 [ 40/190]
loss: 752581.0000000000000000 [ 50/190]
loss: 1039454.4375000000000000 [ 60/190]
loss: 581997.6250000000000000 [ 70/190]
loss: 1481623.0000000000000000 [ 80/190]
loss: 131388.5000000000000000 [ 90/190]
loss: 159510.3593750000000000 [100/190]
loss: 284669.9687500000000000 [110/190]
loss: 635406.0625000000000000 [120/190]
loss: 547961.0000000000000000 [130/190]
loss: 594799.6875000000000000 [140/190]
loss: 206942.5937500000000000 [150/190]
loss: 635930.1250000000000000 [160/190]
loss: 53651.5429687500000000 [170/190]
loss: 1120740.7500000000000000 [180/190]
train loss: 384702.7380550986854360, time gap: 25.3646
Test: mse loss: 312383.8472222221898846
Test: mae metric: 455.6222493489584053
lr: 0.0020000001
Epoch 5
-------------------------------
loss: 442066.5625000000000000 [ 0/190]
loss: 610366.8750000000000000 [ 10/190]
loss: 154912.0781250000000000 [ 20/190]
loss: 199916.3125000000000000 [ 30/190]
loss: 253701.6875000000000000 [ 40/190]
loss: 695447.4375000000000000 [ 50/190]
loss: 973856.6875000000000000 [ 60/190]
loss: 529174.5625000000000000 [ 70/190]
loss: 1359184.8750000000000000 [ 80/190]
loss: 120610.0546875000000000 [ 90/190]
loss: 145533.5312500000000000 [100/190]
loss: 253629.2500000000000000 [110/190]
loss: 602776.2500000000000000 [120/190]
loss: 479350.7187500000000000 [130/190]
loss: 522066.7812500000000000 [140/190]
loss: 197747.9687500000000000 [150/190]
loss: 585378.6875000000000000 [160/190]
loss: 39960.7265625000000000 [170/190]
loss: 1010730.2500000000000000 [180/190]
train loss: 355614.9552425986621529, time gap: 26.2478
Test: mse loss: 291521.1986111110891216
Test: mae metric: 440.5232421874999886
lr: 0.0020000001
Training Done!
模型推理
自定义pred函数进行模型推理,返回推理结果。
[11]:
def pred(configs, dtype=ms.float32):
"""Pred the model on the eval dataset."""
batch_size_eval = configs['BATCH_SIZE_EVAL']
n_eval = configs['N_EVAL']
print("Loading data... ")
data_path = configs['DATA_PATH']
_, _, _, ds_test, eval_edge_index, eval_batch, num_type = create_test_dataset(
config={
"path": data_path,
"batch_size_eval": batch_size_eval,
"n_val": n_eval,
},
dtype=dtype,
pred_force=False
)
# Define model
print("Initializing model... ")
model = build(num_type, configs)
# load checkpoint
ckpt_file = './model.ckpt'
ms.load_checkpoint(ckpt_file, model)
# Instantiate loss function and metric function
loss_fn = nn.MSELoss()
metric_fn = nn.MAELoss()
# == Evaluation ==
print("Initializing Evaluation... ")
print("seed is: %d" %ms.get_seed())
pred_list, test_loss, metric = evaluation(
model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, metric_fn
)
print("prediction saved")
print("Test: mse loss: %.16f" %test_loss)
print("Test: mae metric: %.16f" %metric)
print("Predict Done!")
return pred_list, test_loss, metric
def evaluation(model, dataset, edge_index, batch, batch_size, loss_fn, metric_fn):
"""evaluation"""
num_batches = dataset.get_dataset_size()
model.set_train(False)
test_loss = 0
metric = 0
pred_list = []
for _, data_dict in enumerate(dataset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
if batch_size != 0:
atom_num = inputs[0].shape[0] / batch_size
else:
raise ValueError("batch_size should not be zero")
square_atom_num = atom_num ** 2
prediction = model(inputs[0], inputs[1], edge_index, batch, batch_size)
pred_list.append(prediction.asnumpy())
if square_atom_num != 0:
test_loss += loss_fn(prediction, label).asnumpy() / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
if atom_num != 0:
metric += metric_fn(prediction, label).asnumpy() / atom_num
else:
raise ValueError("atom_num should not be zero")
test_loss /= num_batches
metric /= num_batches
return pred_list, test_loss, metric
[14]:
pred(configs)
Loading data...
Initializing model...
Initializing Evaluation...
seed is: 123
prediction saved
Test: mse loss: 901.1434895833332348
Test: mae metric: 29.0822919209798201
Predict Done!
[14]:
([array([[-259531.56],
[-259377.28],
[-259534.83],
[-259243.62],
[-259541.62]], dtype=float32),
array([[-259516.4 ],
[-259519.81],
[-259545.69],
[-259428.45],
[-259527.28]], dtype=float32),
array([[-259508.94],
[-259521.22],
[-259533.28],
[-259465.56],
[-259523.88]], dtype=float32),
array([[-259533.56],
[-259303.9 ],
[-259509.53],
[-259369.22],
[-259514.4 ]], dtype=float32),
array([[-259368.25],
[-259487.45],
[-259545.94],
[-259379.47],
[-259494.19]], dtype=float32),
array([[-259533.64],
[-259453. ],
[-259542.69],
[-259451.9 ],
[-259213.11]], dtype=float32),
array([[-259562.5 ],
[-259531.6 ],
[-259526.5 ],
[-259530.3 ],
[-259389.12]], dtype=float32),
array([[-259515.03],
[-259530.69],
[-259476.9 ],
[-259267.77],
[-259535.11]], dtype=float32),
array([[-259548.77],
[-259530.8 ],
[-259401.7 ],
[-259542.12],
[-259419.86]], dtype=float32),
array([[-259386.81],
[-259291.75],
[-259419.61],
[-259488.25],
[-259334.34]], dtype=float32)],
901.1434895833332,
29.08229192097982)