{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Quick Start\n", "\n", "[![DownloadNotebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_notebook_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindchemistry/en/quick_start/mindspore_quick_start.ipynb) [![DownloadCode](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_download_code_en.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/mindchemistry/en/quick_start/mindspore_quick_start.py) [![View Source On Gitee](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_source_en.svg)](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindchemistry/docs/source_en/quick_start/quick_start.ipynb)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Overview\n", "\n", "Taking the prediction of interatomic potential with Allegro as an example.\n", "\n", "Allegro is a state-of-the-art model built on equivariant graph neural networks. The related paper has been published in the journal Nature Communications. This case study demonstrates the effectiveness of Allegro in molecular potential energy prediction, with high application value.\n", "\n", "This tutorial introduces the research background and technical path of Allegro, and demonstrates how to train and perform fast inference with MindSpore Chemistry. More information can be found in [paper](https://www.nature.com/articles/s41467-023-36329-y)." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Technology Path\n", "\n", "MindSpore Earth solves the problem as follows:\n", "\n", "1. Data Construction.\n", "2. Model Construction.\n", "3. Loss function.\n", "4. Optimizer.\n", "5. Model Training.\n", "6. Model Prediction." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import time\n", "import random\n", "\n", "import mindspore as ms\n", "import numpy as np\n", "from mindspore import nn\n", "from mindspore.experimental import optim" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The following `src` can be downloaded in [allegro/src](https://gitee.com/mindspore/mindscience/tree/r0.7/MindChemistry/applications/allegro/src)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from mindchemistry.cell.allegro import Allegro\n", "from mindchemistry.utils.load_config import load_yaml_config_from_path\n", "\n", "from src.allegro_embedding import AllegroEmbedding\n", "from src.dataset import create_training_dataset, create_test_dataset\n", "from src.potential import Potential" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ms.set_seed(123)\n", "ms.dataset.config.set_seed(1)\n", "np.random.seed(1)\n", "random.seed(1)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "You can get parameters of model, data and optimizer from [config](https://gitee.com/mindspore/mindscience/blob/r0.7/MindChemistry/applications/allegro/rmd.yaml)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" }, "scrolled": true }, "outputs": [], "source": [ "configs = load_yaml_config_from_path(\"rmd.yaml\")\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "ms.set_device(\"Ascend\", 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Construction\n", "\n", "In [Revised MD17 dataset (rMD17)](https://gitee.com/link?target=https%3A%2F%2Ffigshare.com%2Farticles%2Fdataset%2FRevised_MD17_dataset_rMD17_%2F12672038), download the dataset to the `./dataset/rmd17/npz_data/` directory. The default configuration file reads the dataset path as `dataset/rmd17/npz_data/rmd17_uracil.npz`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data... \n" ] } ], "source": [ "n_epoch = 5\n", "batch_size = configs['BATCH_SIZE']\n", "batch_size_eval = configs['BATCH_SIZE_EVAL']\n", "learning_rate = configs['LEARNING_RATE']\n", "is_profiling = configs['IS_PROFILING']\n", "shuffle = configs['SHUFFLE']\n", "split_random = configs['SPLIT_RANDOM']\n", "lrdecay = configs['LRDECAY']\n", "n_train = configs['N_TRAIN']\n", "n_eval = configs['N_EVAL']\n", "patience = configs['PATIENCE']\n", "factor = configs['FACTOR']\n", "parallel_mode = \"NONE\"\n", "\n", "print(\"Loading data... \")\n", "data_path = configs['DATA_PATH']\n", "ds_train, edge_index, batch, ds_test, eval_edge_index, eval_batch, num_type = create_training_dataset(\n", " config={\n", " \"path\": data_path,\n", " \"batch_size\": batch_size,\n", " \"batch_size_eval\": batch_size_eval,\n", " \"n_train\": n_train,\n", " \"n_val\": n_eval,\n", " \"split_random\": split_random,\n", " \"shuffle\": shuffle\n", " },\n", " dtype=ms.float32,\n", " pred_force=False,\n", " parallel_mode=parallel_mode\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%% md\n" } }, "source": [ "## Model Construction\n", "\n", "The Allegro model can be imported using the mindchemistry library, while the Embedding and potential energy prediction modules can be imported from src." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def build(num_type, configs):\n", " \"\"\" Build Potential model\n", "\n", " Args:\n", " num_atom (int): number of atoms\n", "\n", " Returns:\n", " net (Potential): Potential model\n", " \"\"\"\n", " literal_hidden_dims = 'hidden_dims'\n", " literal_activation = 'activation'\n", " literal_weight_init = 'weight_init'\n", " literal_uniform = 'uniform'\n", "\n", " emb = AllegroEmbedding(\n", " num_type=num_type,\n", " cutoff=configs['CUTOFF']\n", " )\n", "\n", " model = Allegro(\n", " l_max=configs['L_MAX'],\n", " irreps_in={\n", " \"pos\": \"1x1o\",\n", " \"edge_index\": None,\n", " \"node_attrs\": f\"{num_type}x0e\",\n", " \"node_features\": f\"{num_type}x0e\",\n", " \"edge_embedding\": f\"{configs['NUM_BASIS']}x0e\"\n", " },\n", " avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],\n", " num_layers=configs['NUM_LAYERS'],\n", " env_embed_multi=configs['ENV_EMBED_MULTI'],\n", " two_body_kwargs={\n", " literal_hidden_dims: configs['two_body_latent_mlp_latent_dimensions'],\n", " literal_activation: 'silu',\n", " literal_weight_init: literal_uniform\n", " },\n", " latent_kwargs={\n", " literal_hidden_dims: configs['latent_mlp_latent_dimensions'],\n", " literal_activation: 'silu',\n", " literal_weight_init: literal_uniform\n", " },\n", " env_embed_kwargs={\n", " literal_hidden_dims: configs['env_embed_mlp_latent_dimensions'],\n", " literal_activation: None,\n", " literal_weight_init: literal_uniform\n", " },\n", " enable_mix_precision=configs['enable_mix_precision'],\n", " )\n", "\n", " net = Potential(\n", " embedding=emb,\n", " model=model,\n", " avg_num_neighbor=configs['AVG_NUM_NEIGHBOR'],\n", " edge_eng_mlp_latent_dimensions=configs['edge_eng_mlp_latent_dimensions']\n", " )\n", "\n", " return net" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing model... \n" ] } ], "source": [ "print(\"Initializing model... \")\n", "model = build(num_type, configs)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Loss Function\n", "\n", "Allegro uses mean squared error and mean absolute error for model training." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "loss_fn = nn.MSELoss()\n", "metric_fn = nn.MAELoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimitizer\n", "\n", "The Adam optimizer is used, and the learning rate update strategy is ReduceLROnPlateau." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.Adam(params=model.trainable_params(), lr=learning_rate)\n", "lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=patience)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Model Training\n", "\n", "In this tutorial, we customize the train_step and test_step, and perform model training." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "pycharm": { "name": "#%%\n" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initializing train... \n", "seed is: 123\n", "Epoch 1\n", "-------------------------------\n", "loss: 468775552.0000000000000000 [ 0/190]\n", "loss: 468785216.0000000000000000 [ 10/190]\n", "loss: 468774752.0000000000000000 [ 20/190]\n", "loss: 468760512.0000000000000000 [ 30/190]\n", "loss: 468342560.0000000000000000 [ 40/190]\n", "loss: 414531872.0000000000000000 [ 50/190]\n", "loss: 435014.9062500000000000 [ 60/190]\n", "loss: 132964368.0000000000000000 [ 70/190]\n", "loss: 82096352.0000000000000000 [ 80/190]\n", "loss: 12417458.0000000000000000 [ 90/190]\n", "loss: 202487.4687500000000000 [100/190]\n", "loss: 300066.6875000000000000 [110/190]\n", "loss: 468295.9375000000000000 [120/190]\n", "loss: 1230706.0000000000000000 [130/190]\n", "loss: 487508.2812500000000000 [140/190]\n", "loss: 242425.6406250000000000 [150/190]\n", "loss: 841241.0000000000000000 [160/190]\n", "loss: 84912.1328125000000000 [170/190]\n", "loss: 1272812.5000000000000000 [180/190]\n", "train loss: 139695450.3818462193012238, time gap: 101.1313\n", "Test: mse loss: 328262.9555555555853061\n", "Test: mae metric: 463.6971659342447083\n", "lr: 0.0020000001\n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 469823.8750000000000000 [ 0/190]\n", "loss: 650901.1875000000000000 [ 10/190]\n", "loss: 183339.9687500000000000 [ 20/190]\n", "loss: 256283.4218750000000000 [ 30/190]\n", "loss: 335927.1875000000000000 [ 40/190]\n", "loss: 913293.2500000000000000 [ 50/190]\n", "loss: 1257833.7500000000000000 [ 60/190]\n", "loss: 630779.2500000000000000 [ 70/190]\n", "loss: 1652336.1250000000000000 [ 80/190]\n", "loss: 155349.2500000000000000 [ 90/190]\n", "loss: 183506.0468750000000000 [100/190]\n", "loss: 322167.0000000000000000 [110/190]\n", "loss: 738248.0000000000000000 [120/190]\n", "loss: 628022.9375000000000000 [130/190]\n", "loss: 693525.0000000000000000 [140/190]\n", "loss: 237971.7812500000000000 [150/190]\n", "loss: 728099.2500000000000000 [160/190]\n", "loss: 50060.8867187500000000 [170/190]\n", "loss: 1229544.0000000000000000 [180/190]\n", "train loss: 441576.3795847039436921, time gap: 26.0444\n", "Test: mse loss: 366946.9187499999534339\n", "Test: mae metric: 493.0364359537759924\n", "lr: 0.0020000001\n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 522800.6250000000000000 [ 0/190]\n", "loss: 751694.5000000000000000 [ 10/190]\n", "loss: 187226.3750000000000000 [ 20/190]\n", "loss: 240447.8906250000000000 [ 30/190]\n", "loss: 302177.7812500000000000 [ 40/190]\n", "loss: 834946.6875000000000000 [ 50/190]\n", "loss: 1170818.2500000000000000 [ 60/190]\n", "loss: 596591.8750000000000000 [ 70/190]\n", "loss: 1559648.0000000000000000 [ 80/190]\n", "loss: 144896.1718750000000000 [ 90/190]\n", "loss: 171495.7656250000000000 [100/190]\n", "loss: 302823.1250000000000000 [110/190]\n", "loss: 681209.3750000000000000 [120/190]\n", "loss: 594635.8750000000000000 [130/190]\n", "loss: 648062.7500000000000000 [140/190]\n", "loss: 221139.5312500000000000 [150/190]\n", "loss: 684927.0000000000000000 [160/190]\n", "loss: 57762.1718750000000000 [170/190]\n", "loss: 1197153.3750000000000000 [180/190]\n", "train loss: 414760.5352384868310764, time gap: 25.4267\n", "Test: mse loss: 337391.1312500000349246\n", "Test: mae metric: 473.0654032389323334\n", "lr: 0.0020000001\n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 479449.6250000000000000 [ 0/190]\n", "loss: 658094.1875000000000000 [ 10/190]\n", "loss: 167701.1875000000000000 [ 20/190]\n", "loss: 218166.0000000000000000 [ 30/190]\n", "loss: 274594.4375000000000000 [ 40/190]\n", "loss: 752581.0000000000000000 [ 50/190]\n", "loss: 1039454.4375000000000000 [ 60/190]\n", "loss: 581997.6250000000000000 [ 70/190]\n", "loss: 1481623.0000000000000000 [ 80/190]\n", "loss: 131388.5000000000000000 [ 90/190]\n", "loss: 159510.3593750000000000 [100/190]\n", "loss: 284669.9687500000000000 [110/190]\n", "loss: 635406.0625000000000000 [120/190]\n", "loss: 547961.0000000000000000 [130/190]\n", "loss: 594799.6875000000000000 [140/190]\n", "loss: 206942.5937500000000000 [150/190]\n", "loss: 635930.1250000000000000 [160/190]\n", "loss: 53651.5429687500000000 [170/190]\n", "loss: 1120740.7500000000000000 [180/190]\n", "train loss: 384702.7380550986854360, time gap: 25.3646\n", "Test: mse loss: 312383.8472222221898846\n", "Test: mae metric: 455.6222493489584053\n", "lr: 0.0020000001\n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 442066.5625000000000000 [ 0/190]\n", "loss: 610366.8750000000000000 [ 10/190]\n", "loss: 154912.0781250000000000 [ 20/190]\n", "loss: 199916.3125000000000000 [ 30/190]\n", "loss: 253701.6875000000000000 [ 40/190]\n", "loss: 695447.4375000000000000 [ 50/190]\n", "loss: 973856.6875000000000000 [ 60/190]\n", "loss: 529174.5625000000000000 [ 70/190]\n", "loss: 1359184.8750000000000000 [ 80/190]\n", "loss: 120610.0546875000000000 [ 90/190]\n", "loss: 145533.5312500000000000 [100/190]\n", "loss: 253629.2500000000000000 [110/190]\n", "loss: 602776.2500000000000000 [120/190]\n", "loss: 479350.7187500000000000 [130/190]\n", "loss: 522066.7812500000000000 [140/190]\n", "loss: 197747.9687500000000000 [150/190]\n", "loss: 585378.6875000000000000 [160/190]\n", "loss: 39960.7265625000000000 [170/190]\n", "loss: 1010730.2500000000000000 [180/190]\n", "train loss: 355614.9552425986621529, time gap: 26.2478\n", "Test: mse loss: 291521.1986111110891216\n", "Test: mae metric: 440.5232421874999886\n", "lr: 0.0020000001\n", "\n", "Training Done!\n" ] } ], "source": [ "# 1. Define forward function\n", "def forward(x, pos, edge_index, batch, batch_size, energy):\n", " pred = model(x, pos, edge_index, batch, batch_size)\n", " loss = loss_fn(pred, energy)\n", " if batch_size != 0:\n", " square_atom_num = (x.shape[0] / batch_size) ** 2\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " if square_atom_num != 0:\n", " loss = loss / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " return loss\n", "\n", "# 2. Get gradient function\n", "backward = ms.value_and_grad(forward, None, optimizer.parameters)\n", "\n", "# 3. Define function of one-step training and validation\n", "@ms.jit\n", "def train_step(x, pos, edge_index, batch, batch_size, energy):\n", " loss_, grads_ = backward(x, pos, edge_index, batch, batch_size, energy)\n", " optimizer(grads_)\n", " return loss_\n", "\n", "@ms.jit\n", "def test_step(x, pos, edge_index, batch, batch_size):\n", " return model(x, pos, edge_index, batch, batch_size)\n", "\n", "def _unpack(data):\n", " return (data['x'], data['pos']), data['energy']\n", "\n", "def train_epoch(model, trainset, edge_index, batch, batch_size, loss_train: list):\n", " size = trainset.get_dataset_size()\n", " model.set_train()\n", " total_train_loss = 0\n", " loss_train_epoch = []\n", " ti = time.time()\n", " for current, data_dict in enumerate(trainset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " loss = train_step(inputs[0], inputs[1], edge_index, batch, batch_size, label)\n", " # AtomWise\n", " loss = loss.asnumpy()\n", " loss_train_epoch.append(loss)\n", " if current % 10 == 0:\n", " # pylint: disable=W1203\n", " print(f\"loss: {loss:.16f} [{current:>3d}/{size:>3d}]\")\n", " total_train_loss += loss\n", "\n", " loss_train.append(loss_train_epoch)\n", " if size != 0:\n", " loss_train_avg = total_train_loss / size\n", " else:\n", " raise ValueError(\"size should not be zero\")\n", " t_now = time.time()\n", " print('train loss: %.16f, time gap: %.4f' %(loss_train_avg, (t_now - ti)))\n", "\n", "def test(model, dataset, edge_index, batch, batch_size, loss_fn, loss_eval: list, metric_fn, metric_list: list):\n", " num_batches = dataset.get_dataset_size()\n", " model.set_train(False)\n", " test_loss = 0\n", " metric = 0\n", " for _, data_dict in enumerate(dataset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " if batch_size != 0:\n", " atom_num = inputs[0].shape[0] / batch_size\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " square_atom_num = atom_num ** 2\n", " pred = test_step(inputs[0], inputs[1], edge_index, batch, batch_size)\n", " if square_atom_num != 0:\n", " test_loss += loss_fn(pred, label).asnumpy() / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " if atom_num != 0:\n", " metric += metric_fn(pred, label).asnumpy() / atom_num\n", " else:\n", " raise ValueError(\"atom_num should not be zero\")\n", "\n", " test_loss /= num_batches\n", " metric /= num_batches\n", " # AtomWise\n", " loss_eval.append(test_loss)\n", " metric_list.append(metric)\n", " print(\"Test: mse loss: %.16f\" %test_loss)\n", " print(\"Test: mae metric: %.16f\" %metric)\n", " return test_loss\n", "\n", "# == Training ==\n", "if is_profiling:\n", " print(\"Initializing profiler... \")\n", " profiler = ms.Profiler(output_path=\"dump_output\" + \"/profiler_data\", profile_memory=True)\n", "\n", "print(\"Initializing train... \")\n", "print(\"seed is: %d\" %ms.get_seed())\n", "loss_eval = []\n", "loss_train = []\n", "metric_list = []\n", "for t in range(n_epoch):\n", " print(\"Epoch %d\\n-------------------------------\" %(t + 1))\n", " train_epoch(model, ds_train, edge_index, batch, batch_size, loss_train)\n", " test_loss = test(\n", " model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, loss_eval, metric_fn, metric_list\n", " )\n", "\n", " if lrdecay:\n", " lr_scheduler.step(test_loss)\n", " last_lr = optimizer.param_groups[0].get('lr').value()\n", " print(\"lr: %.10f\\n\" %last_lr)\n", "\n", " if (t + 1) % 50 == 0:\n", " ms.save_checkpoint(model, \"./model.ckpt\")\n", "\n", "if is_profiling:\n", " profiler.analyse()\n", "\n", "print(\"Training Done!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Prediction\n", "\n", "Define a custom pred function for model prediction and return the prediction results." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def pred(configs, dtype=ms.float32):\n", " \"\"\"Pred the model on the eval dataset.\"\"\"\n", " batch_size_eval = configs['BATCH_SIZE_EVAL']\n", " n_eval = configs['N_EVAL']\n", "\n", " print(\"Loading data... \")\n", " data_path = configs['DATA_PATH']\n", " _, _, _, ds_test, eval_edge_index, eval_batch, num_type = create_test_dataset(\n", " config={\n", " \"path\": data_path,\n", " \"batch_size_eval\": batch_size_eval,\n", " \"n_val\": n_eval,\n", " },\n", " dtype=dtype,\n", " pred_force=False\n", " )\n", "\n", " # Define model\n", " print(\"Initializing model... \")\n", " model = build(num_type, configs)\n", "\n", " # load checkpoint\n", " ckpt_file = './model.ckpt'\n", " ms.load_checkpoint(ckpt_file, model)\n", "\n", " # Instantiate loss function and metric function\n", " loss_fn = nn.MSELoss()\n", " metric_fn = nn.MAELoss()\n", "\n", " # == Evaluation ==\n", " print(\"Initializing Evaluation... \")\n", " print(\"seed is: %d\" %ms.get_seed())\n", "\n", " pred_list, test_loss, metric = evaluation(\n", " model, ds_test, eval_edge_index, eval_batch, batch_size_eval, loss_fn, metric_fn\n", " )\n", "\n", " print(\"prediction saved\")\n", " print(\"Test: mse loss: %.16f\" %test_loss)\n", " print(\"Test: mae metric: %.16f\" %metric)\n", "\n", " print(\"Predict Done!\")\n", "\n", " return pred_list, test_loss, metric\n", "\n", "\n", "def evaluation(model, dataset, edge_index, batch, batch_size, loss_fn, metric_fn):\n", " \"\"\"evaluation\"\"\"\n", " num_batches = dataset.get_dataset_size()\n", " model.set_train(False)\n", " test_loss = 0\n", " metric = 0\n", " pred_list = []\n", " for _, data_dict in enumerate(dataset.create_dict_iterator()):\n", " inputs, label = _unpack(data_dict)\n", " if batch_size != 0:\n", " atom_num = inputs[0].shape[0] / batch_size\n", " else:\n", " raise ValueError(\"batch_size should not be zero\")\n", " square_atom_num = atom_num ** 2\n", " prediction = model(inputs[0], inputs[1], edge_index, batch, batch_size)\n", " pred_list.append(prediction.asnumpy())\n", " if square_atom_num != 0:\n", " test_loss += loss_fn(prediction, label).asnumpy() / square_atom_num\n", " else:\n", " raise ValueError(\"square_atom_num should not be zero\")\n", " if atom_num != 0:\n", " metric += metric_fn(prediction, label).asnumpy() / atom_num\n", " else:\n", " raise ValueError(\"atom_num should not be zero\")\n", "\n", " test_loss /= num_batches\n", " metric /= num_batches\n", "\n", " return pred_list, test_loss, metric" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading data... \n", "Initializing model... \n", "Initializing Evaluation... \n", "seed is: 123\n", "prediction saved\n", "Test: mse loss: 901.1434895833332348\n", "Test: mae metric: 29.0822919209798201\n", "Predict Done!\n" ] }, { "data": { "text/plain": [ "([array([[-259531.56],\n", " [-259377.28],\n", " [-259534.83],\n", " [-259243.62],\n", " [-259541.62]], dtype=float32),\n", " array([[-259516.4 ],\n", " [-259519.81],\n", " [-259545.69],\n", " [-259428.45],\n", " [-259527.28]], dtype=float32),\n", " array([[-259508.94],\n", " [-259521.22],\n", " [-259533.28],\n", " [-259465.56],\n", " [-259523.88]], dtype=float32),\n", " array([[-259533.56],\n", " [-259303.9 ],\n", " [-259509.53],\n", " [-259369.22],\n", " [-259514.4 ]], dtype=float32),\n", " array([[-259368.25],\n", " [-259487.45],\n", " [-259545.94],\n", " [-259379.47],\n", " [-259494.19]], dtype=float32),\n", " array([[-259533.64],\n", " [-259453. ],\n", " [-259542.69],\n", " [-259451.9 ],\n", " [-259213.11]], dtype=float32),\n", " array([[-259562.5 ],\n", " [-259531.6 ],\n", " [-259526.5 ],\n", " [-259530.3 ],\n", " [-259389.12]], dtype=float32),\n", " array([[-259515.03],\n", " [-259530.69],\n", " [-259476.9 ],\n", " [-259267.77],\n", " [-259535.11]], dtype=float32),\n", " array([[-259548.77],\n", " [-259530.8 ],\n", " [-259401.7 ],\n", " [-259542.12],\n", " [-259419.86]], dtype=float32),\n", " array([[-259386.81],\n", " [-259291.75],\n", " [-259419.61],\n", " [-259488.25],\n", " [-259334.34]], dtype=float32)],\n", " 901.1434895833332,\n", " 29.08229192097982)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred(configs)" ] } ], "metadata": { "interpreter": { "hash": "58e6709d8bbc21fe79376972d6b15d6c06efb7b1d41f6d4b946e12f7486761ac" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }