{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# 快速入门\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindchemistry/zh_cn/quick_start/mindspore_quick_start.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/mindchemistry/zh_cn/quick_start/mindspore_quick_start.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/mindchemistry/docs/source_zh_cn/quick_start/quick_start.ipynb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 概述\n", "\n", "本篇教程以Allegro预测分子间势能为例。\n", "\n", "Allegro是基于等变图神经网络构建的SOTA模型,相关论文已经发表在期刊Nature Communications上,该案例验证了Allegro在分子势能预测中的有效性,具有较高的应用价值。\n", "\n", "本教程介绍了Allegro的研究背景和技术路径,并展示了如何通过MindSpore Chemistry训练和快速推理模型。更多信息参见[文章](https://www.nature.com/articles/s41467-023-36329-y)。" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 技术路径\n", "\n", "MindSpore Chemistry求解分子势能预测问题的具体流程如下:\n", "\n", "1. 创建数据集\n", "2. 模型构建\n", "3. 损失函数\n", "4. 优化器\n", "5. 模型训练\n", "6. 模型推理" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import time\n", "import random\n", "\n", "import mindspore as ms\n", "import numpy as np\n", "from mindspore import nn\n", "from mindspore.experimental import optim" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "下述`src`可以在[allegro/src](https://gitee.com/mindspore/mindscience/tree/master/MindChemistry/applications/allegro/src)下载。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from mindchemistry.cell.allegro import Allegro\n", "from mindchemistry.utils.load_config import load_yaml_config_from_path\n", "\n", "from src.allegro_embedding import AllegroEmbedding\n", "from src.dataset import create_training_dataset, create_test_dataset\n", "from src.potential import Potential" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ms.set_seed(123)\n", "ms.dataset.config.set_seed(1)\n", "np.random.seed(1)\n", "random.seed(1)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "模型涉及的参数、优化器、数据配置见[config](https://gitee.com/mindspore/mindscience/blob/master/MindChemistry/applications/allegro/rmd.yaml)。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "scrolled": true }, "outputs": [], "source": [ "configs = load_yaml_config_from_path(\"rmd.yaml\")\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "ms.set_device(\"Ascend\", 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建数据集\n", "\n", "在[Revised MD17 dataset (rMD17)](https://gitee.com/link?target=https%3A%2F%2Ffigshare.com%2Farticles%2Fdataset%2FRevised_MD17_dataset_rMD17_%2F12672038)下,下载训练数据集验证数据集到 `./dataset/rmd17/npz_data/`目录。默认配置文件读取数据集的路径为 dataset/rmd17/npz_data/rmd17_uracil.npz。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data... \n" ] } ], "source": [ "n_epoch = 5\n", "batch_size = configs['BATCH_SIZE']\n", "batch_size_eval = configs['BATCH_SIZE_EVAL']\n", "learning_rate = configs['LEARNING_RATE']\n", "is_profiling = configs['IS_PROFILING']\n", "shuffle = configs['SHUFFLE']\n", "split_random = configs['SPLIT_RANDOM']\n", "lrdecay = configs['LRDECAY']\n", "n_train = configs['N_TRAIN']\n", "n_eval = configs['N_EVAL']\n", "patience = configs['PATIENCE']\n", "factor = configs['FACTOR']\n", "parallel_mode = \"NONE\"\n", "\n", "print(\"Loading data... \")\n", "data_path = configs['DATA_PATH']\n", "ds_train, edge_index, batch, ds_test, eval_edge_index, eval_batch, num_type = create_training_dataset(\n", " config={\n", " \"path\": data_path,\n", " \"batch_size\": batch_size,\n", " \"batch_size_eval\": batch_size_eval,\n", " \"n_train\": n_train,\n", " \"n_val\": n_eval,\n", " \"split_random\": split_random,\n", " \"shuffle\": shuffle\n", " },\n", " dtype=ms.float32,\n", " pred_force=False,\n", " parallel_mode=parallel_mode\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%% md\n" } }, "source": [ "## 模型构建\n", "\n", "Allegro模型可从mindchemistry库导入,Embedding与势能预测模块可从src导入。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def build(num_type, configs):\n", " \"\"\" Build Potential model\n", "\n", " Args:\n", " num_atom (int): number of atoms\n", "\n", " Returns:\n", " net (Potential): Potential model\n", " \"\"\"\n", " literal_hidden_dims = 'hidden_dims'\n", " literal_activation = 'activation'\n", " literal_weight_init = 'weight_init'\n", " literal_uniform = 'uniform'\n", "\n", " emb = AllegroEmbedding(\n", " num_type=num_type,\n", " cutoff=configs['CUTOFF']\n", " )\n", "\n", " model = Allegro(\n", " l_max=configs['L_MAX'],\n", " irreps_in={\n", " \"pos\": \"1x1o\",\n", " \"edge_index\": None,\n", " \"node_attrs\": f\"{num_type}x0e\",\n", " \"node_features\": f\"{num_type}x0e\",\n", " \"edge_embedding\": f\"{configs['NUM_BASIS']}x0e\"\n", " },\n", " avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],\n", " num_layers=configs['NUM_LAYERS'],\n", " env_embed_multi=configs['ENV_EMBED_MULTI'],\n", " two_body_kwargs={\n", " literal_hidden_dims: configs['two_body_latent_mlp_latent_dimensions'],\n", " literal_activation: 'silu',\n", " literal_weight_init: literal_uniform\n", " },\n", " latent_kwargs={\n", " literal_hidden_dims: configs['latent_mlp_latent_dimensions'],\n", " literal_activation: 'silu',\n", " literal_weight_init: literal_uniform\n", " },\n", " env_embed_kwargs={\n", " literal_hidden_dims: configs['env_embed_mlp_latent_dimensions'],\n", " literal_activation: None,\n", " literal_weight_init: literal_uniform\n", " },\n", " enable_mix_precision=configs['enable_mix_precision'],\n", " )\n", "\n", " net = Potential(\n", " embedding=emb,\n", " model=model,\n", " avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],\n", " edge_eng_mlp_latent_dimensions=configs['edge_eng_mlp_latent_dimensions']\n", " )\n", "\n", " return net" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing model... \n" ] } ], "source": [ "print(\"Initializing model... \")\n", "model = build(num_type, configs)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 损失函数\n", "\n", "Allegro在模型训练中使用平均平方误差和平均绝对误差。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "loss_fn = nn.MSELoss()\n", "metric_fn = nn.MAELoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 优化器\n", "\n", "使用Adam优化器,学习率更新策略采用ReduceLROnPlateau。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.Adam(params=model.trainable_params(), lr=learning_rate)\n", "lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=patience)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 模型训练\n", "\n", "在本教程中,我们自定义train_step与test_step,并进行模型训练。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "pycharm": { "name": "#%%\n" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing train... \n", "seed is: 123\n", "Epoch 1\n", "-------------------------------\n", "loss: 468775552.0000000000000000 [ 0/190]\n", "loss: 468785216.0000000000000000 [ 10/190]\n", "loss: 468774752.0000000000000000 [ 20/190]\n", "loss: 468760512.0000000000000000 [ 30/190]\n", "loss: 468342560.0000000000000000 [ 40/190]\n", "loss: 414531872.0000000000000000 [ 50/190]\n", "loss: 435014.9062500000000000 [ 60/190]\n", "loss: 132964368.0000000000000000 [ 70/190]\n", "loss: 82096352.0000000000000000 [ 80/190]\n", "loss: 12417458.0000000000000000 [ 90/190]\n", "loss: 202487.4687500000000000 [100/190]\n", "loss: 300066.6875000000000000 [110/190]\n", "loss: 468295.9375000000000000 [120/190]\n", "loss: 1230706.0000000000000000 [130/190]\n", "loss: 487508.2812500000000000 [140/190]\n", "loss: 242425.6406250000000000 [150/190]\n", "loss: 841241.0000000000000000 [160/190]\n", "loss: 84912.1328125000000000 [170/190]\n", "loss: 1272812.5000000000000000 [180/190]\n", "train loss: 139695450.3818462193012238, time gap: 101.1313\n", "Test: mse loss: 328262.9555555555853061\n", "Test: mae metric: 463.6971659342447083\n", "lr: 0.0020000001\n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 469823.8750000000000000 [ 0/190]\n", "loss: 650901.1875000000000000 [ 10/190]\n", "loss: 183339.9687500000000000 [ 20/190]\n", "loss: 256283.4218750000000000 [ 30/190]\n", "loss: 335927.1875000000000000 [ 40/190]\n", "loss: 913293.2500000000000000 [ 50/190]\n", "loss: 1257833.7500000000000000 [ 60/190]\n", "loss: 630779.2500000000000000 [ 70/190]\n", "loss: 1652336.1250000000000000 [ 80/190]\n", "loss: 155349.2500000000000000 [ 90/190]\n", "loss: 183506.0468750000000000 [100/190]\n", "loss: 322167.0000000000000000 [110/190]\n", "loss: 738248.0000000000000000 [120/190]\n", "loss: 628022.9375000000000000 [130/190]\n", "loss: 693525.0000000000000000 [140/190]\n", "loss: 237971.7812500000000000 [150/190]\n", "loss: 728099.2500000000000000 [160/190]\n", "loss: 50060.8867187500000000 [170/190]\n", "loss: 1229544.0000000000000000 [180/190]\n", "train loss: 441576.3795847039436921, time gap: 26.0444\n", "Test: mse loss: 366946.9187499999534339\n", "Test: mae metric: 493.0364359537759924\n", "lr: 0.0020000001\n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 522800.6250000000000000 [ 0/190]\n", "loss: 751694.5000000000000000 [ 10/190]\n", "loss: 187226.3750000000000000 [ 20/190]\n", "loss: 240447.8906250000000000 [ 30/190]\n", "loss: 302177.7812500000000000 [ 40/190]\n", "loss: 834946.6875000000000000 [ 50/190]\n", "loss: 1170818.2500000000000000 [ 60/190]\n", "loss: 596591.8750000000000000 [ 70/190]\n", "loss: 1559648.0000000000000000 [ 80/190]\n", "loss: 144896.1718750000000000 [ 90/190]\n", "loss: 171495.7656250000000000 [100/190]\n", "loss: 302823.1250000000000000 [110/190]\n", "loss: 681209.3750000000000000 [120/190]\n", "loss: 594635.8750000000000000 [130/190]\n", "loss: 648062.7500000000000000 [140/190]\n", "loss: 221139.5312500000000000 [150/190]\n", "loss: 684927.0000000000000000 [160/190]\n", "loss: 57762.1718750000000000 [170/190]\n", "loss: 1197153.3750000000000000 [180/190]\n", "train loss: 414760.5352384868310764, time gap: 25.4267\n", "Test: mse loss: 337391.1312500000349246\n", "Test: mae metric: 473.0654032389323334\n", "lr: 0.0020000001\n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 479449.6250000000000000 [ 0/190]\n", "loss: 658094.1875000000000000 [ 10/190]\n", "loss: 167701.1875000000000000 [ 20/190]\n", "loss: 218166.0000000000000000 [ 30/190]\n", "loss: 274594.4375000000000000 [ 40/190]\n", "loss: 752581.0000000000000000 [ 50/190]\n", "loss: 1039454.4375000000000000 [ 60/190]\n", "loss: 581997.6250000000000000 [ 70/190]\n", "loss: 1481623.0000000000000000 [ 80/190]\n", "loss: 131388.5000000000000000 [ 90/190]\n", "loss: 159510.3593750000000000 [100/190]\n", "loss: 284669.9687500000000000 [110/190]\n", "loss: 635406.0625000000000000 [120/190]\n", "loss: 547961.0000000000000000 [130/190]\n", "loss: 594799.6875000000000000 [140/190]\n", "loss: 206942.5937500000000000 [150/190]\n", "loss: 635930.1250000000000000 [160/190]\n", "loss: 53651.5429687500000000 [170/190]\n", "loss: 1120740.7500000000000000 [180/190]\n", "train loss: 384702.7380550986854360, time gap: 25.3646\n", "Test: mse loss: 312383.8472222221898846\n", "Test: mae metric: 455.6222493489584053\n", "lr: 0.0020000001\n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 442066.5625000000000000 [ 0/190]\n", "loss: 610366.8750000000000000 [ 10/190]\n", "loss: 154912.0781250000000000 [ 20/190]\n", "loss: 199916.3125000000000000 [ 30/190]\n", "loss: 253701.6875000000000000 [ 40/190]\n", "loss: 695447.4375000000000000 [ 50/190]\n", "loss: 973856.6875000000000000 [ 60/190]\n", "loss: 529174.5625000000000000 [ 70/190]\n", "loss: 1359184.8750000000000000 [ 80/190]\n", "loss: 120610.0546875000000000 [ 90/190]\n", "loss: 145533.5312500000000000 [100/190]\n", "loss: 253629.2500000000000000 [110/190]\n", "loss: 602776.2500000000000000 [120/190]\n", "loss: 479350.7187500000000000 [130/190]\n", "loss: 522066.7812500000000000 [140/190]\n", "loss: 197747.9687500000000000 [150/190]\n", "loss: 585378.6875000000000000 [160/190]\n", "loss: 39960.7265625000000000 [170/190]\n", "loss: 1010730.2500000000000000 [180/190]\n", "train loss: 355614.9552425986621529, time gap: 26.2478\n", "Test: mse loss: 291521.1986111110891216\n", "Test: mae metric: 440.5232421874999886\n", "lr: 0.0020000001\n", "\n", "Training Done!\n" ] } ], "source": [ "# 1. Define forward function\n", "def forward(x, pos, edge_index, batch, batch_size, energy):\n", " pred = model(x, pos, edge_index, batch, batch_size)\n", " loss = loss_fn(pred, energy)\n", " if batch_size != 0:\n", " square_atom_num = (x.shape[0] / batch_size) ** 2\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " if square_atom_num != 0:\n", " loss = loss / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " return loss\n", "\n", "# 2. Get gradient function\n", "backward = ms.value_and_grad(forward, None, optimizer.parameters)\n", "\n", "# 3. Define function of one-step training and validation\n", "@ms.jit\n", "def train_step(x, pos, edge_index, batch, batch_size, energy):\n", " loss_, grads_ = backward(x, pos, edge_index, batch, batch_size, energy)\n", " optimizer(grads_)\n", " return loss_\n", "\n", "@ms.jit\n", "def test_step(x, pos, edge_index, batch, batch_size):\n", " return model(x, pos, edge_index, batch, batch_size)\n", "\n", "def _unpack(data):\n", " return (data['x'], data['pos']), data['energy']\n", "\n", "def train_epoch(model, trainset, edge_index, batch, batch_size, loss_train: list):\n", " size = trainset.get_dataset_size()\n", " model.set_train()\n", " total_train_loss = 0\n", " loss_train_epoch = []\n", " ti = time.time()\n", " for current, data_dict in enumerate(trainset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " loss = train_step(inputs[0], inputs[1], edge_index, batch, batch_size, label)\n", " # AtomWise\n", " loss = loss.asnumpy()\n", " loss_train_epoch.append(loss)\n", " if current % 10 == 0:\n", " # pylint: disable=W1203\n", " print(f\"loss: {loss:.16f} [{current:>3d}/{size:>3d}]\")\n", " total_train_loss += loss\n", "\n", " loss_train.append(loss_train_epoch)\n", " if size != 0:\n", " loss_train_avg = total_train_loss / size\n", " else:\n", " raise ValueError(\"size should not be zero\")\n", " t_now = time.time()\n", " print('train loss: %.16f, time gap: %.4f' %(loss_train_avg, (t_now - ti)))\n", "\n", "def test(model, dataset, edge_index, batch, batch_size, loss_fn, loss_eval: list, metric_fn, metric_list: list):\n", " num_batches = dataset.get_dataset_size()\n", " model.set_train(False)\n", " test_loss = 0\n", " metric = 0\n", " for _, data_dict in enumerate(dataset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " if batch_size != 0:\n", " atom_num = inputs[0].shape[0] / batch_size\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " square_atom_num = atom_num ** 2\n", " pred = test_step(inputs[0], inputs[1], edge_index, batch, batch_size)\n", " if square_atom_num != 0:\n", " test_loss += loss_fn(pred, label).asnumpy() / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " if atom_num != 0:\n", " metric += metric_fn(pred, label).asnumpy() / atom_num\n", " else:\n", " raise ValueError(\"atom_num should not be zero\")\n", "\n", " test_loss /= num_batches\n", " metric /= num_batches\n", " # AtomWise\n", " loss_eval.append(test_loss)\n", " metric_list.append(metric)\n", " print(\"Test: mse loss: %.16f\" %test_loss)\n", " print(\"Test: mae metric: %.16f\" %metric)\n", " return test_loss\n", "\n", "# == Training ==\n", "if is_profiling:\n", " print(\"Initializing profiler... \")\n", " profiler = ms.Profiler(output_path=\"dump_output\" + \"/profiler_data\", profile_memory=True)\n", "\n", "print(\"Initializing train... \")\n", "print(\"seed is: %d\" %ms.get_seed())\n", "loss_eval = []\n", "loss_train = []\n", "metric_list = []\n", "for t in range(n_epoch):\n", " print(\"Epoch %d\\n-------------------------------\" %(t + 1))\n", " train_epoch(model, ds_train, edge_index, batch, batch_size, loss_train)\n", " test_loss = test(\n", " model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, loss_eval, metric_fn, metric_list\n", " )\n", "\n", " if lrdecay:\n", " lr_scheduler.step(test_loss)\n", " last_lr = optimizer.param_groups[0].get('lr').value()\n", " print(\"lr: %.10f\\n\" %last_lr)\n", "\n", " if (t + 1) % 50 == 0:\n", " ms.save_checkpoint(model, \"./model.ckpt\")\n", "\n", "if is_profiling:\n", " profiler.analyse()\n", "\n", "print(\"Training Done!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型推理\n", "\n", "自定义pred函数进行模型推理,返回推理结果。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def pred(configs, dtype=ms.float32):\n", " \"\"\"Pred the model on the eval dataset.\"\"\"\n", " batch_size_eval = configs['BATCH_SIZE_EVAL']\n", " n_eval = configs['N_EVAL']\n", "\n", " print(\"Loading data... \")\n", " data_path = configs['DATA_PATH']\n", " _, _, _, ds_test, eval_edge_index, eval_batch, num_type = create_test_dataset(\n", " config={\n", " \"path\": data_path,\n", " \"batch_size_eval\": batch_size_eval,\n", " \"n_val\": n_eval,\n", " },\n", " dtype=dtype,\n", " pred_force=False\n", " )\n", "\n", " # Define model\n", " print(\"Initializing model... \")\n", " model = build(num_type, configs)\n", "\n", " # load checkpoint\n", " ckpt_file = './model.ckpt'\n", " ms.load_checkpoint(ckpt_file, model)\n", "\n", " # Instantiate loss function and metric function\n", " loss_fn = nn.MSELoss()\n", " metric_fn = nn.MAELoss()\n", "\n", " # == Evaluation ==\n", " print(\"Initializing Evaluation... \")\n", " print(\"seed is: %d\" %ms.get_seed())\n", "\n", " pred_list, test_loss, metric = evaluation(\n", " model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, metric_fn\n", " )\n", "\n", " print(\"prediction saved\")\n", " print(\"Test: mse loss: %.16f\" %test_loss)\n", " print(\"Test: mae metric: %.16f\" %metric)\n", "\n", " print(\"Predict Done!\")\n", "\n", " return pred_list, test_loss, metric\n", "\n", "\n", "def evaluation(model, dataset, edge_index, batch, batch_size, loss_fn, metric_fn):\n", " \"\"\"evaluation\"\"\"\n", " num_batches = dataset.get_dataset_size()\n", " model.set_train(False)\n", " test_loss = 0\n", " metric = 0\n", " pred_list = []\n", " for _, data_dict in enumerate(dataset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " if batch_size != 0:\n", " atom_num = inputs[0].shape[0] / batch_size\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " square_atom_num = atom_num ** 2\n", " prediction = model(inputs[0], inputs[1], edge_index, batch, batch_size)\n", " pred_list.append(prediction.asnumpy())\n", " if square_atom_num != 0:\n", " test_loss += loss_fn(prediction, label).asnumpy() / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " if atom_num != 0:\n", " metric += metric_fn(prediction, label).asnumpy() / atom_num\n", " else:\n", " raise ValueError(\"atom_num should not be zero\")\n", "\n", " test_loss /= num_batches\n", " metric /= num_batches\n", "\n", " return pred_list, test_loss, metric" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data... \n", "Initializing model... \n", "Initializing Evaluation... \n", "seed is: 123\n", "prediction saved\n", "Test: mse loss: 901.1434895833332348\n", "Test: mae metric: 29.0822919209798201\n", "Predict Done!\n" ] }, { "data": { "text/plain": [ "([array([[-259531.56],\n", " [-259377.28],\n", " [-259534.83],\n", " [-259243.62],\n", " [-259541.62]], dtype=float32),\n", " array([[-259516.4 ],\n", " [-259519.81],\n", " [-259545.69],\n", " [-259428.45],\n", " [-259527.28]], dtype=float32),\n", " array([[-259508.94],\n", " [-259521.22],\n", " [-259533.28],\n", " [-259465.56],\n", " [-259523.88]], dtype=float32),\n", " array([[-259533.56],\n", " [-259303.9 ],\n", " [-259509.53],\n", " [-259369.22],\n", " [-259514.4 ]], dtype=float32),\n", " array([[-259368.25],\n", " [-259487.45],\n", " [-259545.94],\n", " [-259379.47],\n", " [-259494.19]], dtype=float32),\n", " array([[-259533.64],\n", " [-259453. ],\n", " [-259542.69],\n", " [-259451.9 ],\n", " [-259213.11]], dtype=float32),\n", " array([[-259562.5 ],\n", " [-259531.6 ],\n", " [-259526.5 ],\n", " [-259530.3 ],\n", " [-259389.12]], dtype=float32),\n", " array([[-259515.03],\n", " [-259530.69],\n", " [-259476.9 ],\n", " [-259267.77],\n", " [-259535.11]], dtype=float32),\n", " array([[-259548.77],\n", " [-259530.8 ],\n", " [-259401.7 ],\n", " [-259542.12],\n", " [-259419.86]], dtype=float32),\n", " array([[-259386.81],\n", " [-259291.75],\n", " [-259419.61],\n", " [-259488.25],\n", " [-259334.34]], dtype=float32)],\n", " 901.1434895833332,\n", " 29.08229192097982)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred(configs)" ] } ], "metadata": { "interpreter": { "hash": "58e6709d8bbc21fe79376972d6b15d6c06efb7b1d41f6d4b946e12f7486761ac" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }