Quick Start
Overview
Taking the prediction of interatomic potential with Allegro as an example.
Allegro is a state-of-the-art model built on equivariant graph neural networks. The related paper has been published in the journal Nature Communications. This case study demonstrates the effectiveness of Allegro in molecular potential energy prediction, with high application value.
This tutorial introduces the research background and technical path of Allegro, and demonstrates how to train and perform fast inference with MindSpore Chemistry. More information can be found in paper.
Technology Path
MindSpore Earth solves the problem as follows:
Data Construction.
Model Construction.
Loss function.
Optimizer.
Model Training.
Model Prediction.
[1]:
import time
import random
import mindspore as ms
import numpy as np
from mindspore import nn
from mindspore.experimental import optim
The following src
can be downloaded in allegro/src.
[13]:
from mindchemistry.cell.allegro import Allegro
from mindchemistry.utils.load_config import load_yaml_config_from_path
from src.allegro_embedding import AllegroEmbedding
from src.dataset import create_training_dataset, create_test_dataset
from src.potential import Potential
[3]:
ms.set_seed(123)
ms.dataset.config.set_seed(1)
np.random.seed(1)
random.seed(1)
You can get parameters of model, data and optimizer from config.
[4]:
configs = load_yaml_config_from_path("rmd.yaml")
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_device("Ascend", 0)
Data Construction
In Revised MD17 dataset (rMD17), download the dataset to the ./dataset/rmd17/npz_data/
directory. The default configuration file reads the dataset path as dataset/rmd17/npz_data/rmd17_uracil.npz
.
[5]:
n_epoch = 5
batch_size = configs['BATCH_SIZE']
batch_size_eval = configs['BATCH_SIZE_EVAL']
learning_rate = configs['LEARNING_RATE']
is_profiling = configs['IS_PROFILING']
shuffle = configs['SHUFFLE']
split_random = configs['SPLIT_RANDOM']
lrdecay = configs['LRDECAY']
n_train = configs['N_TRAIN']
n_eval = configs['N_EVAL']
patience = configs['PATIENCE']
factor = configs['FACTOR']
parallel_mode = "NONE"
print("Loading data... ")
data_path = configs['DATA_PATH']
ds_train, edge_index, batch, ds_test, eval_edge_index, eval_batch, num_type = create_training_dataset(
config={
"path": data_path,
"batch_size": batch_size,
"batch_size_eval": batch_size_eval,
"n_train": n_train,
"n_val": n_eval,
"split_random": split_random,
"shuffle": shuffle
},
dtype=ms.float32,
pred_force=False,
parallel_mode=parallel_mode
)
Loading data...
Model Construction
The Allegro model can be imported using the mindchemistry library, while the Embedding and potential energy prediction modules can be imported from src.
[6]:
def build(num_type, configs):
""" Build Potential model
Args:
num_atom (int): number of atoms
Returns:
net (Potential): Potential model
"""
literal_hidden_dims = 'hidden_dims'
literal_activation = 'activation'
literal_weight_init = 'weight_init'
literal_uniform = 'uniform'
emb = AllegroEmbedding(
num_type=num_type,
cutoff=configs['CUTOFF']
)
model = Allegro(
l_max=configs['L_MAX'],
irreps_in={
"pos": "1x1o",
"edge_index": None,
"node_attrs": f"{num_type}x0e",
"node_features": f"{num_type}x0e",
"edge_embedding": f"{configs['NUM_BASIS']}x0e"
},
avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
num_layers=configs['NUM_LAYERS'],
env_embed_multi=configs['ENV_EMBED_MULTI'],
two_body_kwargs={
literal_hidden_dims: configs['two_body_latent_mlp_latent_dimensions'],
literal_activation: 'silu',
literal_weight_init: literal_uniform
},
latent_kwargs={
literal_hidden_dims: configs['latent_mlp_latent_dimensions'],
literal_activation: 'silu',
literal_weight_init: literal_uniform
},
env_embed_kwargs={
literal_hidden_dims: configs['env_embed_mlp_latent_dimensions'],
literal_activation: None,
literal_weight_init: literal_uniform
},
enable_mix_precision=configs['enable_mix_precision'],
)
net = Potential(
embedding=emb,
model=model,
avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
edge_eng_mlp_latent_dimensions=configs['edge_eng_mlp_latent_dimensions']
)
return net
[7]:
print("Initializing model... ")
model = build(num_type, configs)
Initializing model...
Loss Function
Allegro uses mean squared error and mean absolute error for model training.
[8]:
loss_fn = nn.MSELoss()
metric_fn = nn.MAELoss()
Optimitizer
The Adam optimizer is used, and the learning rate update strategy is ReduceLROnPlateau.
[9]:
optimizer = optim.Adam(params=model.trainable_params(), lr=learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=patience)
Model Training
In this tutorial, we customize the train_step and test_step, and perform model training.
[10]:
# 1. Define forward function
def forward(x, pos, edge_index, batch, batch_size, energy):
pred = model(x, pos, edge_index, batch, batch_size)
loss = loss_fn(pred, energy)
if batch_size != 0:
square_atom_num = (x.shape[0] / batch_size) ** 2
else:
raise ValueError("batch_size should not be zero")
if square_atom_num != 0:
loss = loss / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
return loss
# 2. Get gradient function
backward = ms.value_and_grad(forward, None, optimizer.parameters)
# 3. Define function of one-step training and validation
@ms.jit
def train_step(x, pos, edge_index, batch, batch_size, energy):
loss_, grads_ = backward(x, pos, edge_index, batch, batch_size, energy)
optimizer(grads_)
return loss_
@ms.jit
def test_step(x, pos, edge_index, batch, batch_size):
return model(x, pos, edge_index, batch, batch_size)
def _unpack(data):
return (data['x'], data['pos']), data['energy']
def train_epoch(model, trainset, edge_index, batch, batch_size, loss_train: list):
size = trainset.get_dataset_size()
model.set_train()
total_train_loss = 0
loss_train_epoch = []
ti = time.time()
for current, data_dict in enumerate(trainset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
loss = train_step(inputs[0], inputs[1], edge_index, batch, batch_size, label)
# AtomWise
loss = loss.asnumpy()
loss_train_epoch.append(loss)
if current % 10 == 0:
# pylint: disable=W1203
print(f"loss: {loss:.16f} [{current:>3d}/{size:>3d}]")
total_train_loss += loss
loss_train.append(loss_train_epoch)
if size != 0:
loss_train_avg = total_train_loss / size
else:
raise ValueError("size should not be zero")
t_now = time.time()
print('train loss: %.16f, time gap: %.4f' %(loss_train_avg, (t_now - ti)))
def test(model, dataset, edge_index, batch, batch_size, loss_fn, loss_eval: list, metric_fn, metric_list: list):
num_batches = dataset.get_dataset_size()
model.set_train(False)
test_loss = 0
metric = 0
for _, data_dict in enumerate(dataset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
if batch_size != 0:
atom_num = inputs[0].shape[0] / batch_size
else:
raise ValueError("batch_size should not be zero")
square_atom_num = atom_num ** 2
pred = test_step(inputs[0], inputs[1], edge_index, batch, batch_size)
if square_atom_num != 0:
test_loss += loss_fn(pred, label).asnumpy() / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
if atom_num != 0:
metric += metric_fn(pred, label).asnumpy() / atom_num
else:
raise ValueError("atom_num should not be zero")
test_loss /= num_batches
metric /= num_batches
# AtomWise
loss_eval.append(test_loss)
metric_list.append(metric)
print("Test: mse loss: %.16f" %test_loss)
print("Test: mae metric: %.16f" %metric)
return test_loss
# == Training ==
if is_profiling:
print("Initializing profiler... ")
profiler = ms.Profiler(output_path="dump_output" + "/profiler_data", profile_memory=True)
print("Initializing train... ")
print("seed is: %d" %ms.get_seed())
loss_eval = []
loss_train = []
metric_list = []
for t in range(n_epoch):
print("Epoch %d\n-------------------------------" %(t + 1))
train_epoch(model, ds_train, edge_index, batch, batch_size, loss_train)
test_loss = test(
model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, loss_eval, metric_fn, metric_list
)
if lrdecay:
lr_scheduler.step(test_loss)
last_lr = optimizer.param_groups[0].get('lr').value()
print("lr: %.10f\n" %last_lr)
if (t + 1) % 50 == 0:
ms.save_checkpoint(model, "./model.ckpt")
if is_profiling:
profiler.analyse()
print("Training Done!")
Initializing train...
seed is: 123
Epoch 1
-------------------------------
loss: 468775552.0000000000000000 [ 0/190]
loss: 468785216.0000000000000000 [ 10/190]
loss: 468774752.0000000000000000 [ 20/190]
loss: 468760512.0000000000000000 [ 30/190]
loss: 468342560.0000000000000000 [ 40/190]
loss: 414531872.0000000000000000 [ 50/190]
loss: 435014.9062500000000000 [ 60/190]
loss: 132964368.0000000000000000 [ 70/190]
loss: 82096352.0000000000000000 [ 80/190]
loss: 12417458.0000000000000000 [ 90/190]
loss: 202487.4687500000000000 [100/190]
loss: 300066.6875000000000000 [110/190]
loss: 468295.9375000000000000 [120/190]
loss: 1230706.0000000000000000 [130/190]
loss: 487508.2812500000000000 [140/190]
loss: 242425.6406250000000000 [150/190]
loss: 841241.0000000000000000 [160/190]
loss: 84912.1328125000000000 [170/190]
loss: 1272812.5000000000000000 [180/190]
train loss: 139695450.3818462193012238, time gap: 101.1313
Test: mse loss: 328262.9555555555853061
Test: mae metric: 463.6971659342447083
lr: 0.0020000001
Epoch 2
-------------------------------
loss: 469823.8750000000000000 [ 0/190]
loss: 650901.1875000000000000 [ 10/190]
loss: 183339.9687500000000000 [ 20/190]
loss: 256283.4218750000000000 [ 30/190]
loss: 335927.1875000000000000 [ 40/190]
loss: 913293.2500000000000000 [ 50/190]
loss: 1257833.7500000000000000 [ 60/190]
loss: 630779.2500000000000000 [ 70/190]
loss: 1652336.1250000000000000 [ 80/190]
loss: 155349.2500000000000000 [ 90/190]
loss: 183506.0468750000000000 [100/190]
loss: 322167.0000000000000000 [110/190]
loss: 738248.0000000000000000 [120/190]
loss: 628022.9375000000000000 [130/190]
loss: 693525.0000000000000000 [140/190]
loss: 237971.7812500000000000 [150/190]
loss: 728099.2500000000000000 [160/190]
loss: 50060.8867187500000000 [170/190]
loss: 1229544.0000000000000000 [180/190]
train loss: 441576.3795847039436921, time gap: 26.0444
Test: mse loss: 366946.9187499999534339
Test: mae metric: 493.0364359537759924
lr: 0.0020000001
Epoch 3
-------------------------------
loss: 522800.6250000000000000 [ 0/190]
loss: 751694.5000000000000000 [ 10/190]
loss: 187226.3750000000000000 [ 20/190]
loss: 240447.8906250000000000 [ 30/190]
loss: 302177.7812500000000000 [ 40/190]
loss: 834946.6875000000000000 [ 50/190]
loss: 1170818.2500000000000000 [ 60/190]
loss: 596591.8750000000000000 [ 70/190]
loss: 1559648.0000000000000000 [ 80/190]
loss: 144896.1718750000000000 [ 90/190]
loss: 171495.7656250000000000 [100/190]
loss: 302823.1250000000000000 [110/190]
loss: 681209.3750000000000000 [120/190]
loss: 594635.8750000000000000 [130/190]
loss: 648062.7500000000000000 [140/190]
loss: 221139.5312500000000000 [150/190]
loss: 684927.0000000000000000 [160/190]
loss: 57762.1718750000000000 [170/190]
loss: 1197153.3750000000000000 [180/190]
train loss: 414760.5352384868310764, time gap: 25.4267
Test: mse loss: 337391.1312500000349246
Test: mae metric: 473.0654032389323334
lr: 0.0020000001
Epoch 4
-------------------------------
loss: 479449.6250000000000000 [ 0/190]
loss: 658094.1875000000000000 [ 10/190]
loss: 167701.1875000000000000 [ 20/190]
loss: 218166.0000000000000000 [ 30/190]
loss: 274594.4375000000000000 [ 40/190]
loss: 752581.0000000000000000 [ 50/190]
loss: 1039454.4375000000000000 [ 60/190]
loss: 581997.6250000000000000 [ 70/190]
loss: 1481623.0000000000000000 [ 80/190]
loss: 131388.5000000000000000 [ 90/190]
loss: 159510.3593750000000000 [100/190]
loss: 284669.9687500000000000 [110/190]
loss: 635406.0625000000000000 [120/190]
loss: 547961.0000000000000000 [130/190]
loss: 594799.6875000000000000 [140/190]
loss: 206942.5937500000000000 [150/190]
loss: 635930.1250000000000000 [160/190]
loss: 53651.5429687500000000 [170/190]
loss: 1120740.7500000000000000 [180/190]
train loss: 384702.7380550986854360, time gap: 25.3646
Test: mse loss: 312383.8472222221898846
Test: mae metric: 455.6222493489584053
lr: 0.0020000001
Epoch 5
-------------------------------
loss: 442066.5625000000000000 [ 0/190]
loss: 610366.8750000000000000 [ 10/190]
loss: 154912.0781250000000000 [ 20/190]
loss: 199916.3125000000000000 [ 30/190]
loss: 253701.6875000000000000 [ 40/190]
loss: 695447.4375000000000000 [ 50/190]
loss: 973856.6875000000000000 [ 60/190]
loss: 529174.5625000000000000 [ 70/190]
loss: 1359184.8750000000000000 [ 80/190]
loss: 120610.0546875000000000 [ 90/190]
loss: 145533.5312500000000000 [100/190]
loss: 253629.2500000000000000 [110/190]
loss: 602776.2500000000000000 [120/190]
loss: 479350.7187500000000000 [130/190]
loss: 522066.7812500000000000 [140/190]
loss: 197747.9687500000000000 [150/190]
loss: 585378.6875000000000000 [160/190]
loss: 39960.7265625000000000 [170/190]
loss: 1010730.2500000000000000 [180/190]
train loss: 355614.9552425986621529, time gap: 26.2478
Test: mse loss: 291521.1986111110891216
Test: mae metric: 440.5232421874999886
lr: 0.0020000001
Training Done!
Model Prediction
Define a custom pred function for model prediction and return the prediction results.
[11]:
def pred(configs, dtype=ms.float32):
"""Pred the model on the eval dataset."""
batch_size_eval = configs['BATCH_SIZE_EVAL']
n_eval = configs['N_EVAL']
print("Loading data... ")
data_path = configs['DATA_PATH']
_, _, _, ds_test, eval_edge_index, eval_batch, num_type = create_test_dataset(
config={
"path": data_path,
"batch_size_eval": batch_size_eval,
"n_val": n_eval,
},
dtype=dtype,
pred_force=False
)
# Define model
print("Initializing model... ")
model = build(num_type, configs)
# load checkpoint
ckpt_file = './model.ckpt'
ms.load_checkpoint(ckpt_file, model)
# Instantiate loss function and metric function
loss_fn = nn.MSELoss()
metric_fn = nn.MAELoss()
# == Evaluation ==
print("Initializing Evaluation... ")
print("seed is: %d" %ms.get_seed())
pred_list, test_loss, metric = evaluation(
model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, metric_fn
)
print("prediction saved")
print("Test: mse loss: %.16f" %test_loss)
print("Test: mae metric: %.16f" %metric)
print("Predict Done!")
return pred_list, test_loss, metric
def evaluation(model, dataset, edge_index, batch, batch_size, loss_fn, metric_fn):
"""evaluation"""
num_batches = dataset.get_dataset_size()
model.set_train(False)
test_loss = 0
metric = 0
pred_list = []
for _, data_dict in enumerate(dataset.create_dict_iterator()):
inputs, label = _unpack(data_dict)
if batch_size != 0:
atom_num = inputs[0].shape[0] / batch_size
else:
raise ValueError("batch_size should not be zero")
square_atom_num = atom_num ** 2
prediction = model(inputs[0], inputs[1], edge_index, batch, batch_size)
pred_list.append(prediction.asnumpy())
if square_atom_num != 0:
test_loss += loss_fn(prediction, label).asnumpy() / square_atom_num
else:
raise ValueError("square_atom_num should not be zero")
if atom_num != 0:
metric += metric_fn(prediction, label).asnumpy() / atom_num
else:
raise ValueError("atom_num should not be zero")
test_loss /= num_batches
metric /= num_batches
return pred_list, test_loss, metric
[14]:
pred(configs)
Loading data...
Initializing model...
Initializing Evaluation...
seed is: 123
prediction saved
Test: mse loss: 901.1434895833332348
Test: mae metric: 29.0822919209798201
Predict Done!
[14]:
([array([[-259531.56],
[-259377.28],
[-259534.83],
[-259243.62],
[-259541.62]], dtype=float32),
array([[-259516.4 ],
[-259519.81],
[-259545.69],
[-259428.45],
[-259527.28]], dtype=float32),
array([[-259508.94],
[-259521.22],
[-259533.28],
[-259465.56],
[-259523.88]], dtype=float32),
array([[-259533.56],
[-259303.9 ],
[-259509.53],
[-259369.22],
[-259514.4 ]], dtype=float32),
array([[-259368.25],
[-259487.45],
[-259545.94],
[-259379.47],
[-259494.19]], dtype=float32),
array([[-259533.64],
[-259453. ],
[-259542.69],
[-259451.9 ],
[-259213.11]], dtype=float32),
array([[-259562.5 ],
[-259531.6 ],
[-259526.5 ],
[-259530.3 ],
[-259389.12]], dtype=float32),
array([[-259515.03],
[-259530.69],
[-259476.9 ],
[-259267.77],
[-259535.11]], dtype=float32),
array([[-259548.77],
[-259530.8 ],
[-259401.7 ],
[-259542.12],
[-259419.86]], dtype=float32),
array([[-259386.81],
[-259291.75],
[-259419.61],
[-259488.25],
[-259334.34]], dtype=float32)],
901.1434895833332,
29.08229192097982)