快速入门

下载Notebook下载样例代码查看源文件

概述

本篇教程以Allegro预测分子间势能为例。

Allegro是基于等变图神经网络构建的SOTA模型,相关论文已经发表在期刊Nature Communications上,该案例验证了Allegro在分子势能预测中的有效性,具有较高的应用价值。

本教程介绍了Allegro的研究背景和技术路径,并展示了如何通过MindSpore Chemistry训练和快速推理模型。更多信息参见文章

技术路径

MindSpore Chemistry求解分子势能预测问题的具体流程如下:

  1. 创建数据集

  2. 模型构建

  3. 损失函数

  4. 优化器

  5. 模型训练

  6. 模型推理

[1]:
import time
import random

import mindspore as ms
import numpy as np
from mindspore import nn
from mindspore.experimental import optim

下述src可以在allegro/src下载。

[13]:
from mindchemistry.cell.allegro import Allegro
from mindchemistry.utils.load_config import load_yaml_config_from_path

from src.allegro_embedding import AllegroEmbedding
from src.dataset import create_training_dataset, create_test_dataset
from src.potential import Potential
[3]:
ms.set_seed(123)
ms.dataset.config.set_seed(1)
np.random.seed(1)
random.seed(1)

模型涉及的参数、优化器、数据配置见config

[4]:
configs = load_yaml_config_from_path("rmd.yaml")
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_device("Ascend", 0)

创建数据集

Revised MD17 dataset (rMD17)下,下载训练数据集验证数据集到 ./dataset/rmd17/npz_data/目录。默认配置文件读取数据集的路径为 dataset/rmd17/npz_data/rmd17_uracil.npz。

[5]:
n_epoch = 5
batch_size = configs['BATCH_SIZE']
batch_size_eval = configs['BATCH_SIZE_EVAL']
learning_rate = configs['LEARNING_RATE']
is_profiling = configs['IS_PROFILING']
shuffle = configs['SHUFFLE']
split_random = configs['SPLIT_RANDOM']
lrdecay = configs['LRDECAY']
n_train = configs['N_TRAIN']
n_eval = configs['N_EVAL']
patience = configs['PATIENCE']
factor = configs['FACTOR']
parallel_mode = "NONE"

print("Loading data...                ")
data_path = configs['DATA_PATH']
ds_train, edge_index, batch, ds_test, eval_edge_index, eval_batch, num_type = create_training_dataset(
    config={
        "path": data_path,
        "batch_size": batch_size,
        "batch_size_eval": batch_size_eval,
        "n_train": n_train,
        "n_val": n_eval,
        "split_random": split_random,
        "shuffle": shuffle
    },
    dtype=ms.float32,
    pred_force=False,
    parallel_mode=parallel_mode
)
Loading data...

模型构建

Allegro模型可从mindchemistry库导入,Embedding与势能预测模块可从src导入。

[6]:
def build(num_type, configs):
    """ Build Potential model

    Args:
        num_atom (int): number of atoms

    Returns:
        net (Potential): Potential model
    """
    literal_hidden_dims = 'hidden_dims'
    literal_activation = 'activation'
    literal_weight_init = 'weight_init'
    literal_uniform = 'uniform'

    emb = AllegroEmbedding(
        num_type=num_type,
        cutoff=configs['CUTOFF']
    )

    model = Allegro(
        l_max=configs['L_MAX'],
        irreps_in={
            "pos": "1x1o",
            "edge_index": None,
            "node_attrs": f"{num_type}x0e",
            "node_features": f"{num_type}x0e",
            "edge_embedding": f"{configs['NUM_BASIS']}x0e"
        },
        avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
        num_layers=configs['NUM_LAYERS'],
        env_embed_multi=configs['ENV_EMBED_MULTI'],
        two_body_kwargs={
            literal_hidden_dims: configs['two_body_latent_mlp_latent_dimensions'],
            literal_activation: 'silu',
            literal_weight_init: literal_uniform
        },
        latent_kwargs={
            literal_hidden_dims: configs['latent_mlp_latent_dimensions'],
            literal_activation: 'silu',
            literal_weight_init: literal_uniform
        },
        env_embed_kwargs={
            literal_hidden_dims: configs['env_embed_mlp_latent_dimensions'],
            literal_activation: None,
            literal_weight_init: literal_uniform
        },
        enable_mix_precision=configs['enable_mix_precision'],
    )

    net = Potential(
        embedding=emb,
        model=model,
        avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],
        edge_eng_mlp_latent_dimensions=configs['edge_eng_mlp_latent_dimensions']
    )

    return net
[7]:
print("Initializing model...              ")
model = build(num_type, configs)
Initializing model...

损失函数

Allegro在模型训练中使用平均平方误差和平均绝对误差。

[8]:
loss_fn = nn.MSELoss()
metric_fn = nn.MAELoss()

优化器

使用Adam优化器,学习率更新策略采用ReduceLROnPlateau。

[9]:
optimizer = optim.Adam(params=model.trainable_params(), lr=learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=patience)

模型训练

在本教程中,我们自定义train_step与test_step,并进行模型训练。

[10]:
# 1. Define forward function
def forward(x, pos, edge_index, batch, batch_size, energy):
    pred = model(x, pos, edge_index, batch, batch_size)
    loss = loss_fn(pred, energy)
    if batch_size != 0:
        square_atom_num = (x.shape[0] / batch_size) ** 2
    else:
        raise ValueError("batch_size should not be zero")
    if square_atom_num != 0:
        loss = loss / square_atom_num
    else:
        raise ValueError("square_atom_num should not be zero")
    return loss

# 2. Get gradient function
backward = ms.value_and_grad(forward, None, optimizer.parameters)

# 3. Define function of one-step training and validation
@ms.jit
def train_step(x, pos, edge_index, batch, batch_size, energy):
    loss_, grads_ = backward(x, pos, edge_index, batch, batch_size, energy)
    optimizer(grads_)
    return loss_

@ms.jit
def test_step(x, pos, edge_index, batch, batch_size):
    return model(x, pos, edge_index, batch, batch_size)

def _unpack(data):
    return (data['x'], data['pos']), data['energy']

def train_epoch(model, trainset, edge_index, batch, batch_size, loss_train: list):
    size = trainset.get_dataset_size()
    model.set_train()
    total_train_loss = 0
    loss_train_epoch = []
    ti = time.time()
    for current, data_dict in enumerate(trainset.create_dict_iterator()):
        inputs, label = _unpack(data_dict)
        loss = train_step(inputs[0], inputs[1], edge_index, batch, batch_size, label)
        # AtomWise
        loss = loss.asnumpy()
        loss_train_epoch.append(loss)
        if current % 10 == 0:
            # pylint: disable=W1203
            print(f"loss: {loss:.16f}  [{current:>3d}/{size:>3d}]")
        total_train_loss += loss

    loss_train.append(loss_train_epoch)
    if size != 0:
        loss_train_avg = total_train_loss / size
    else:
        raise ValueError("size should not be zero")
    t_now = time.time()
    print('train loss: %.16f, time gap: %.4f' %(loss_train_avg, (t_now - ti)))

def test(model, dataset, edge_index, batch, batch_size, loss_fn, loss_eval: list, metric_fn, metric_list: list):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    test_loss = 0
    metric = 0
    for _, data_dict in enumerate(dataset.create_dict_iterator()):
        inputs, label = _unpack(data_dict)
        if batch_size != 0:
            atom_num = inputs[0].shape[0] / batch_size
        else:
            raise ValueError("batch_size should not be zero")
        square_atom_num = atom_num ** 2
        pred = test_step(inputs[0], inputs[1], edge_index, batch, batch_size)
        if square_atom_num != 0:
            test_loss += loss_fn(pred, label).asnumpy() / square_atom_num
        else:
            raise ValueError("square_atom_num should not be zero")
        if atom_num != 0:
            metric += metric_fn(pred, label).asnumpy() / atom_num
        else:
            raise ValueError("atom_num should not be zero")

    test_loss /= num_batches
    metric /= num_batches
    # AtomWise
    loss_eval.append(test_loss)
    metric_list.append(metric)
    print("Test: mse loss: %.16f" %test_loss)
    print("Test: mae metric: %.16f" %metric)
    return test_loss

# == Training ==
if is_profiling:
    print("Initializing profiler...      ")
    profiler = ms.Profiler(output_path="dump_output" + "/profiler_data", profile_memory=True)

print("Initializing train...         ")
print("seed is: %d" %ms.get_seed())
loss_eval = []
loss_train = []
metric_list = []
for t in range(n_epoch):
    print("Epoch %d\n-------------------------------" %(t + 1))
    train_epoch(model, ds_train, edge_index, batch, batch_size, loss_train)
    test_loss = test(
        model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, loss_eval, metric_fn, metric_list
    )

    if lrdecay:
        lr_scheduler.step(test_loss)
        last_lr = optimizer.param_groups[0].get('lr').value()
        print("lr: %.10f\n" %last_lr)

    if (t + 1) % 50 == 0:
        ms.save_checkpoint(model, "./model.ckpt")

if is_profiling:
    profiler.analyse()

print("Training Done!")
Initializing train...
seed is: 123
Epoch 1
-------------------------------
loss: 468775552.0000000000000000  [  0/190]
loss: 468785216.0000000000000000  [ 10/190]
loss: 468774752.0000000000000000  [ 20/190]
loss: 468760512.0000000000000000  [ 30/190]
loss: 468342560.0000000000000000  [ 40/190]
loss: 414531872.0000000000000000  [ 50/190]
loss: 435014.9062500000000000  [ 60/190]
loss: 132964368.0000000000000000  [ 70/190]
loss: 82096352.0000000000000000  [ 80/190]
loss: 12417458.0000000000000000  [ 90/190]
loss: 202487.4687500000000000  [100/190]
loss: 300066.6875000000000000  [110/190]
loss: 468295.9375000000000000  [120/190]
loss: 1230706.0000000000000000  [130/190]
loss: 487508.2812500000000000  [140/190]
loss: 242425.6406250000000000  [150/190]
loss: 841241.0000000000000000  [160/190]
loss: 84912.1328125000000000  [170/190]
loss: 1272812.5000000000000000  [180/190]
train loss: 139695450.3818462193012238, time gap: 101.1313
Test: mse loss: 328262.9555555555853061
Test: mae metric: 463.6971659342447083
lr: 0.0020000001

Epoch 2
-------------------------------
loss: 469823.8750000000000000  [  0/190]
loss: 650901.1875000000000000  [ 10/190]
loss: 183339.9687500000000000  [ 20/190]
loss: 256283.4218750000000000  [ 30/190]
loss: 335927.1875000000000000  [ 40/190]
loss: 913293.2500000000000000  [ 50/190]
loss: 1257833.7500000000000000  [ 60/190]
loss: 630779.2500000000000000  [ 70/190]
loss: 1652336.1250000000000000  [ 80/190]
loss: 155349.2500000000000000  [ 90/190]
loss: 183506.0468750000000000  [100/190]
loss: 322167.0000000000000000  [110/190]
loss: 738248.0000000000000000  [120/190]
loss: 628022.9375000000000000  [130/190]
loss: 693525.0000000000000000  [140/190]
loss: 237971.7812500000000000  [150/190]
loss: 728099.2500000000000000  [160/190]
loss: 50060.8867187500000000  [170/190]
loss: 1229544.0000000000000000  [180/190]
train loss: 441576.3795847039436921, time gap: 26.0444
Test: mse loss: 366946.9187499999534339
Test: mae metric: 493.0364359537759924
lr: 0.0020000001

Epoch 3
-------------------------------
loss: 522800.6250000000000000  [  0/190]
loss: 751694.5000000000000000  [ 10/190]
loss: 187226.3750000000000000  [ 20/190]
loss: 240447.8906250000000000  [ 30/190]
loss: 302177.7812500000000000  [ 40/190]
loss: 834946.6875000000000000  [ 50/190]
loss: 1170818.2500000000000000  [ 60/190]
loss: 596591.8750000000000000  [ 70/190]
loss: 1559648.0000000000000000  [ 80/190]
loss: 144896.1718750000000000  [ 90/190]
loss: 171495.7656250000000000  [100/190]
loss: 302823.1250000000000000  [110/190]
loss: 681209.3750000000000000  [120/190]
loss: 594635.8750000000000000  [130/190]
loss: 648062.7500000000000000  [140/190]
loss: 221139.5312500000000000  [150/190]
loss: 684927.0000000000000000  [160/190]
loss: 57762.1718750000000000  [170/190]
loss: 1197153.3750000000000000  [180/190]
train loss: 414760.5352384868310764, time gap: 25.4267
Test: mse loss: 337391.1312500000349246
Test: mae metric: 473.0654032389323334
lr: 0.0020000001

Epoch 4
-------------------------------
loss: 479449.6250000000000000  [  0/190]
loss: 658094.1875000000000000  [ 10/190]
loss: 167701.1875000000000000  [ 20/190]
loss: 218166.0000000000000000  [ 30/190]
loss: 274594.4375000000000000  [ 40/190]
loss: 752581.0000000000000000  [ 50/190]
loss: 1039454.4375000000000000  [ 60/190]
loss: 581997.6250000000000000  [ 70/190]
loss: 1481623.0000000000000000  [ 80/190]
loss: 131388.5000000000000000  [ 90/190]
loss: 159510.3593750000000000  [100/190]
loss: 284669.9687500000000000  [110/190]
loss: 635406.0625000000000000  [120/190]
loss: 547961.0000000000000000  [130/190]
loss: 594799.6875000000000000  [140/190]
loss: 206942.5937500000000000  [150/190]
loss: 635930.1250000000000000  [160/190]
loss: 53651.5429687500000000  [170/190]
loss: 1120740.7500000000000000  [180/190]
train loss: 384702.7380550986854360, time gap: 25.3646
Test: mse loss: 312383.8472222221898846
Test: mae metric: 455.6222493489584053
lr: 0.0020000001

Epoch 5
-------------------------------
loss: 442066.5625000000000000  [  0/190]
loss: 610366.8750000000000000  [ 10/190]
loss: 154912.0781250000000000  [ 20/190]
loss: 199916.3125000000000000  [ 30/190]
loss: 253701.6875000000000000  [ 40/190]
loss: 695447.4375000000000000  [ 50/190]
loss: 973856.6875000000000000  [ 60/190]
loss: 529174.5625000000000000  [ 70/190]
loss: 1359184.8750000000000000  [ 80/190]
loss: 120610.0546875000000000  [ 90/190]
loss: 145533.5312500000000000  [100/190]
loss: 253629.2500000000000000  [110/190]
loss: 602776.2500000000000000  [120/190]
loss: 479350.7187500000000000  [130/190]
loss: 522066.7812500000000000  [140/190]
loss: 197747.9687500000000000  [150/190]
loss: 585378.6875000000000000  [160/190]
loss: 39960.7265625000000000  [170/190]
loss: 1010730.2500000000000000  [180/190]
train loss: 355614.9552425986621529, time gap: 26.2478
Test: mse loss: 291521.1986111110891216
Test: mae metric: 440.5232421874999886
lr: 0.0020000001

Training Done!

模型推理

自定义pred函数进行模型推理,返回推理结果。

[11]:
def pred(configs, dtype=ms.float32):
    """Pred the model on the eval dataset."""
    batch_size_eval = configs['BATCH_SIZE_EVAL']
    n_eval = configs['N_EVAL']

    print("Loading data...                ")
    data_path = configs['DATA_PATH']
    _, _, _, ds_test, eval_edge_index, eval_batch, num_type = create_test_dataset(
        config={
            "path": data_path,
            "batch_size_eval": batch_size_eval,
            "n_val": n_eval,
        },
        dtype=dtype,
        pred_force=False
    )

    # Define model
    print("Initializing model...              ")
    model = build(num_type, configs)

    # load checkpoint
    ckpt_file = './model.ckpt'
    ms.load_checkpoint(ckpt_file, model)

    # Instantiate loss function and metric function
    loss_fn = nn.MSELoss()
    metric_fn = nn.MAELoss()

    # == Evaluation ==
    print("Initializing Evaluation...         ")
    print("seed is: %d" %ms.get_seed())

    pred_list, test_loss, metric = evaluation(
        model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, metric_fn
    )

    print("prediction saved")
    print("Test: mse loss: %.16f" %test_loss)
    print("Test: mae metric: %.16f" %metric)

    print("Predict Done!")

    return pred_list, test_loss, metric


def evaluation(model, dataset, edge_index, batch, batch_size, loss_fn, metric_fn):
    """evaluation"""
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    test_loss = 0
    metric = 0
    pred_list = []
    for _, data_dict in enumerate(dataset.create_dict_iterator()):
        inputs, label = _unpack(data_dict)
        if batch_size != 0:
            atom_num = inputs[0].shape[0] / batch_size
        else:
            raise ValueError("batch_size should not be zero")
        square_atom_num = atom_num ** 2
        prediction = model(inputs[0], inputs[1], edge_index, batch, batch_size)
        pred_list.append(prediction.asnumpy())
        if square_atom_num != 0:
            test_loss += loss_fn(prediction, label).asnumpy() / square_atom_num
        else:
            raise ValueError("square_atom_num should not be zero")
        if atom_num != 0:
            metric += metric_fn(prediction, label).asnumpy() / atom_num
        else:
            raise ValueError("atom_num should not be zero")

    test_loss /= num_batches
    metric /= num_batches

    return pred_list, test_loss, metric
[14]:
pred(configs)
Loading data...
Initializing model...
Initializing Evaluation...
seed is: 123
prediction saved
Test: mse loss: 901.1434895833332348
Test: mae metric: 29.0822919209798201
Predict Done!
[14]:
([array([[-259531.56],
         [-259377.28],
         [-259534.83],
         [-259243.62],
         [-259541.62]], dtype=float32),
  array([[-259516.4 ],
         [-259519.81],
         [-259545.69],
         [-259428.45],
         [-259527.28]], dtype=float32),
  array([[-259508.94],
         [-259521.22],
         [-259533.28],
         [-259465.56],
         [-259523.88]], dtype=float32),
  array([[-259533.56],
         [-259303.9 ],
         [-259509.53],
         [-259369.22],
         [-259514.4 ]], dtype=float32),
  array([[-259368.25],
         [-259487.45],
         [-259545.94],
         [-259379.47],
         [-259494.19]], dtype=float32),
  array([[-259533.64],
         [-259453.  ],
         [-259542.69],
         [-259451.9 ],
         [-259213.11]], dtype=float32),
  array([[-259562.5 ],
         [-259531.6 ],
         [-259526.5 ],
         [-259530.3 ],
         [-259389.12]], dtype=float32),
  array([[-259515.03],
         [-259530.69],
         [-259476.9 ],
         [-259267.77],
         [-259535.11]], dtype=float32),
  array([[-259548.77],
         [-259530.8 ],
         [-259401.7 ],
         [-259542.12],
         [-259419.86]], dtype=float32),
  array([[-259386.81],
         [-259291.75],
         [-259419.61],
         [-259488.25],
         [-259334.34]], dtype=float32)],
 901.1434895833332,
 29.08229192097982)