{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 实现简单线性函数拟合\n", "\n", "作者:[杨奕](https://github.com/helloyesterday)    编辑:[吕明赋](https://gitee.com/lvmingfu)\n", "\n", "`Linux` `Windows` `Ascend` `GPU` `CPU` `全流程` `初级` `中级` `高级`\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.2/tutorials/training/source_zh_cn/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.2/tutorials/training/source_zh_cn/quick_start/linear_regression.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/tutorials/training/source_zh_cn/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.2/mindspore_linear_regression.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/tutorials/training/source_zh_cn/_static/logo_modelarts.png)](https://console.huaweicloud.com/modelarts/?region=cn-north-4#/notebook/loading?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL21pbmRzcG9yZV9saW5lYXJfcmVncmVzc2lvbi5pcHluYg==&image_id=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![](https://gitee.com/mindspore/docs/raw/r1.2/tutorials/training/source_zh_cn/_static/logo_online_experience.png)](https://ascend.huawei.com/zh/#/college/onlineExperiment/codeLabMindSpore/linearRegression)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述\n", "\n", "回归问题算法通常是利用一系列属性来预测一个值,预测的值是连续的。例如给出一套房子的一些特征数据,如面积、卧室数等等来预测房价,利用最近一周的气温变化和卫星云图来预测未来的气温情况等。如果一套房子实际价格为500万元,通过回归分析的预测值为499万元,则认为这是一个比较好的回归分析。在机器学习问题中,常见的回归分析有线性回归、多项式回归、逻辑回归等。本例子介绍线性回归算法,并通过MindSpore进行线性回归AI训练体验。\n", "\n", "整体流程如下:\n", "\n", "1. 生成数据集\n", "2. 定义训练网络\n", "3. 定义前向传播网络与反向传播网络并关联\n", "4. 拟合过程可视化准备\n", "5. 执行训练" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> 本文档适用于CPU、GPU和Ascend环境。本例的源代码地址:。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 环境准备\n", "\n", "设置MindSpore运行配置" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:52.617310Z", "start_time": "2021-01-04T07:04:51.919345Z" } }, "outputs": [], "source": [ "from mindspore import context\n", "\n", "context.set_context(mode=context.GRAPH_MODE, device_target=\"CPU\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`GRAPH_MODE`:图模式。\n", "\n", "`device_target`:设置MindSpore的训练硬件为CPU。\n", "\n", "> 本教程代码依赖`matplotlib`第三方支持包,可使用命令`pip install matplotlib`安装。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 生成数据集\n", "\n", "### 定义数据集生成函数\n", "\n", "`get_data`用于生成训练数据集和测试数据集。由于拟合的是线性数据,假定要拟合的目标函数为:$f(x)=2x+3$,那么我们需要的训练数据集应随机分布于函数周边,这里采用了$f(x)=2x+3+noise$的方式生成,其中`noise`为遵循标准正态分布规律的随机数值。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:52.623357Z", "start_time": "2021-01-04T07:04:52.618320Z" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def get_data(num, w=2.0, b=3.0):\n", " for _ in range(num):\n", " x = np.random.uniform(-10.0, 10.0)\n", " noise = np.random.normal(0, 1)\n", " y = x * w + b + noise\n", " yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用`get_data`生成50组测试数据,并可视化。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:52.988318Z", "start_time": "2021-01-04T07:04:52.624363Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de3yU5Zn/8c/FwRNYxSVBFBE8tFu1Fdsg1EOrQpWqLZWtVi3Iri7xUFrRdn/i4VdPuxC79dDtzxXSFnRbqIcqeAIFFKwBFYO1cgj+RAUMYBJqBTyFhrn2j+eZMElmkgwzTyYz+b5fr7wyz2HmvnkyXLlzP9dct7k7IiJSmLrlugMiIhIdBXkRkQKmIC8iUsAU5EVECpiCvIhIAVOQFxEpYAryIs2Y2RIz+9d2nnuamVVH3SeRPaUgL3nLzNab2adm9lHC1//Ldb9SMbN/NrOKXPdDupYeue6ASIa+7e6Lct0Jkc5KI3kpOGa2t5l9aGbHJewrCkf9xWbWx8yeMrM6M/tb+HhAO197XzO7P3zeGmBos+OTzextM9thZmvM7Lxw/xeBacDXwr84Pgz3n2Nmfzaz7Wb2npndkq3rIAIK8lKA3L0eeAy4KGH3BcAL7l5L8L6fCRwODAQ+Bdo7zXMzcGT4dRYwvtnxt4FTgQOAW4Hfm1l/d68CrgBecvfe7n5geP7HwCXAgcA5wJVm9t00/rkirVKQl3w3Nxy1x78mhPtn0zTIXxzuw93/6u6Puvsn7r4D+A/gG+1s7wLgP9z9A3d/D/ivxIPu/oi7b3b3mLs/BLwFnJjqxdx9ibuvDM9/A/hDGn0RaZPm5CXffTfFnPzzwL5mNgx4HxgCzAEws/2Au4FRQJ/w/P3NrLu772qjvUOA9xK2NyQeNLNLgGuBQeGu3kDfVC8W9q8MOA7YC9gbeKSNPoi0m0byUpDcPQY8TDCavxh4Khy1A/wE+AIwzN0/B3w93G/teOktwGEJ2wPjD8zscODXwETgH8IpmVUJr5us5Ots4AngMHc/gGDevj39EGkXBXkpZLOB7wM/CB/H7U8wD/+hmR1EMM/eXg8D14c3bwcAP0o41osgkNcBmNm/EIzQ42qAAWa2V7O+fODun5nZiQS/kESyRkFe8t2TzfLk58QPuPsrBDc2DwHmJzznHmBfYCvwMvBMGu3dSjBF8y6wAPhdQntrgDuBlwgC+peApQnPfR5YDbxvZlvDfVcBt5nZDuBnBL9ERLLGtGiIiEjh0kheRKSAKciLiBQwBXkRkQKmIC8iUsA61Yeh+vbt64MGDcp1N0RE8sqKFSu2untRsmMZB3kzOwz4H+BgIAaUu/svw0JLEwhzhoEb3H1ea681aNAgKisrM+2SiEiXYmYbUh3Lxki+AfiJu79mZvsDK8xsYXjsbnf/RRbaEBGRPZBxkHf3LQQf9cbdd5hZFXBopq8rIiKZy+qNVzMbBJwAvBLummhmb5jZDDPrk/KJIiISiawFeTPrDTwKTHL37cB9BDW3hxCM9O9M8bxSM6s0s8q6urpkp4iIyB7KSpA3s54EAX6Wuz8G4O417r4rrAb4a1LU1Hb3cncvcfeSoqKkN4dFRGQPZRzkzcyA3wJV7n5Xwv7+CaedR1ByVUREOlA2smtOBsYBK83s9XDfDcBFZjaEoPTqeuDyLLQlIpLfYjGoq4PiYrDolw7IRnZNBckXOWg1J15EpMuIxaCmBtzhwgvhpZfgpJNg8WLoFm3hgU71iVcRkYITi8Hpp8OLLwZBPm7ZMqithYMPjrR51a4REYlSXV0Q0Juv3dHQABdcEPwSiJCCvIhIlIqLg6mZ+Pz7sGG7p2heein4JRAhBXkRkSiZBXPvmzbBli3BqP6UU6BHDzjpJBr6HsQNz93AypqVkTSvOXkRkah16wb9E7LKFy+Gujoe2foCF/x7sK77xzs/5pff+mXWm1aQFxHpYNt27uDAabtvuI48YiT3jLonkrY0XSMi0oGmvDiFA+84sHF7zVVrWDhuIRZRzrxG8iIiHWDDhxsY9MtBjdvXDr+WO89KWtIrqzSSFxGJkLvzg8d+0CTA1744nDu/+Z8d0r5G8iIiEXml+hWG/3Z443b5092Y8GoMelQGqZP9+kXeB43kRUTSlVimIImGWAPH/vexjQG+uFcxn97wCRP23Z06SXFxh3RVQV5EJB3xMgUDBsBpp7X4xOojqx+h5+09WVO3BoCF4xZS89Ma9um5b5A6WV0NS5Z0SHEy0HSNiEh64mUKGhqC7+G0y/b67RxQdkDjaSOPGMmCsQuaZs1069YhUzSJNJIXka6tjamXFuJlChKmXaa8OKVJgE+ZFpluW1mgkbyIdF3xqZdly9pf+jdepqCujg17fcqg23af32pa5J60lQUK8iLSdaWYemmLmzF26bXMXjm7cV/tT2sp6tXKEqZ72FamsrH832FmttjMqsxstZldHe4/yMwWmtlb4fc+mXdXRCSLkky9tGX5puV0u61bY4AvP7ccv9lbD/B72FY2mGc4NxSu5drf3V8zs/2BFcB3gX8GPnD3MjObDPRx9+tae62SkhKvrKzMqD8iImlp53J8DbEGhkwbwuq61UCQFrlh0gb26bFP1ttKl5mtcPeSZMcyHsm7+xZ3fy18vAOoAg4FRgMPhKc9QBD4RUQ6l3jGSytBN54WGQ/wjWmR6QT4draVbVmdkzezQcAJwCtAP3ffAsEvAjNL+reJmZUCpQADBw7MZndERDLSrrTITi5rt3bNrDfwKDDJ3be393nuXu7uJe5eUlTUxpyWiEhHiMWYMv+GJgF+9VWrI60WGZWsjOTNrCdBgJ/l7o+Fu2vMrH84iu8P1GajLRGRKG344F0G/eqIxu1rh1/DnWfdlcMeZSYb2TUG/BaocvfEK/EEMD58PB54PNO2RESi0lgtMiHA1/4c7vzyv+WwV5nLxkj+ZGAcsNLMXg/33QCUAQ+b2WXARuD8LLQlIpJ1yzctZ9hvhjVulz8JE1YQ3CDNs+mZ5jIO8u5eAaS6CiMyfX0Rkag0xBo4YfoJrKpdBYRpkVevZ5/nR0GP8JOpHVxrJtv0iVcR6ZIeWf0IF/zxgsbtheMWMvKIkcFGWLYg2/nsuaAgLyKFK8mHjzprtcioqAqliBSmJHXfm1eLzNe0yHRoJC8ihSmhINiG1UsZdHv3xkPXDL+Gu/I4LTIdCvIiUpj69sWHljD20FeYfdyuxt01P62huFfHFAfrDBTkRaTwxGIsH13CsLNeb9xVfm45E746IYedyg0FeRHJX0lurDbEGjjh3i+xauhaAIo/gg2T1rPPoYfnsqc5oxuvIpKfktxY/eOaP9Lz9p6s+iAI8AtmdaNm+dfZ55CuW/xQI3kRyU8JN1a3Vy7lgIQbqyOPGMmCi5/BrthaELnumVCQF5H8FK60NMUquPH03TdWV1+1mmOKjgk2CiTXPRMK8iKSlzZs28igM/7UuN2V0iLToSAvInnF3Rk7Z2yTRbS7WlpkOhTkRSRvNK8WOf3c6ZR+tTSHPer8FORFpNNLWi0y3UW0uyilUIpIp9aYFhkG+AVjF+zZItpdVLaW/5sBnAvUuvtx4b5bgAlAXXjaDe4+LxvtiUjha7VaZJIPQUly2RrJ3w+MSrL/bncfEn4pwItIu0x9cWrqapFJPgQlqWVlJO/ufzKzQdl4LRHpujZ8uIFBvxzUuJ00LTLhQ1AsWxZsKx8+pajn5Cea2RtmNsPM+iQ7wcxKzazSzCrr6uqSnSIiXcDYx8Y2CfA1P61JnvcefgiKHj2C78VKnWxNlEH+PuBIYAiwBbgz2UnuXu7uJe5eUlRUFGF3RKQzWr5pOXarMWvlLCBIi/SbvWneeywGNTXgHszBL14M1dWwZInm5NsQWQqlu9fEH5vZr4GnompLRPJPu9Mi43Pwy8KFtRcvLqjl+aIW2UjezPonbJ4HrIqqLRHp5BJH4qSZFplsDl7aLVsplH8ATgP6mlk1cDNwmpkNARxYD1yejbZEJM8kjMS3n3oiB3xjWeOhpItoNxefg4+P5DUHnxbz8DdrZ1BSUuKVlZW57oaIZFNNDQwYwNThDdwwcvfuJtUi28p7V158q8xshbuXJDumT7yKSKQ27PUpdtPuAH/NsEn4zd40wLeV9x6fg1eAT5tq14hIZMY+NrYxawag5ifvU9y72Q1T5b1HSiN5Ecm6FmmR50wL0iKbB3hQ3nvENJIXkaxpiDXwlelfYWXtSgCKdvZk450x9nl+NiyeEEy7NBfPe9eceyQ0kheRrIinRcYD/IJzHqL2584+9bvaTn3UnHtkNJIXkYykrBYJcNK9Sn3MMQV5EdljU1+cyg3P39C43SQtEjQN0wkoyItI2ppXi5w0bBJ3j7q75YkqP5BzCvIikpYWaZFaRLtTU5AXkXbRItr5SUFeRFrVIi1yvyI2XrNRa6zmCaVQikhKLdIixy6g9t9qFeDziEbyItJCq4toS17RSF5EmiirKEu9iLbkHY3kRQRiMTa+82cOn7W7Wm3KtEjJK9laNGQGcC5Q6+7HhfsOAh4CBhEsGnKBu/8tG+2JSBbFYoz9YX9mHVzbuEtpkYUjW9M19wOjmu2bDDzn7kcDz4XbItKJLN+0HLu9e2OAn/50N/yK9xXgC0hWgry7/wn4oNnu0cAD4eMHgO9moy0RaUOz9VSTaYg1cPy04xvz3ot29uTTqd0p3fcU1ZgpMFHeeO3n7lsAwu9J3zlmVmpmlWZWWacFekUy09AAp5zS6ipL8bTIN2reAGDBrG7Uvjicfd59D5YsUY2ZApPzG6/uXg6UQ7DGa467I5K/YjE49VR4+eVgu9kqS83TIkcceioLrlxGt4Zd8O5LQZ0ZBfiCE+VIvsbM+gOE32vbOF9EMlFXB6++unt76NDGqZdkaZGLLnuBbiedrBWZClyUI/kngPFAWfj98QjbEpHiYjj5ZFi6NAjwFRVs3P4eh99zeOMpLdIiVQq44GUrhfIPwGlAXzOrBm4mCO4Pm9llwEbg/Gy0JSIpNFtGb9zcS/j9G79vPJw0LVKlgAteVoK8u1+U4tCIbLy+iCQRi7UchXfrxvKGDQy77eDG01QtsmvL+Y1XEdkDsRicfvrupfUWL6aBGF8t/2pj1oyqRQqodo1IfqqrCwJ8QwMsW8ajr8xsmhbZVrXIduTSS2FQkBfJR8XFcNJJbN+vO3ZTA99b8K8AjBg8gl0/28U3j/xm6ufG/wpoJZdeCoeCvEg+MqPs30dxwP/Z1bhr9VWrWXTJIrpZG/+tm/0VgD6EWNA0Jy+SZzZu29h6WmRbwr8CGufzlR9f0BTkRfLIuMfG8vtMF9Fulmqp/PjCpiAvkgde3fQqJ/7mxMbtaWuP4vJZbwZ57slSKdui/PguQ3PyIp1YvFpkPMD3/Rg++Xe4/I/rg8Ce7k1UZdV0OQryIp3Uo2sebZoW+YNnqXvl6+xLQq2ZdG6iKqumS9J0jUgns6N+B58r+1zj9ojBI1gwbkGQNbN4ZNOpmXRuoib7haApm4KnkbxIJ1JWUdYkwK+6clXTtMj4XHp87j1+E7W6uu1a8PFfCKo62aVoJC/SCWSUFtnem6jKqumSFORFcmzcnHFtV4vMFmXVdDkK8iI50iIt8pxpXF5yeQ57JIVIQV4kKiny1xtiDU2qRfbdry8bJ21k35775qqnUsAiv/FqZuvNbKWZvW5mlVG3J9IppEhXbJkW+Qx1l6xiX5UDloh01Ej+dHff2kFtieRes3TFHZve5XMzjmo8PGLwCBb84Bm6nTGiSU14uinhTbJL7yiRKCSkK5ZdNKBJgG9Mi9z6V1WDlMh1RJB3YIGZrTCzFmuQmVmpmVWaWWWd3uRSKMzYOPcB7KYGrj9yPQBXD7sav9k5tvjY4BzlrUsHMI+4hoWZHeLum82sGFgI/Mjd/5Ts3JKSEq+s1LS95L9L5lzC7974XeN2yrTIPSkuJtKMma1w95JkxyKfk3f3zeH3WjObA5wIJA3yIvku7bRI5a1LxCIN8mbWC+jm7jvCx2cCt0XZpkguKC1SOquoR/L9gDkW/BnaA5jt7s9E3KZIh3p0zaN875HvNW4/O/ZZzjzyzBz2SGS3SIO8u78DHB9lGyK50mq1SJFOQu9GkT1wR8UdrVeLFOkkVNZAJA3Nq0VePexq7hl1Tw57JNI6BXmRdrrs8cuY8fqMxu1Iq0WKZIn+thRpw6btmzjvofMaA/y0c6bhN7sCvOQFjeRFUoh5jOmV05n83GR27tpJ2Ygyfjzsx0qLlLyiIC+SxJq6NZQ+WcrS95YyYvAIpp87nSMPOjLX3RJJm4K8SIL6hnqmVkxlyotT2H/v/bn/OzO4pP+3sD76VKrkJ83Ji4QqNlZwwvQTuPWFWzn/2POpunI14yfdjx12WJOa8CL5REFeurxtn23jyqeu5NSZp/LJ3z9h3sXzmDVmFsWfmEoBS97TdI10aXOq5jBx/kTe/+h9rh1+Lbeefiu99+odHIyXAo4v6qFSwJKHFOSlS9q0fRMT509k7tq5DOk3hMe/OYOS485sWu7XLFitSaWAJY9puka6lJjHuO/V+zjmv4/hmXXPcMeIMpY/2JuSr5ybfN49XgpYAV7ylEbyUpiSLMaRNC3y771h6U1N591V310KiEbyUnhiMTj9dBgwAE47jfqdn3LLklsYMm0IVVuruH/0/SwctzDIe9cSfFLgNJKXwlNX15gVU/HeUkqnHU/V397i4i9dzN1n3d20HIHm3aXART6SN7NRZvamma0zs8lRtydCcTHbTj2RK79tnDp+F5/Edu5Oi0xWb0bz7lLAol7+rztwL/BNoBp41cyecPc1UbYrXductXOZeO563v/IuObEq7ntjNt3p0WKdDFRT9ecCKwLV4jCzB4ERgMK8pJ1m7Zv4kfzf8SctXM4vt/xzP3+XIYeOjTX3RLJqaiD/KHAewnb1cCwiNuULqZ5tcg7Rt7BNcOvoWf3nrnumkjORR3kk01yepMTzEqBUoCBAwdG3B0pNKoWKdK6qG+8VgOHJWwPADYnnuDu5e5e4u4lRUVFEXdHCkV9Q32TtMiZo2fuTosUkUZRj+RfBY42s8HAJuBC4OKI25QCV7GxgtInS6naWpU8LVJEGkUa5N29wcwmAs8C3YEZ7r46yjalcG37bBuTF01m2oppHH7A4cy7eB7fOvpbue6WSKcW+Yeh3H0eMC/qdqSwJVaLvGbYJKVFirSTPvEqnVqTtMgdvZj7sDH0udfgzP1y3TWRvKDaNdIpJVaLnL9uPmXDbuTVX33G0Pd2aQEPkTRoJC+dTtK0yD5HwPAXtYCHSJoU5KXTaL6I9szRMxl//HjMHWpr4fnnYetWFRITSYOCvHQKiWmRFx13EfeMuidIi4yXDY6P4BcvVoAXSYPm5CWnki2iPfufZu/Oe08oG6y5eJH0aSQvOdMiLfK4H9H70MFNT9Ji2iIZ0UhesiMWg5oacG/z1M07NjPmoTGMeXgMRfsV8fKly7hr6mv0HvyFluusxhf1qK6GJUs0VSOSJgV5yVyz5fZaLIYdP81jTKucxhfv/WKQFjmijFcnvMrQnoNan5LRoh4ie0zTNZK5ZPPmzRbDbrVapKZkRCKjIC+ZayVIp0yLTByVa51VkcgoyEvmUgTppRuXMuHJCS3TIpOJT8lAMN2jgC+SFZqTl+xImDePp0WeMvOU5GmRrWnn/L6ItI9G8pJVTdIih1/Dbaffll61yHbM74tI+2kkL1nRIi3yspe566y70i8HHJ/f79FDN2FFskAjeclIzGOUryjnukXXsXPXTspGlHHt167d80W0dRNWJKsiC/JmdgswAYgnPd8QLiAiBaKqrorSp0qp2FjBiMEjmHbuNI466KjMXzjxJqyIZCTqkfzd7v6LiNuQDtautEgR6RQ0XSNpSSstUkRyLuobrxPN7A0zm2FmfZKdYGalZlZpZpV1qjDYaSWmRX5c/Q7z/tCN2f+1ieJ9++a6ayLSCvN2FJRK+WSzRcDBSQ7dCLwMbAUcuB3o7+6XtvZ6JSUlXllZucf9kWgkpkVe/aUJ3PaD39D7011BBkx1tebPRXLMzFa4e0myYxlN17j7yHZ24NfAU5m0JR1v847NTJw3MVhEu9/xzP3+XIYeUgJDq1RnRiRPRJld09/dt4Sb5wGrompLsqvNtEilOIrkjShvvP7czIYQTNesBy6PsC3Jkqqa1ZTOvYyK91/hjMFnMP3c6S3TIpXiKJI3Igvy7j4uqteW7KtvqGfqi1OYsuR29v/Mmbn+84y/aQHWvXuuuyYiGVAKpTRNi1xj3DMfiuvfga1bNWIXyXOqXdOFNUmL/PvHzLvoaWbXnkpxverGiBQKjeS7qMS0yEnDJnH7GbcHxcQWj9JNVZECoiDfxWzeVs3Exy9nzrvzdqdFHjp09wm6qSpSUBTku4iYxyivnMZ1T/yYnb6Lsg2DufaGV+jZc+9cd01EIqQg3wUkVos84z1j+pNw1Pb34BcfatQuUuB047WA1TfUc+uSWxkyfQira1cz8zszWPTuKRy1vY0bq7EY1NRABiUvRKRz0Eg+36VY9DpltcjF41u/sRpfYzVetmDx4mCeXkTykv735rMki163SItsvoh2woLbSSVbY1VE8pZG8vmsWUCes/wBJi67qWVaZDria6yqAJlIQVCQz2dhQN78xlImXnwAc569lC/3+3LLtMh0aI1VkYKiIJ/HYjjlv/g+1y36MztjH1P2jQwX0Y5TrrxIwVCQz1NN0iJTVYsUkS5PQT7P1DfUU1ZRxpSKKfTq2UuLaItIqxTk84gW0RaRdGWUQmlm55vZajOLmVlJs2PXm9k6M3vTzM7KrJtdW5tpkSIiKWQ6kl8FjAGmJ+40s2OAC4FjgUOARWb2eXfflWF7XU7KapEiIu2Q6ULeVUCy+eDRwIPuXg+8a2brgBOBlzJprytJXEQ747RIEemyopqTPxR4OWG7OtzXgpmVAqUAAwcOjKg7eSIWI1ZbQ/nGOVz33GR2xv7echFtEZE0tBnkzWwRcHCSQze6++OpnpZkX9JqV+5eDpQDlJSUdN2KWLEYVecMo/SQSioGwhnvwvT3h3LU9f+m2jEissfaDPLuPnIPXrcaOCxhewCweQ9ep0uob6in7NmbmFJSSa+/w8y5MP51sB5/Dj55qg8micgeimq65glgtpndRXDj9WhgeURt5bUmaZFbi7hn9gcUW2/o/pFqx4hIxjIK8mZ2HvAroAh42sxed/ez3H21mT0MrAEagB8qs6apbZ9t4/rnrue+yvsY+NnePD2nG2cX/yOsfSgI7Fu3qnaMiGTMvBMtDFFSUuKVlZW57kbkEtMif3zcv3L72N/S+9Nd0KMHVFdrekZE0mJmK9y9JNkxfeK1A7VIi7zgMYb2OBxKquCllzQ9IyJZpyDfAWIeo3xFOdctuo6du3YydcRUfjLsGnqOPHN33faNG+HggzU9IyJZpSAfsZTVImtqmq7A1K2bAryIZJ0SsCPSYhHt0TNZNG7R7nLA8RWYerSxqLaISAY0ko9Au6pFagUmEekACvJZ1CQt8oCBPH3x05x99Nmpn6AVmEQkYgry6YrFko6+VS1SRDojzcmnIxaD00+HAQPgtNMgFmPzjs2MeWgMYx4eQ9/9+vLyZS9z96i7FeBFpFPQSD4ddXWNGTGxZUspX/ILrlv+H7vTIr/2E1WLFJFORUE+HWFGTNWbSym9qBcVL16nRbRFpFNTkE9D/a6dlN1yOlMqXqbXXt2ZeZYW0RaRzk1Bvp20iLaI5CMF+TaknRYpItKJKMi3Yu7aufxw3g+VFikieUtBPi4h/33zR1u0iLaIFISM8uTN7HwzW21mMTMrSdg/yMw+NbPXw69pmXc1QmH+e2zAoUy7+PN88d4vMn/dfKaOmErlhEoFeBHJW5mO5FcBY4DpSY697e5DMnz9jlFXx9o3lzJh3C4qDl/HGX1PYfqYmUqLFJG8l1GQd/cqIH9SCJOUJKhvqKes6j6mXB6jVz3MXPN5xv/fF7Bu+jCwiOS/KCPZYDP7s5m9YGanpjrJzErNrNLMKuvq6qLrTZKSBEs3LuWE6Sdwywu38k9f+j5rr1jJPz+4VgFeRApGmyN5M1sEHJzk0I3u/niKp20BBrr7X83sq8BcMzvW3bc3P9Hdy4FyCNZ4bX/X05RQkmDbiqVc/+il3LfmAaVFikhBazPIu/vIdF/U3euB+vDxCjN7G/g8kLtVusOSBHO3VvDD0d15v+p3SosUkYIXSQqlmRUBH7j7LjM7AjgaeCeKttpr80db+NFV/8Bja2N8ud8/Mvfbv1HWjIgUvIyCvJmdB/wKKAKeNrPX3f0s4OvAbWbWAOwCrnD3DzLubbpiMWK1NZS/N5frnpusapEi0uVkml0zB5iTZP+jwKOZvHbGYjHWnjOMCYdUUjEQzhh0BtO/rWqRItK1FGQayc5dO7lt/mSOL6lkdRHMeLIbi86apQAvIl1OwZU1WLpxKaVPlbKmbg0XbS3i7j98QL/jT9ZaqiLSJRXMSH7bZ9u46umrOGXmKXy08yOevvhpZt/7Pv3e3ARLljRZj1VEpKsoiJF85eZKRj84Onm1SI3gRaQLK4ggf0SfIzi26FhVixQRaaYggvxB+x7EgnELct0NEZFOp2Dm5EVEpCUFeRGRAqYgLyJSwBTkRUQKmIK8iEgBU5AXESlgCvIiIgVMQV5EpICZe3Qr7qXLzOqADRm8RF9ga5a6k03qV3rUr/SoX+kpxH4d7u5FyQ50qiCfKTOrdPeSXPejOfUrPepXetSv9HS1fmm6RkSkgCnIi4gUsEIL8uW57kAK6ld61K/0qF/p6VL9Kqg5eRERaarQRvIiIpJAQV5EpIDlVZA3s/PNbLWZxcyspNmx681snZm9aWZnpXj+YDN7xczeMrOHzGyviPr5kJm9Hn6tN7PXU5y33sxWhudVRtGXZu3dYmabEvp2dorzRoXXcZ2ZTe6Afv2nma01szfMbI6ZHZjivMivV1v/djPbO/z5rgvfS4Oi6EeSdg8zs8VmVhX+H7g6yTmnmdm2hJ/vzzqob63+XCzwX+E1e8PMvtIBffpCwnV43cy2m9mkZud0yPUysxlmVmtmqxL2HWRmC8NYtNsmmtMAAATRSURBVNDM+qR47vjwnLfMbPwedcDd8+YL+CLwBWAJUJKw/xjgL8DewGDgbaB7kuc/DFwYPp4GXNkBfb4T+FmKY+uBvh14/W4BftrGOd3D63cEsFd4XY+JuF9nAj3Cx3cAd+TierXn3w5cBUwLH18IPNRBP7v+wFfCx/sD/z9J304Dnuqo91N7fy7A2cB8wIDhwCsd3L/uwPsEHxjq8OsFfB34CrAqYd/Pgcnh48nJ3vPAQcA74fc+4eM+6bafVyN5d69y9zeTHBoNPOju9e7+LrAOODHxBDMz4Azgj+GuB4DvRtnfsM0LgD9E2U6WnQisc/d33H0n8CDB9Y2Muy9w94Zw82VgQJTttaI9//bRBO8dCN5LI8Kfc6TcfYu7vxY+3gFUAYdG3W6WjAb+xwMvAweaWf8ObH8E8La7Z/Jp+j3m7n8CPmi2O/F9lCoWnQUsdPcP3P1vwEJgVLrt51WQb8WhwHsJ29W0/A/wD8CHCcEk2TnZdipQ4+5vpTjuwAIzW2FmpRH3JW5i+CfzjBR/IrbnWkbpUoJRXzJRX6/2/NsbzwnfS9sI3lsdJpwiOgF4Jcnhr5nZX8xsvpkd20Fdauvnkuv31IWkHmjl4noB9HP3LRD8AgeKk5yTlevW6RbyNrNFwMFJDt3o7o+nelqSfc1zQ9tzTru1s58X0foo/mR332xmxcBCM1sb/tbfY631C7gPuJ3g3307wVTSpc1fIslzM86zbc/1MrMbgQZgVoqXyfr1at7NJPsifR+ly8x6A48Ck9x9e7PDrxFMSXwU3m+ZCxzdAd1q6+eSs2sW3nf7DnB9ksO5ul7tlZXr1umCvLuP3IOnVQOHJWwPADY3O2crwZ+JPcIRWLJz2q2tfppZD2AM8NVWXmNz+L3WzOYQTBdkFLTae/3M7NfAU0kOtedaZr1f4U2lc4ERHk5IJnmNrF+vZtrzb4+fUx3+jA+g5Z/ikTCzngQBfpa7P9b8eGLQd/d5ZvbfZtbX3SMtxtWOn0sk76l2+hbwmrvXND+Qq+sVqjGz/u6+JZy6qk1yTjXBfYO4AQT3I9NSKNM1TwAXhpkPgwl+Gy9PPCEMHIuB74W7xgOp/jLIhpHAWnevTnbQzHqZ2f7xxwQ3H1clOzdbms2DnpeivVeBoy3IRNqL4E/dJyLu1yjgOuA77v5JinM64nq159/+BMF7B4L30vOpfillUzjv/1ugyt3vSnHOwfH7A2Z2IsH/779G3K/2/FyeAC4Js2yGA9viUxUdIOVf07m4XgkS30epYtGzwJlm1iecWj0z3JeeqO8sZ/OLIDBVA/VADfBswrEbCTIj3gS+lbB/HnBI+PgIguC/DngE2DvCvt4PXNFs3yHAvIS+/CX8Wk0wbRH19fsdsBJ4I3yT9W/er3D7bILsjbc7qF/rCOYeXw+/pjXvV0ddr2T/duA2gl9AAPuE75114XvpiKivT9juKQR/qr+RcJ3OBq6Iv8+AieG1+QvBDeyTOqBfSX8uzfplwL3hNV1JQmZcxH3bjyBoH5Cwr8OvF8EvmS3A38P4dRnBfZzngLfC7weF55YAv0l47qXhe20d8C970r7KGoiIFLBCma4REZEkFORFRAqYgryISAFTkBcRKWAK8iIiBUxBXkSkgCnIi4gUsP8FUy7XCPXim48AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "eval_data = list(get_data(50))\n", "x_target_label = np.array([-10, 10, 0.1])\n", "y_target_label = x_target_label * 2 + 3\n", "x_eval_label,y_eval_label = zip(*eval_data)\n", "\n", "plt.scatter(x_eval_label, y_eval_label, color=\"red\", s=5)\n", "plt.plot(x_target_label, y_target_label, color=\"green\")\n", "plt.title(\"Eval data\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上图中绿色线条部分为目标函数,红点部分为验证数据`eval_data`。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义数据增强函数\n", "\n", "使用MindSpore的数据增强函数,将数据进行增强操作,操作解释如下:\n", "\n", "- `ds.GeneratorDataset`:将生成的数据转换为MindSpore的数据集,并且将生成的数据的x,y值存入到`data`和`label`的数组中。\n", "- `batch`:将`batch_size`个数据组合成一个batch。\n", "- `repeat`:将数据集数量倍增。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:52.993381Z", "start_time": "2021-01-04T07:04:52.990360Z" } }, "outputs": [], "source": [ "from mindspore import dataset as ds\n", "\n", "def create_dataset(num_data, batch_size=16, repeat_size=1):\n", " input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])\n", " input_data = input_data.batch(batch_size)\n", " input_data = input_data.repeat(repeat_size)\n", " return input_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用数据集增强函数生成训练数据,并查看训练数据的格式。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.079377Z", "start_time": "2021-01-04T07:04:52.994402Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The dataset size of ds_train: 100\n", "dict_keys(['data', 'label'])\n", "The x label value shape: (16, 1)\n", "The y label value shape: (16, 1)\n" ] } ], "source": [ "data_number = 1600\n", "batch_number = 16\n", "repeat_number = 1\n", "\n", "ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number) \n", "print(\"The dataset size of ds_train:\", ds_train.get_dataset_size())\n", "dict_datasets = next(ds_train.create_dict_iterator())\n", "\n", "print(dict_datasets.keys())\n", "print(\"The x label value shape:\", dict_datasets[\"data\"].shape)\n", "print(\"The y label value shape:\", dict_datasets[\"label\"].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过定义的`create_dataset`将生成的1600个数据增强为了100组shape为16x1的数据集。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义训练网络\n", "\n", "在MindSpore中使用`nn.Dense`生成单个数据输入,单个数据输出的线性函数模型:\n", "\n", "$$f(x)=wx+b\\tag{1}$$\n", "\n", "并使用Normal算子随机初始化权重$w$和$b$。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.085026Z", "start_time": "2021-01-04T07:04:53.080390Z" } }, "outputs": [], "source": [ "from mindspore.common.initializer import Normal\n", "from mindspore import nn\n", "\n", "class LinearNet(nn.Cell):\n", " def __init__(self):\n", " super(LinearNet, self).__init__()\n", " self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))\n", " \n", " def construct(self, x):\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "调用网络查看初始化的模型参数。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.100773Z", "start_time": "2021-01-04T07:04:53.086027Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter (name=fc.weight) [[-0.02289871]]\n", "Parameter (name=fc.bias) [0.01492652]\n" ] } ], "source": [ "net = LinearNet()\n", "model_params = net.trainable_params()\n", "for param in model_params:\n", " print(param, param.asnumpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "初始化网络模型后,接下来将初始化的网络函数和训练数据集进行可视化,了解拟合前的模型函数情况。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.242097Z", "start_time": "2021-01-04T07:04:53.102786Z" }, "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3iUVdrH8e+dhIAUBSFApFqAVUQpERBBQFGUteEuoq7KKor4iksRK6wRwYaKa0cQxYoooriISKRYAJVQVBClL0QgIFWlJnPeP2YyCWFCEqYm+X2uK1fmqefmyTD3POec5xxzziEiIhIX7QBERCQ2KCGIiAighCAiIj5KCCIiAighiIiIjxKCiIgAIUgIZlbPzGab2XIzW2Zm/X3rHzSzX81sie+nW/DhiohIuFiwzyGYWTKQ7JxbZGZVgIXAFcBVwB/OuSeDD1NERMItIdgTOOc2AZt8r383s+VAnWDPKyIikRX0HcIhJzNrCHwJnA4MAv4J7AbSgTudczsCHNMH6ANQqVKlVn/5y19CFo+ISFmwcOHC35xzScGeJ2QJwcwqA18ADzvnJptZLeA3wAHD8VYr3XSkc6SkpLj09PSQxCMiUlaY2ULnXEqw5wlJLyMzKwd8ALztnJsM4JzLdM5lO+c8wFigdSjKEhGR8AhFLyMDxgHLnXOj8qxPzrNbd2BpsGWJiEj4BN2oDJwDXA/8aGZLfOvuB64xs+Z4q4zWAbeGoCwRkbLH44HMTDCDWrW8v8MgFL2MvgYCRTct2HOLiJRpOYmgZ0/4+mtwDjp0gDlzIC70zxWH4g5BRERCzeOBzp1h3jzIyspdP28ebNkCtWuHvEgNXSEiEou2bs1NBnmriLKz4aqrvAkjxJQQRERiUc2a0K4dJCR4q4mWLMmtJpo/35swQkxVRiIiscgMZs/2fvDXrOld1769966hXTuoWZMde3fQ4uUWIStSCUFEJFbFxXl7FeXwJQiXlMQ1H1zDxGUTQ1tcSM8mIiLhExfH+E2fEjc83p8MUjumhuz0ukMQESkBftr6E01fbOpfblG7Bd/c/A2J8YkMY1hIylBCEBGJYX8e+JNTXziVDbs3+Net7b+WhlUbhrwsVRmJiMSoftP6UfnRyv5k8OF7cbhZ59Lw2PphKU93CCIiMWbKz1O4YuIV/uV+p9/Ec1e/4X0mIWGet+dR3sbmEFFCEBGJBI8ntwtpAWMRrdu5jhOfOdG/XO/Yeiy/fTmVylWEdqsO6XIaDkoIIiLhlncYinbtvN1H84xFdCD7AG1facvizYv965betpSmNXMbkQ95JiFMg9upDUFEJNzyDkMxb94hTxkPmzOM8iPK+5PBa5e/hkt1hyYDyH0mIUzJAHSHICJSPEWo+jlMzjAUeap8Zq+dzXlvnOffpWfTnkz42wQs0DmPpsyjoIQgIlJUhVT9FCjPMBSZFR21H8o9pnx8eTbduYlqx1QLbZlHIRQzptUzs9lmttzMlplZf9/6480szcxW+n4X8K8VESkhjlD1U5hsHBel9aL2qNzJJL/p/Q37hu4rOBkEWWZxhSLNZAF3OudOBdoCt5vZacC9wEznXCNgpm9ZRKTkyjsCaTF6+zz/3fMkDE/gs9WfATDqwlG4VEebum3CVubRCMWMaZuATb7Xv5vZcqAOcDnQybfb68Ac4J5gyxMRiZr8I5AWUp+/cONCUsam+Jc7N+zMjOtnkBBXjI/eYpYZjJC2IZhZQ6AF8C1Qy5cscM5tMrPwpTURkUjJPwJpALv27aLe0/X4/cDv/nUbB20kuUryEY4KrsxQCFnLhJlVBj4ABjjndhfjuD5mlm5m6VvDWDcmIhJuzjmum3wdVR+v6k8Gaden4VLd0SeDCApJQjCzcniTwdvOucm+1ZlmluzbngxsCXSsc26Mcy7FOZeSlJQUinBERCLL4+Htr14k7qE43v7xbQDub38/LtXR5aQuUQ6u6IKuMjJvp9lxwHLn3Kg8mz4GegGP+X5PCbYsEZFY8/OWnzj1pdyHyJrVbMaCWxZQPqF8FKM6OqFoQzgHuB740cyW+NbdjzcRvGdmvYH1QI8QlCUiEhP2HNzD6S+eztqda/3rVj0DJy/7DEpgMoDQ9DL6Giio2fv8YM8vIhJrBk4fyH++/Y9/edJ78Lef8PYACmMvoHDTWEYiIkU0dcVUbJj5k8GtrW7F8+9s/lbjXO9zAh06RKQ3ULho6AoRkUKs37WeBv9p4F+uVakWq/61isqJlb0rIvScQLgpIYiIQMAB5A5mH6T9a+357tfv/Lt93/d7zqh1xqHHRug5gXBTlZGISM4AcnXrQqdO4PHw8JcPkzgi0Z8Mxl46FpfqDk8GpYjuEEREMjP9A8h9tWEu5w6P92/q/pfuTLpqEnFW+r8/KyGISNnm8UDPnmxNzKLmUIBsAOIsjszBmdSoWCOq4UWSEoKIlGmeLZlcVv9rPsnTSX7uTXNpV69d9IKKktJ/DyQiAt47gcxMcM6/anT6aOJfPoFPGnnXPbbqRNwDnjKZDEB3CCJSFuSbdWzxO6No+UrusNQd6ndg1kUTSKh9QonuNhosJQQRKf18s47tjs+i4dlfsiNPMtgwcAN1j60bxeBihxKCiJR6LimJm3pXZ3xypn/dp//4lItOuSiKUcUetSGISKn27tJ3iRse708Gd509GJfqlAwC0B2CiJRKK7etpPHzjf3LTao3YUnfJVRIqBDFqGKbEoKIlCp7D+7lzNFnsnL7Sv+6Ff1W0Kh6oyhGVTKoykhESo270+6m4iMV/cng3b+9i0t1SgZFpDsEESnxpq+azsVvX+xfvrH5jYy7bBxWhruQHo2QJAQzexW4BNjinDvdt+5B4BZgq2+3+51z00JRnogIwK+7f6Xu07ldRo8/5njW9l/LseWP9a4IMIKpFCxUVUbjgUBN9k8755r7fpQMRCQksjxZtH+1/SHJYFGfRWy7e9uhySDfCKZyZCFJCM65L4HtoTiXiMiRjJw7knLDyzF3w1wAXvrrS7hUR4vkFofu6HsYjaws7++tWwOcTfIKdxtCPzO7AUgH7nTO7ci/g5n1AfoA1K9fP8zhiEhJNXf9XNq/1t6/fEnjS5hy9ZSCh6WuWRPatfMPV0HNmhGKtOQyl2egp6BOZNYQmJqnDaEW8BvggOFAsnPupiOdIyUlxaWnp4ckHhEpHbbt2UbSE0k4cj+rtgzeQlKlpMN3zt9mUEbaEMxsoXMupfA9jyxs3U6dc5nOuWznnAcYC7QOV1kiUvp4nIfuE7tT44ka/mTwxT+/wKW6gpNB/jaDnKktS3EyCKWwJQQzS86z2B1YGq6yRKSEyzc09diFY4l/KJ6Pfv4IgIfPexiX6ji3wbkFn0NtBkELVbfTCUAnoIaZZQCpQCcza463ymgdcGsoyhKRUibP0NQ/XHAGZ7ZZ5N/Upk4bvrrxK8pZvDdhHKnqR20GQQtJQnDOXRNg9bhQnFtESrmtW/k9fS6nDMhmS+XcZPC/Af+j/nH1D5vLgNmzvVVB+Zl5t5WBNoNw0dAVIhI1zjn6fDeUY+/OZktl77r/Xv0xLtV5kwEUrypIbQZB0dAVIhIVk36aRI/3e/iXBzTrw9PdRx/+Ya6qoIhRQhCRiFq9fTWnPHeKf/nkYxvy4+3LOCaxYuADVBUUMaoyEpGI2J+1n9NfPP2QZLD8mxRW3Z3BMRdcfOShJVQVFBFKCCISdkNmDqHCwxVYtnUZAG92fxPXdzN/+XyJuonGEFUZiUjYpK1O48K3LvQvX3fGdbxxxRveYamdU9tAjFFCEJGQ2/j7RuqMquNfrpJYhfUD11O1QtXcndQ2EHOUEEQkZLI8WVzw5gXMWTfHv27BLQtIOaGAYXZy2gYkJqgNQURCYtT8UZQbXs6fDJ696Flcqis4GUjM0R2CiATl24xvaTuurX+568ld+eTaT4iPi49iVHI0lBBE5Khs37ud5KeSOZB9wL9u852bqVVZVUAllaqMRKRYnHNc9f5VVB9Z3Z8MZt0wC5fqlAxKON0hiEiRjV8ynhun3OhffrDjg6R2So1iRBJKSggiUjDfjGPL2Mrpo5v5V7dKbsW83vNIjE+MYnASakoIIhKYx8Of559Lk5Zz+fXY3NVr+6+lYdWGUQtLwickbQhm9qqZbTGzpXnWHW9maWa20ve7WijKEpHI6Df5Zip3yk0GH3Z9DZfqlAxKsVA1Ko8HLsq37l5gpnOuETDTtywi0ZBvisoj+ejnj7BhxgvLXgOg3wLDzTqXK9r0CneUEmWhmjHtSzNrmG/15Xin1QR4HZgD3BOK8kSkGIo449jaHWs56dmT/Mv1dsHyRe2oNGES1K6toSXKgHC2IdRyzm0CcM5tMrOAI1eZWR+gD0D9+vXDGI5IGeTxwE8/HT7jWJ7hIg5kH6DtK21ZvHmxf93S0fE03ZwNCd95k4eSQZkQ9ecQnHNjnHMpzrmUpKSkaIcjUnrk3Bk0bw6VKkFCwmGjiqbOTqX8iPL+ZDD+8vG4Bzw0bXxOwP2ldAvnHUKmmSX77g6SgS1hLEtE8suZizg7G/74A5YsgaZNwYzZa2dz3hvn+Xft2bQnE/42wTssNWgU0jIqnAnhY6AX8Jjv95QwliUi+eWfi7hpUzb/mUnyU8n+XSokVGDjoI1UOyZfJ0CNQlomhSQhmNkEvA3INcwsA0jFmwjeM7PewHqgR8FnEJGQyzPfQHaN6lz8VlfS1qT5N3/T+xva1G0TxQAl1oSql9E1BWw6PxTnF5FC+J4oPqyKJy6O59a9x79G/8u/atSFoxh49sAoBCmxTk8qi5R0BXQrTd+Yzlljz/Lvdt6J5zHjuhkalloKpIQgUtLlNB77upXuzFhF3bda8ufBP/27bBy0keQqyUc4CQXfZUiZEfVupyISJF/jsUuI59qbq1HttSb+ZJB2fRou1RUtGXTuDHXrQqdO3mUpc5QQREo6M9585ibihmYzofZWAIZ0GIJLdXQ5qUvRzpHvLoOtW8MYsMQqVRmJlFQeD8tXzOW0ief6VzWr2YwFtyygfEL54p0rfxdVPYxWJikhiJRAe/b/wempSaw9Zp9/3ao7VnHy8Scf3QnzdFFVG0LZpSojkRJmwPQBVHqsij8ZTJoUh+u7OTcZFGNk00PkPIymZFBmKSGIlBBTV0zFhhnPfPsMALf+moxnRDx/O759bhXP0TYOH20SkVJFVUYiMW79rvU0+E8D/3KtSrVY9a9VVE6oCMPzVfEEahwubAiKIg6PLaWf/uoiMepg9kFaj219SDL4oe8PbB68mcqJlQNX8eQ0DhdnpFL1MBIfJQSRGDTiyxEkjkhkwcYFAIy9dCwu1dGsVrMjH5jTOJyRAXPmFK094GiSiJRKqjISiSFf/u9LOo7v6F++8tQreb/H+8RZMb67FXekUvUwEh8lBJEYsOXPLdR6MvdDPN7iyRycSfWK1SMTgIa7FpQQRKLK4zxc8s4lfLrqU/+6uTfNpV29dlGMSsoqtSGIhNMRunO+tOAl4h+K9yeDx7s8jkt1SgYSNWG/QzCzdcDvQDaQ5ZxLCXeZIjGhgO6cizctpuWYlv7dOtTvwKxes0ggzps8VI8vURKpKqPOzrnfIlSWSGzI151zV8ZqGrxzFrv27/LvkjEwgzrH1tGzABIT9I4TCZc8w1L/s3d1qr7W2J8Mpv9jOi7VeZMB6FkAiQmRSAgOmGFmC82sT/6NZtbHzNLNLH2r/hNIaWLGu8/dStzQbF5PzgTg7nZ341IdXU/peui+ehZAYoC5MI9dYmYnOOc2mllNIA24wzn3ZaB9U1JSXHp6eljjEYmEFdtW0OT5Jv7lJtWbsKTvEiokVCj4IM1YJkfJzBaGon027G0IzrmNvt9bzOxDoDUQMCGIlHR7D+7lzNFnsnL7Sv+6Ff1W0Kh6o8IP1rMAEmVhrTIys0pmViXnNXAhsDScZYpEy10z7qLiIxX9yeDdv72LS3VFSwYiMSDcdwi1gA/Ne/ubALzjnJse5jJFIurTlZ/S7Z1u/uWbmt/EK5e9gqnaR0qYsCYE59wa4MxwliESLRm7M6j3dD3/cvVjqrOm/xqOLX9sFKMSOXoaukKkmA5mH6Tj+I7Mz5jvX7f41sU0r908ilGJBE/PIYgUw+NfP07iiER/Mnjpry/hUp2SgZQKukMQKYL8s5Zd2vhSPrr6o+INSy0S45QQRI5gz8E9jJw7kpFzRwIQZ3FsvnMzSZWSohyZSOgpIYgE4Jxj4rKJ3J12Nxt2b6Bn05483uVxGlRtUPjBIiWUEoJIPgs3LqT/9P7M3TCXFjWa8Xavt+jQ8NxohyUSdqoAFfHZ/Mdmek/pzVljz2Ll9pW8srwRCwb+RIde//YOKyFSyukOQcq8/Vn7efbbZxn+5XD2Ze3jzrPvZOhf+nDcfadBVnbu6KMaVkJKOSUEKbOcc0xdMZVBMwaxavsqLml8CU9d+BSNqzf2znDWrl3u/AQafVTKACUEKZOWbVnGwM8GkrYmjVNrnMr0f0w/dEhqM+8kNRp9VMoQJQQpU7bv3c6Dcx7kxQUvUqV8FZ656Blua3kr5bbv9N4V5P3g1+ijUsaoUVlKlwImtc/yZPHCdy/Q6LlGvLDgBfq06sPKO1byr7P6Ua7LhVC3LnTqpMZjKdOUEKT0yJmXON+H+8w1M2k+ujn9Pu3HmbXOZMmtS3jxry9So2INTV0pkocSgpQe+T7cV69aQPeJ3enyZhf2HNzD5KsmM/OGmTSr1Sz3GE1dKeKnNgQpPXwf7r+nz+Xhnsk8/d65lIsrx6PnP8qAtgMCT1+pxmMRv7AnBDO7CHgGiAdecc49Fu4ypWzy4HjjP7247/Of2bxnA71O78Uj5z/CCVVOOPKBajwWAcKcEMwsHngBuADIABaY2cfOuZ8C7f/zz9Chg/f/p5n395FeF3U/HROdY8wi94V73oZ59J/en/SN6bSt25Yp1/6X1nVaR6ZwkVIi3HcIrYFVvpnTMLN3gcuBgAkhLg4SE71tgR6Ptyo457VzgV8faVs4jpHiC2fiya6UwdYz72FXg3dI2HsC9Ze9xb4p1/B/L8WFPMHNnw8ZGeG/VgkJEB+f+7uw10Xd72iOCee5g40nTi2gIRfuhFAH2JBnOQNok3cHM+sD9AGoX78+M2eGOaIg5SSJcCceHXPk1wfcXn4+/klWVX8MRzZNMofSKPMe4uIq46lT+PE5XzaKE1u4kwH4/m0Hwl+OSCDhTgiBKgwO+Z7tnBsDjAFISUmJ+e/gZt5vJxIdzjkm/TSJwWmDWb9rPX8/7e88ccETNKzaMNqhBc05yM72/mRlHfq7sNdF3a80HaM79tALd0LIAOrlWa4LbAxzmVJKLd60mP7T+/PV+q84s9aZvHHFG3Rs2DHaYYWMWW61SPny0Y5GSpJQtdWFOyEsABqZ2YnAr8DVwLVhLlNKmS1/bmHorKG8sugVqleszsuXvEzvFr2Jj9OtmkgohTUhOOeyzKwf8BnebqevOueWhbNMKT0OZB/g+W+fY9gXw9iTtZcBbQfwQMcHqFqharRDEymVwv4cgnNuGjAt3OVI6eGcY9rKaQz6bBArtq+g20pj1LZWNBnypLqWiISRnlSWmLJ863IGzRjE9FXTaVL1FD6ZEEe3XzyQsEST1IiEmb5uSUzYsXcHA6YP4IzRZzB/w3xGXTiKH25fSrda7TXOkEiE6A5Boirbk83YRWMZOmso2/dup0+rPgzvOIykPUB8osYZEokg3SFI1MxeO5uWY1py2ye3cXrN01l06yJGd3uRpEuuyh3CGrzVREoGImGnOwSJuLU71jI4bTCTl0+mwXENmNRjEleeeiVm5p3cJv/8BGo3EIkIJQSJmD8O/MGjXz3KU/OfIj4unhGdRzDo7EEcU+6Y3J1y5ifQ5PYiEaeEIGHncR7e+uEt7v38Xjb9sYnrz7ieRzs/TJ39iZB/jgLNTyASNWpDkOIrYN7iQL7N+Jazx51Nr496Ue+4eszvPZ83Lh9PncuvK3ge45z5CZQMRCJKCUGKp4B5i/P7dfev3PDhDbQd15YNuzbw+hWvM7/3fNrWbat5jEVilKqMpHgCfZjnafTdl7WPUfNH8chXj3DQc5D72t/Hfe3vo0r5KrnnUDuBSExSQpDiKeDD3DnH5OWTGZw2mHU713HlqVfyxAVPcFK1kw4/h9oJRGKSEoIUT4AP8+83f8+AzwYwZ90cmtVsxswbZnLeiecd+Tz55zH2eJQgRKJMbQhSfL4P8617fqPv1L60HNOSHzN/5MVuL7Lo1kWFJ4P8itguISLhpTsEKbaD2Qd5YcELPDjnQf448Ad3tL6D1I6pVDum2tGdsJB2CRGJDCUEKZZPV37KoBmD+Pm3n+l6cldGdR3FaUmnBXdSNTKLxISwJQQzexC4BcjpU3i/b24EKYF++e0XBs0YxLSV02h0fCOmXjOVbo26eYebCJYamUViQrjvEJ52zj0Z5jIkjHbu28nwL4bz7HfPUrFcRZ684EnuaHMHifGJoS0ofyOziEScqowkoGxPNuMWjWXozKH8tm87N7e8mRHnjaBmJVXniJRW4e5l1M/MfjCzV80sYIujmfUxs3QzS9+qJ1ZjwhfrvqDVmFbc+slt/GXFNhZ+14Ixfx2tZCBSypkrwng0BR5s9jlQO8CmIcA3wG+AA4YDyc65m450vpSUFJeenn7U8Uhw1u1cx91pd/P+T+9Tv3Idnnh9Ez1+9GAJCZCRoSodkRhlZgudcynBnieoKiPnXJei7GdmY4GpwZQl4fPngT957OvHeHL+kxjGsE7DGHz2nVT8uBskqOePSFkRzl5Gyc65Tb7F7sDScJUlR8c5xzs/vsM9n9/Dr7//yrXNruWx8x+j3nH1vDuo549ImRLORuWRZtYcb5XROuDWMJYlxeHxsGDZDPp/9xDzM+bTKrkVE/8+kXPqn3Pofur5I1KmhC0hOOeuD9e55eht2vUr993TiteTM6m1vxyv/v0VerW4kTjTKCYiZZ0+BcqIfVn7eOzrx2j8QhMmJGVyz9ew4hkPN9a5RMlARAA9h1DqOeeY8ssU7pxxJ2t2rOHyJpfz5PiNnDJnsRqLReQQSgil2I+ZPzLgswHMWjuLpklNSbs+jS4ndYGrNNS0iBxOCaG08XjYtuEXHlj6HKMXvsxx5Y/j+Yuf59aUW0mI8/251VgsIgEoIZQiBw/u56V/NuXBeqvZXR7+r83tPNhpGNUrVo92aCJSAqg1sZSYsXoGZ77UjP6NV9NqI3w/Jp7nWv1byUBEikwJoYRbuW0ll024jK5vdeUAHqb80JQZE+Jp2vgcNRiLSLGoyqiE2rVvFyO+HMEz3z5DhYQKjOwykn+1+Rfl+5UrWoOx5jAWkXyUEGJZgA/tbE8245eM5/5Z97P1z63c2PxGHj7/YWpXzjPGYGENxjlzGOfMUDZ7trehWUTKNH0KxKoAE89/vf5rWr/Smpv/ezOnHH8K393yHeMuH3doMiiKQHMYi0iZp4QQq/J8aK9fOper37mSDq91YMufW3jnynf4+savSTnhKEe7zZnDOCFBD6eJiJ+qjGJVzZrsad+GkXHzGHmOw/3vM1I7pnJXu7uolFgpuHNrDmMRCUAJIQY555i4bCJ3XfY/MnY7eja9ipEXjKT+cfVDV4geThORfJQQYszCjQvpP70/czfMpUXtFrxz5Tt0aNAh2mGJSBmghBAjNv+xmSEzh/DaktdIqpTEK5e+wj+b/5P4uPhohyYiZURQjcpm1sPMlpmZx8xS8m27z8xWmdkvZtY1uDBLr/1Z+xk5dySNn2vMmz+8yZ1n38mKfivo3bK3koGIRFSwdwhLgSuBl/OuNLPTgKuBpsAJwOdm1tg5lx1keaWGc47/rvgvgz4bxOodq7m08aU8deFTNKreKNqhiUgZFVRCcM4tB7DDe6lcDrzrnNsPrDWzVUBrYH4w5ZUWy7YsY+BnA0lbk8apNU5l+j+m0/UU3USJSHSFqw2hDvBNnuUM37rDmFkfoA9A/foh7EUTSzweyMxk+/6dpP70Ai+lj6ZK+So8e9Gz9E3pS7n4ctGOUESk8IRgZp8DgR6FHeKcm1LQYQHWuUA7OufGAGMAUlJSAu5Tonk8ZJ3XiZf3fsUDnWDnMdD3rNsY1vkhalSsEe3oRET8Ck0IzrkuR3HeDKBenuW6wMajOE+J9/niDxjQ9CuW1YTz1sB/0uJodlsqKBmISIwJ19AVHwNXm1l5MzsRaAR8F6ayYtLq7au54t0ruGDqVeypUoEP34XP34Bmp2hYahGJTUG1IZhZd+A5IAn4xMyWOOe6OueWmdl7wE9AFnB7Welh9Pv+33n4q4d5+punKRdXjkfbDGFA5/uo0H+3d4iIWrU0VISIxCRzLnaq7VNSUlx6enq0wzgqHufhje/f4L6Z97H5j830OuMGHnnxF06Ys1BDTItIWJnZQufcUY52mUtPKofAvA3z6D+9P+kb02lbty1Trp5C64QGcFXdQ4eY1thBIhLD9JU1CBm7M/jH5H9wzqvnsPH3jbzV/S3m/fNrbzJIStIQ0yJSougO4SjsObiHJ+c9yeNzH8fjPAztMJR72t9D5YSKh85ENnMmbNumIaZFpERQQigG5xzv//Q+d6Xdxfpd6+lxWg9GXjCShlUbenfIzDx0JrJt21RNJCIlhqqMimjxpsV0HN+RnpN6Uq1CNeb0msN7Pd7LTQagmchEpETTHUIhtvy5hSEzhzBu8TiqV6zOy5e8TO8WBYxEqpnIRKQEU0IowIHsAzz37XM89OVD7Dm4h4FtB/Lvjv+maoWqRz5QM5GJSAlVthOCx3PYt3nnHJ+s/IRBnw1i5faVdGvUjVEXjqJJjSZRDlZEJLzKbhuCx+PtEVS3LnTqBB4Py7cu5+K3L+bSCZcSZ3FMu3Yan1z7iZKBiJQJZfcOYetWf4+gHYvmMuyjvjy/9FUqJ1bm6a5Pc/tZt2tYahEpU8puQqhZk6xzzmbs3rn8u4uxY+k4bml5C8M7DyepUlK0oxMRibgymxBmrZvNgJ47+XGLh44NOvDMRc9wZu0zox2WiEjUlLmEsGbHGu5Ku4vJyyfTsGpDJvWYxJWnXhloGlARkTKlzCSE3/f/zqNfP8qo+aOIj4tnRAuPU48AAAnnSURBVOcRDDp7EMeUOybaoYmIxITSlxDydSX1OA9v/fAW935+L5v+2MT1Z1zPo+c/Sp1jA07xLCJSZgXV7dTMepjZMjPzmFlKnvUNzWyvmS3x/YwOPtQiyNeV9Jv18zh73Nn0+qgX9Y6rx/ze83mj+xtKBiIiAQR7h7AUuBJ4OcC21c655kGev3h8XUl/PSaLe5O+4q3XziG5cjKvX/E6151xHXFWdh+7EBEpTFAJwTm3HIiZBtm91aow6pq6PFJvHdnxcH/7+7ivw/1UTqwc7dBERGJeONsQTjSzxcBuYKhz7quQnTlfO4Fzjg+Wf8BdaXex7uR1XHniX3nikmc46fiTQ1akiEhpV2hCMLPPgdoBNg1xzk0p4LBNQH3n3DYzawV8ZGZNnXO7A5y/D9AHoH79+oVHnNNO4JuE5vsJ/6H/jIF88b8vaFazGbNumEXnEzsXfh4RETlEoQnBOdeluCd1zu0H9vteLzSz1UBjID3AvmOAMQApKSmu0JP72gm2Jmbx7+O+YuwrKVSrUI2X/voSN7e8mYS40tdxSkQkEsLy6WlmScB251y2mZ0ENALWhOLcB6pX5YWe9RlWbw1/JsIdrfuR2vFBqh1TLRSnFxEps4JKCGbWHXgOSAI+MbMlzrmuwLnAQ2aWBWQDfZ1z24+6IF+bwae7FjJwxiB+abSGrvU68/Qlz3NqzdOC+SeIiIhPsL2MPgQ+DLD+A+CDYM7t5/Hwy1/bMKj6QqY1cjQ6vhFTr5lKt0bdYqZ3k4hIaRDTFe479+3koU/v5bmz0ql4EJ5Mi+OO92aSeEK9aIcmIlLqxGRCyPZkM27xOIbMGsK2Pdu4eXNtRkzcSs3m50By3WiHJyJSKsVcQvhi3Rf0n96f7zO/p0N977DULWqdCcM1cb2ISDjFVEJYs2MNnV7vRP3j6jPx7xPpcVqP3HYCTVwvIhJWMZUQdu7byUOdHmJwu8EallpEJMLMucKfBYuUM1qc4X5Y/EO0wxARKVHMbKFzLqXwPY8spob/TIxPjHYIIiJlVkwlBBERiR4lBBERAZQQRETERwlBREQAJQQREfFRQhAREUAJQUREfJQQREQEUEIQERGfoBKCmT1hZj+b2Q9m9qGZVc2z7T4zW2Vmv5hZ1+BDFRGRcAr2DiENON05dwawArgPwMxOA64GmgIXAS+aWXyQZYmISBgFlRCcczOcc1m+xW+AnNlrLgfedc7td86tBVYBrYMpS0REwiuUw1/fBEz0va6DN0HkyPCtO4yZ9QH6+Bb3m9nSEMYULjWA36IdRBEoztBSnKFTEmKEkhNnk1CcpNCEYGafA7UDbBrinJvi22cIkAW8nXNYgP0DjrPtnBsDjPGdJz0UQ7iGm+IMLcUZWiUhzpIQI5SsOENxnkITgnOuSyGB9AIuAc53uZMrZAD18uxWF9h4tEGKiEj4BdvL6CLgHuAy59yePJs+Bq42s/JmdiLQCPgumLJERCS8gm1DeB4oD6T55j7+xjnX1zm3zMzeA37CW5V0u3MuuwjnGxNkPJGiOENLcYZWSYizJMQIZSzOmJpCU0REokdPKouICKCEICIiPhFPCGbWw8yWmZnHzFLybSt0uAszO9HMvjWzlWY20cwSIxDzRDNb4vtZZ2ZLCthvnZn96NsvJN3Aihnng2b2a55YuxWw30W+a7zKzO6NQpwFDnmSb7+IX8/Cro2vo8RE3/ZvzaxhJOLKF0M9M5ttZst9/5f6B9ink5ntyvNeeCDScfriOOLf0Lye9V3PH8ysZRRibJLnOi0xs91mNiDfPlG5nmb2qpltyft8lpkdb2Zpvs/ANDOrVsCxvXz7rPT1Bi2ccy6iP8CpeB+imAOk5Fl/GvA93kbqE4HVQHyA498Drva9Hg3cFuH4nwIeKGDbOqBGpK9pnvIfBAYXsk+879qeBCT6rvlpEY7zQiDB9/px4PFYuJ5FuTbA/wGjfa+vBiZG4e+cDLT0va6Cd9iY/HF2AqZGOrbi/g2BbsCneJ9dagt8G+V444HNQINYuJ7AuUBLYGmedSOBe32v7w30/wc4Hljj+13N97paYeVF/A7BObfcOfdLgE2FDndh3q5M5wGTfKteB64IZ7wByr8KmBCpMsOgNbDKObfGOXcAeBfvtY8YV/CQJ9FWlGtzOd73HXjfh+f73hcR45zb5Jxb5Hv9O7CcAkYCKAEuB95wXt8AVc0sOYrxnA+sds79L4ox+DnnvgS251ud9z1Y0GdgVyDNObfdObcD77hzFxVWXiy1IdQBNuRZDjTcRXVgZ54PkwKHxAiTDkCmc25lAdsdMMPMFvqG5IiGfr5b71cLuJUsynWOpJvwfkMMJNLXsyjXxr+P7324C+/7Mip8VVYtgG8DbD7bzL43s0/NrGlEA8tV2N8w1t6PV1PwF75YuJ4AtZxzm8D75QCoGWCfo7quoRzLyM+KMNxFoMMCrMvfJ7bIQ2IUVxFjvoYj3x2c45zbaGY18T6b8bMvw4fMkeIEXgKG470mw/FWb92U/xQBjg153+OiXE87fMiT/MJ+PfOJ6nuwuMysMvABMMA5tzvf5kV4qz3+8LUlfYT3AdFIK+xvGEvXMxG4DN+ozfnEyvUsqqO6rmFJCK6Q4S4KUJThLn7De0uZ4Pt2FrIhMQqL2cwSgCuBVkc4x0bf7y1m9iHeKoiQfoAV9dqa2VhgaoBNERlWpAjXM9CQJ/nPEfbrmU9Rrk3OPhm+98RxHH5LH3ZmVg5vMnjbOTc5//a8CcI5N83MXjSzGs65iA7UVoS/YSwNc3MxsMg5l5l/Q6xcT59MM0t2zm3yVa9tCbBPBt52jxx18bbbHlEsVRkVOtyF74NjNvB336peQEF3HKHWBfjZOZcRaKOZVTKzKjmv8TacRnTk1nx1r90LKH8B0Mi8vbUS8d4ifxyJ+HJYwUOe5N0nGtezKNfmY7zvO/C+D2cVlNDCxddmMQ5Y7pwbVcA+tXPaNsysNd7/69siF2WR/4YfAzf4ehu1BXblVIdEQYE1ALFwPfPI+x4s6DPwM+BCM6vmqzq+0LfuyKLQat4db/baD2QCn+XZNgRvL49fgIvzrJ8GnOB7fRLeRLEKeB8oH6G4xwN98607AZiWJ67vfT/L8FaNRPravgn8CPzge9Mk54/Tt9wNb8+U1VGKcxXe+s0lvp/R+eOM1vUMdG2Ah/AmL4AKvvfdKt/78KQoXL/2eG//f8hzDbsBfXPeo0A/33X7Hm/DfbsoxBnwb5gvTgNe8F3vH8nT8zDCsVbE+wF/XJ51Ub+eeBPUJuCg73OzN942q5nASt/v4337pgCv5Dn2Jt/7dBVwY1HK09AVIiICxFaVkYiIRJESgoiIAEoIIiLio4QgIiKAEoKIiPgoIYiICKCEICIiPv8PsVW4qWEViusAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from mindspore import Tensor\n", "\n", "x_model_label = np.array([-10, 10, 0.1])\n", "y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] + \n", " Tensor(model_params[1]).asnumpy()[0])\n", "\n", "plt.axis([-10, 10, -20, 25])\n", "plt.scatter(x_eval_label, y_eval_label, color=\"red\", s=5)\n", "plt.plot(x_model_label, y_model_label, color=\"blue\")\n", "plt.plot(x_target_label, y_target_label, color=\"green\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上图中可以看出,蓝色线条的初始化模型函数与绿色线条的目标函数还是有较大的差别的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义前向传播网络与反向传播网络并关联" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来需要定义模型的损失函数,这里采用均方误差(MSE,Mean Squared Error)的方法用于判断拟合的效果如何,即均方误差值越小,拟合的效果越好,其损失函数公式为:\n", "\n", "$$J(w)=\\frac{1}{2m}\\sum_{i=1}^m(h(x_i)-y^{(i)})^2\\tag{2}$$\n", "\n", "假设训练数据第$i$个数据为$(x_i,y^{(i)})$,公式2中的参数解释如下:\n", "\n", "- $J(w)$为损失值。\n", "\n", "- $m$为样本数据的数量,本例中$m$的值为`batch_number`。\n", "\n", "- $h(x_i)$为第$i$个数据的$x_i$值代入模型网络(公式1)后的预测值。\n", "\n", "- $y^{(i)}$为第$i$个数据中的$y^{(i)}$值(label值)。\n", "\n", "### 定义前向传播网络\n", "\n", "前向传播网络包含两个部分,其中:\n", "\n", "1. 将参数带入到模型网络中得出预测值。\n", "2. 使用预测值和训练数据计算出loss值。\n", "\n", "在MindSpore中使用如下方式实现。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.249228Z", "start_time": "2021-01-04T07:04:53.243109Z" } }, "outputs": [], "source": [ "net = LinearNet()\n", "net_loss = nn.loss.MSELoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义反向传播网络\n", "\n", "反向传播网络的目标是不断变换权重值,使得loss值取得最小值,一般的在线性网络中采用权重更新公式:\n", "\n", "$$w_{t}=w_{t-1}-\\alpha\\frac{\\partial{J(w_{t-1})}}{\\partial{w}}\\tag{3}$$\n", "\n", "公式3参数解释:\n", "\n", "- $w_{t}$为迭代后的权重值。\n", "- $w_{t-1}$为迭代前的权重值。\n", "- $\\alpha$为学习率。\n", "- $\\frac{\\partial{J(w_{t-1}\\ )}}{\\partial{w}}$为损失函数对权重$w_{t-1}$的微分。\n", "\n", "函数中所有的权重值更新完成后,将值传入到模型函数中,这个过程就是反向传播过程,实现此过程需要使用MindSpore中的优化器函数,如下:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.273562Z", "start_time": "2021-01-04T07:04:53.250245Z" } }, "outputs": [], "source": [ "opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 关联前向和反向传播网络\n", "\n", "定义完成前向传播和反向传播后,在MindSpore中需要调用`Model`函数,将前面定义的网络,损失函数,优化器函数关联起来,使之变成完整的计算网络。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.287238Z", "start_time": "2021-01-04T07:04:53.275579Z" } }, "outputs": [], "source": [ "from mindspore import Model\n", "\n", "model = Model(net, net_loss, opt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 拟合过程可视化准备\n", "\n", "### 定义绘图函数\n", "\n", "为了使得整个训练过程更容易理解,需要将训练过程的测试数据、目标函数和模型网络进行可视化,这里定义了可视化函数,将在每个step训练结束后调用,展示模型网络的拟合过程。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.305631Z", "start_time": "2021-01-04T07:04:53.288251Z" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import time\n", "\n", "def plot_model_and_datasets(net, eval_data):\n", " weight = net.trainable_params()[0]\n", " bias = net.trainable_params()[1]\n", " x = np.arange(-10, 10, 0.1)\n", " y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]\n", " x1, y1 = zip(*eval_data)\n", " x_target = x\n", " y_target = x_target * 2 + 3\n", " \n", " plt.axis([-11, 11, -20, 25])\n", " plt.scatter(x1, y1, color=\"red\", s=5)\n", " plt.plot(x, y, color=\"blue\")\n", " plt.plot(x_target, y_target, color=\"green\")\n", " plt.show()\n", " time.sleep(0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 定义回调函数\n", "\n", "MindSpore提供的工具,可对模型训练过程进行自定义控制,这里在`step_end`中调用可视化函数,展示拟合过程。更多的使用可参考[官网说明](https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/advanced_use/custom_debugging_info.html#callback)\n", "\n", "- `display.clear_output`:清除打印内容,实现动态拟合效果。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:04:53.318392Z", "start_time": "2021-01-04T07:04:53.306647Z" } }, "outputs": [], "source": [ "from IPython import display\n", "from mindspore.train.callback import Callback\n", "\n", "class ImageShowCallback(Callback):\n", " def __init__(self, net, eval_data):\n", " self.net = net\n", " self.eval_data = eval_data\n", " \n", " def step_end(self, run_context):\n", " plot_model_and_datasets(self.net, self.eval_data)\n", " display.clear_output(wait=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 执行训练\n", "\n", "完成以上过程后,可以使用训练数`ds_train`对模型训练,这里调用`model.train`进行,其中参数解释:\n", "\n", "- `epoch`:训练迭代的整个数据集的次数。\n", "- `ds_train`:训练数据集。\n", "- `callbacks`:训练过程中需要调用的回调函数。\n", "- `dataset_sink_model`:数据集下沉模式,支持Ascend、GPU计算平台,本例为CPU计算平台设置为False。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2021-01-04T07:05:27.693120Z", "start_time": "2021-01-04T07:04:53.319412Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD8CAYAAACSCdTiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3xUVfrH8c9DqFIEpAhSBLvgSomouO5aUBFdQXct6M+GEMDeQUBAsKJYkBpEAUWKikhViqKrCBKQHpGqhhISek1I5vz+mGGNYQIJMzeTTL7v1yuvzNx7554nN5Mnd8499znmnENERKJTsUgHICIi3lGSFxGJYkryIiJRTEleRCSKKcmLiEQxJXkRkSgWcpI3s9pm9o2ZJZrZSjN7LLC8t5ltMrMlga9WoYcrIiJ5YaGOkzezGkAN59xiMysPLALaALcB+5xzb4QepoiInIjioe7AObcF2BJ4vNfMEoHTQt2viIiELuQz+b/szOx04DugIfAkcB+wB0gAnnLO7QzymjggDqBs2bJNzz333LDFIyJSFCxatCjVOVc12LqwJXkzKwd8C7zknJtoZtWBVMABffF36bQ71j5iY2NdQkJCWOIRESkqzGyRcy422LqwjK4xsxLAZ8AY59xEAOdcsnMu0znnA4YDzcLRloiI5F44RtcYMAJIdM69mWV5jSyb3QysCLUtERHJm5AvvAKXAXcDy81sSWBZN6CtmTXC312zEegYhrZERAonnw9SUqBKFUhNhWrVwMzzZsMxuuZ7IFik00Pdt4hIVMjIgMsvh4ULoVw52LcPLrsMvvkGinl7T6rueBUR8ZLPB//4B8yfD5mZsHu3//u8ebBtm+fNK8mLiHgpJcV/Bn9E+fL+7xkZcNtt/n8CHlKSFxHxUrVq0Lw5FC8Ol14KiYkQE+Nf9+OP/n8CHgrHhVcREcmJmb/vPSXFn/DB3x8/bx40b056xaqU9LB5ncmLiHitWDGoXt2f8ANJP2PjH9z/r86Ue6Ypv/6+y7umPduziIgENXtRElVfbs/I/W0pXSqGbfu2e9aWumtERPLJ/gOZ3PLaQGZmdIdKjjsrv8nIHo9QIsa7VKwkLyKSD0bOWEbn6R04VOUnah5uyeSOQ2h6xumet6vuGhERD21JPUjjp5/j/h+bkl5mHV0+u5CkBftpWq9OvrSvM3kREQ84B71Hz+GlpR3JPHkdDQ7czYzBk6i9f6l/OGVKiv9irMd0Ji8iklc+HyQn+zN5ECvWbafuY/fRZ2MLYmKMwRfPYcWro6jdtLE/wTdv/udwSo/pTF5EJC98Prjyyv+Nc89afyYz09FhwFhGbn0cV3EnVxbvxqQXelDhpDL+12YdL58PxclASV5EJG9SUvwJPiPD/z3Q7TJr4UZuG92ZXVW+pEJGMz6+dTY3xP7tr689Ml4+H6m7RkSKtuN0vRwla5mC5s3ZX74y1/buz7WTGrCrwvfcVXkA2/vNOzrBn0hbYaAkLyJF15Gul1q14Iorclcs7EiZgqQk3u/yFlWeu5RZ9jQ1069i8QOr+OiRRyh+pDZNqG2FQThmhqptZt+YWaKZrTSzxwLLK5vZLDNbE/heKfRwRUTCKFjXSy5s2XGIRv3f4IEFzUgvncRzZ04gqd9kGtevHfa2QhWOM/kM4Cnn3HnAJcBDZnY+0BWY45w7C5gTeC4iUnBk63rJzYiXXqNnUvuVhiwt+wYND7dj4zOJvHzXrdjxLqSeQFvhEI6ZobYAWwKP95pZInAa0Bq4IrDZKGAu0CXU9kREwiZ7hchjJOrl61K4YcCT/FH5I0ra2Qy6eC4dW/7Tk7bCKayja8zsdKAxsACoHvgHgHNui5nlz78tEZG8OM6Il8xMxwMDPmR08pO4k/dwVczzTOrbjfJlSoe9LS+ELcmbWTngM+Bx59ye4350+fN1cUAcQJ06+XObr4hIbsxMWMftozux65TZnJzenLH/ief62AaRDitPwjK6xsxK4E/wY5xzEwOLk82sRmB9DSDoZIbOuXjnXKxzLrZq1arhCEdEJCT796dxbfe+XDfpAnaVW8DdlQaz/Y3/FroED2E4kzf/KfsIINE592aWVZOBe4FXA9+/CLUtERGvjZj+Ew9NvY+06omctv4ypj4/lkZnHmPUTAEXjjP5y4C7gavMbEngqxX+5H6Nma0Brgk8FxEpkDan7qNRlydpv+BSDp+0nW7jmpA0+gcalS3chQHCMbrmeyCnDvirQ92/iIjXen44g5eXdiaz/G9csK890weOo1ba4j+n6yvEdMeriBRZy9YlU/uJtvRd34ri7iSGXfw9y/rFU+viJv7x7Jdfnu+jYcKtcH8OERE5AZmZjnbvfsCHyU/jyu3n6mIvMOnFLpQrU8q/QQTGs3tFSV5EopfPd1Sy/iphDbd/FMfuSnM5+dDljL0rnutjz/3r6yIwnt0r6q4RkeiUrSDYvv2HaNHnZVp+cQG7T/qZeyrGs73/3KMTfJTRmbyIRKfk5P8VBHtvw24e7hFLWsWV1DrwH6Y+OIALz6gR6QjzhZK8iEQfnw9uv53NxUpx/fWXsazZd8TsP43u9b/gxbtvinR0+UpJXkSiT0oKz6ek88pDFcis8B1/2xfH9K79OK1KhUhHlu+U5EWk8ApyYXXpui3cOOgxku5YQKnkMxm6pA3t5wwp9KNkTpQuvIpI4ZTtwmpmRgb3vBNPoxHnkVR2Mi3sRVKf/Zr2cz4rsgkedCYvIoVVlpmWZvySRNunrmR35e85+cAVjLtrGC0vOjvSERYISvIiUjhVq8a+S/9O6xgfX/99Ppaxk3sqjuD9HvcTE1N0z9yzU5IXkUIp/qt5PHrRNtIqrKL2rjuY+vDb/O2M6LiBKZyU5EWkUElK3U2r/s+xvPQQYqwOz9efRp+7W0U6rAJLSV5ECo3uH33Oa8seJrPMVi488DjTu/alZpVykQ6rQFOSF5ECb8m6Tdw4+BE2VficUukXMuzKSTxw/UWRDqtQUJIXkQIr0+fjvoHD+GhrVyiTTgteZdIrT1K2TIlIh1ZohGuO1/fNbJuZrciyrLeZbco2W5SISK5MX7iSyk9fzkc7H6TivmZ82XoFs3p18Sd4n89fm8a5SIdZ4IXrZqiRQMsgy99yzjUKfE0PU1siEsX2HjzE1S/25IYpjdlbcjX3nTyK1Ldmct1FZ/g3yHYTFD5fROMt6MLSXeOc+87MTg/HvkSk6Br21Xc8NiuOtPKrqb3n/5j2yJtccEbVv26U5SYo5s3zP4+S2u9e8LqswcNmtizQnVMp2AZmFmdmCWaWkJKS4nE4IlIQJaXu4oJucXSa/08yXBo9633J729/eHSCB3+dmubN/dPzNW/ufy45MhemPq3AmfxU51zDwPPqQCrggL5ADedcu2PtIzY21iUkJIQlHhEp+JxzdP/4U/ote5TM0tu48OCTTH+2NzWrlP3rhtkLkQUpTFaUmdki51xssHWenck755Kdc5nOOR8wHGjmVVsiUvj8vP4Paj/bmlfW3kbxQzV57+KFLOn3evAEn70P/sj0fErwx+XZEEozq+Gc2xJ4ejOw4ljbi0gUy3LmneHzcd+gwYxJ7gYlfVzj3uDzVx+jbJkc0pH64EMSliRvZmOBK4AqZpYE9AKuMLNG+LtrNgIdw9GWiBQyR87E581j+j9a0LbRTvZUWEDFPdcx/p4hXHtRvT+HRAbrfjnSBz9vnvrgT0C4Rte0DbJ4RDj2LSKFXEoKexfMp/U/m/JN89lYWiXurzCG4T3a+qtFZvknQPPm8M03/u6YI8z8y9QHf0J0x6uIeGrYzyt5tFMN0istoM6q65n20mganlnlzw1y0x1zpA9e8kwzQ4mIJ/5I3UHD7u3otOBqMosVp1fVSfw2btpfEzxoSKTHdCYvImHlnKPbx+Ppt/wxfKW202jnU0zv3ocaVU8K/gJ1x3hKSV5Ewmbx+t/419DObC47g9IHLmLwvKu4/8d34OeFR/e1Z6XuGM8oyYtIyDIyM7ln8ADGJveAEsa1vreZ+MStlD27roY+RpiSvIiEZGrCEu4a14E95ROouKsVE+4ZzDXN6vorRGroY8QpyYvICdlz8ACt33qBuWn9sWJVaFd+PPE9bv1zEm31tRcISvIikmeDv5rFE3M6kV52PXV2tGfa4/1oeEaQGoTqa484JXkRybXfU1O5/u0nWVXiQ2IOn03v07+hZ88rdJJegCnJi8hxOefoOnYMbyx/Al+JXTTe04Np3bpTo2rpSIcmx6EkLyLHlLBuPTfFd2bLSTMpve8SBrcczv03NIx0WJJLSvIiEtThzAzuHfI2Y7f2hJjiXJc5kM9e60TZk2IiHZrkgZK8iBxlyqJF3DWuA3vL/UzFHTcx4d5BXHNxrUiHJSdASV5E/lfvfXe5srR+pxffpr2NUY125T5lWPdbKF5cV1YLKyV5kaIuUOp30NbNPHnDAdJP3kyd1I5Me+JVGp5ZMdLRSYjCUoUyMFH3NjNbkWVZZTObZWZrAt+DTuQtIpG18ZdVnF9lDw/fuZbM9LK8cMp0Ng4cqgQfJcJVangk0DLbsq7AHOfcWcCcwHMR8dqRWZacO+Zmzjme+eh9zvjwnyQ2WEnjua354+cz6PlQS417jyJhSfLOue+AHdkWtwZGBR6PAtqEoy0ROYZgk14HsXDdWmo+14I31j1AyW2n88Gc21n88RBqfD9d5QeijJeThlQ/MpF34HvQ6kRmFmdmCWaWkJKS4mE4IkVAcvLRsyxlkZ5xmDsGvkqzkRew1RJoObUNqR+s5L6fxvlLECjBR52IzwzlnIt3zsU652KrVq0a6XBECi+fD26/3Z/gzY6q/PhFwk9U6RbL+O3PUSmlFbNar2LGSTsoG5OpKpFRzMvRNclmVsM5t8XMagDbPGxLRFJS4Mcf/Y+LFYPx48GMXQf20vqd5/kubQDmq0m7kz4nflAbYmJQlcgiwMsz+cnAvYHH9wJfeNiWiGSdK/Wyy6B6dd79airV+zTgu7QB1N32IMs6rmLEM4EED39WiVSCj1phOZM3s7HAFUAVM0sCegGvAhPM7AHgd+DWcLQlIjnIUr99g/lo1esOfomZQMyBBrxwzg883+tS5fIiKCxJ3jnXNodVV4dj/yISROAu1axdLc6Mp+dM4e0Vz+CLOUDjnX2Z1uNZalQrGeFgJVIifuFVRE5AkKGSC9atpma3K3lzTQdK7ryQD5otY/HbPZTgiziVNRApjFJS/jdUMu3HH7j7rZ58susNcGVomT6cT/u3o+xJxziHC/IpQKKTzuRFCqPARdbP69SgSqc6fLLvJSpta82sNonMeKn98RN8Lm6YkuigM3mRQmjnwb3c1LIh36f9F9tbiwfKTGHY4Bv/HDVzLFk+BfzvhinNwxq1lORFChOfj3cmjuGZhOc4XHozdbc+wtSnXqThWeVzv48jQy3nzdNNUEWAkrxIIbF+WxKtnr2Z1fUSiNnTgD5nfkqPXpfkvUs9y1BL9clHPyV5kQLO53w8PTaed1Z2wVcrnSaz/8O0Bd9y6h/1wDixi6hHboKSqKcLryIF2Px1idTs/k/eWtOZkttj+WBOWxbNn8Spl57nT+p5vYiayzLEEj2U5EUKoEOH07h1UG8uHXUhyZmruD7tA1L7z+a+H96DpCSYO9d/1h7sImpONKqmSFKSFylgPlv4PVWfb8SnqS9QacutzL45kekv30fZsnZ0rZms9WqOdxE1L/8QJGqoT16kgNi+fxet3+3KD2nDsPS6PFBpOkOHXE/xY/2V5uUiqkbVFElK8iIR5pzj7ZkT6TL3EQ6XTKbu1ieZ9lQfGpxdNnc7yO1FVI2qKZKU5EUiaO22JG4Y9BC/FptMzJ7G9GkyhR69m3qXfzWqpshRkheJgExfJk+NG8K7q7rhI4MmO15nWq/HObWa/iQlvPSOEvFKDuPX561dwc0fdGBbyfmUTrmGoTcM5d6b6kcwUIlmno+uMbONZrbczJaYWYLX7YkUCEGGKx48fIh/D+rBZaMbs+3wWlod+pCU12dw78VlNW5dPJNfZ/JXOudS86ktkcjLNlzxkzlTuP/rZ9lf+lcqb7qHCe37c/XFlf3/CI6MdvnmG3+fuUgY6R0l4oXAcMXtZYtz2R1NuW1eGw4czKB9qZkkDxvF1ZdW0bh1yRf5keQdMNPMFplZXPaVZhZnZglmlpCiN7lECQf079qJGo+cwrz6CdRNepZlnZczvOs1f457z8uNTCInyJzHfYFmVtM5t9nMqgGzgEecc98F2zY2NtYlJKjbXgq3Ndt+54bBD7LGphGT3JRejYfTo33j4MMiNUOThIGZLXLOxQZb53mfvHNuc+D7NjP7HGgGBE3yIoVZpi+TJ8cN5N1V3XEOmu5+iym9HqZG9WP8mWncunjM0yRvZmWBYs65vYHH1wJ9vGxTJBK+X7uUWz7oQErJhZROvp6hNw7h3tZ1Ix2WiOdn8tWBz83/MbQ48LFz7kuP2xTJNwcPH+Su4X34PPl1SD+FVhljGf/W7ZQrp64XKRg8TfLOufXAhV62IRIp4xfO4YHPO7K/1Doq/96OCR1e5+rmlSMdlshf6I5XkTxK2bedmwY9xfxDo7D9Z9Kh3NcMir+SEiUiHZnI0ZTkRXLJOcfrX31M9/8+TkbMLupu6sbUZ3rQ8NwykQ5NJEdK8iK58EvyBm4c2pl1fEXMjmb0aTycHi/8TaMepcBTkhc5hgxfBo+Pe4fBiT1xmcVoumsAU3o9SI1TYyIdmkiuKMmL5ODbNYv5z6gOpJZYTJnN/2LIjYO49+bakQ5LJE+U5EWy2Z++n/97rzeTkt+CQ1W5Ye+HjH3rTspXUKknKXyU5EWyGLtwJu0ndeJAyQ1U+q09E1YdpsXP98OS4aoSKYWSkrwIkLw3hdZDnmTBwY+wvefQofq3DHrxHErUq/XXKpEqQSCFjJK8FGnOOV778kOe//5JMmL2cPofPZnS5TkanlvaP5FH8+Z/1ntXlUgphJTkpchKTF7HjUM7sZ7ZxKQ2p0+TeHq80ODPYZFm/i4aVYmUQkxJXoqcw5mHeWz8mwz9pTcuowRNdwxiysP/pkbDapA9j6tKpBRyuook0cnng+Tko+ZO/ebXhdTofRFD1nSl1B/XM7LJShISx1OjyZ9zsYpEEyV5iT5BJtHel76PmwY/wVVjLmH7gRRa7Z3ItgETuffykpqCT6Kaumsk+mSbO3XMnPHEze3KgZK/U3lDZ8Z3eIUWl5/s37ZcNV1claimJC/RJzB36tYlP3DTLQ1YOO9ObPf5dDj1ewa+dxklS2bZVhdXJcp53l1jZi3NbLWZrTWzrl63J+KAl7vcTe1HKrCwViJ1N7zAsocWE98jW4I/4sjFVSV4iUJeT/8XAwwCrgGSgIVmNtk5t8rLdqXoWrHlV24a3pENbi4xyZfTp0k83V84VzeqSpHldXdNM2BtYIYozGwc0BpQkpewSs9M59HxrxP/S1/c4dI03RHP5D4PULOGsrsUbV4n+dOAP7I8TwIuzrqBmcUBcQB16tTxOByJRnNWz+f2jzqwvfgKyvx2K0Nav8O9t9SIdFgiBYLXST5YJ+dfBi475+KBeIDY2FgXZHuRoPam7eXOEd2Yum0Q7D+NVkxm7Lv/okKFSEcmUnB4neSTgKwFuGsBmz1uU4qA0Qsm02nKQxwsvonKax9mfMeXaHF5+UiHJVLgeJ3kFwJnmVk9YBNwB3Cnx21KFNu8Zws3DX2URQc/xXY2pEP1Txj4/iXBR82IiLdJ3jmXYWYPA18BMcD7zrmVXrYp0cnnfLw0PZ4+P3Ylww5Rd+NLTHnuGS44v0SkQxMp0Dy/Gco5Nx2Y7nU7Er2Wb/mFm4bHsdH9l5iky+mz/FK6L+hKseIaOSNyPLrjVQqstIw0Hhn/Gu+tfgmXXpamMzsz+eeJ1Cz+I2x/UtUhRXJBp0JSIM1a/QM1+zRm+NpelN5wCx80WUVC+ZXULL5dNWZE8kBn8lKg7D60mzvf78r0lKGwtw6t3DQ+HtCSk9NT4N9fQ2qqasyI5IHO5KXAGDn/c2q8eD7Tk+OpvPoJZrZZybQ3W3Jym0DZ4KuugqpVleBF8kBn8hJxSbs3cdOwh/n54CRsx4V0qD6Jdz+4iFKl8E/8kb3eu/riRXJNSV4ixud89J0xlL4/diXTHabuhteY3O0J/tYgy7DIaqr3LhIKJXmJiCWbV9LmvTh+c/OI2dSCPo0H0/2hChQ7NdtbUvXeRUKiPnkJjxzmVM3uUMYh2o/pSZNhjflt/2qa/j6Kjb2/5PlP21OsTg7zrKreu8gJU5KX0AWZUzWYr375jpp9GjFibV9Kr72d95smkjDiHmqVStU8qyIeUZKX0GWbUzV7kt55cCfXD+5Ay/H/ZOeedFpt/4rNgz7k/tuq+jc40u9evLj63UXCTH3yErocLo4653h//qc8PP0RDhVLpfLqZxjbqRfXXlH2r69Xv7uIZ5TkJXRBkvTvu/7gpvgHWXpwKqQ0oUP1Gbw7qrF/WGQwR/rdwd/do4QvEhbqrpHwCCTpTOej1/R3qd//fJbu/pq6v/RnyUMLiH/hGAk+q1z274tI7uhMXsLm583LaTOiA7/7FhDzx3W80HQIPfrWy9sk2sH693Xzk8gJU5KXkB08fJCHJvTlg19fh4OVaJoyhs/7tqV27RPoatHNTyJh5VmSN7PeQAfgyFCLboHa8hJFpiV+zf+N68iuYmsp8+t9DG7zBvfedsqJd6XrIqxIWHl9Jv+Wc+4Nj9uQCNh+YDt3jnyGmSkfwK4zuME3m4+GXE3FimHYedaLsCISEnXXSJ445xg+72Me/fIJ0ortoFJiV8Z37sk1V5aJdGgiEoTXSf5hM7sHSACecs7tzL6BmcUBcQB16tTxOBwJxYadG2kzvDPLDn4JybF0+OlRBszrRumTNEhLpKAyd5xaI8d8sdls4NQgq7oD84FUwAF9gRrOuXbH2l9sbKxLSEg44XjEGxm+DHrPeJdXF/QgM9OoOyeOL376igtjfoWkJHWtiESYmS1yzsUGWxfSmbxzrkUuAxgOTA2lLYmMhKQl3PxBe5J8i4jZeAMvNBlE95L3ERPzq0a/iBQCXo6uqeGc2xJ4ejOwwqu2JPwOHD7AgxNeYNSv/eFAFZpuG8/EF2+lTh2DR+do9ItIIeFln3w/M2uEv7tmI9DRw7YkXHw+pvz4KXfPfI7dxdZT5pf2DLq5H/fdXunPfK7RLyKFhmdJ3jl3t1f7Fm+k7ttG20evYHbdRNhxNq18X/PR0CupVCnSkYnIidIQSsE5x9B5H/HEl0+QVms3lb/tzNjvV3HtH+eDErxIoaYkX8St27GeNiM6seLALNhyKe0X/ot3V75E6cua6qKqSBRQki+iMnwZPD/9Lfr91AtfRnHqrh3IpO6daXQBkNJOF1VFooSSfFHj8/HT8tncMrkLm3xLiFnfmhdiB9K9by1iYgLb6KKqSNRQki9C9h/aS+e4S/mwXiLsr06TrZ8y8aVbqFtXZ+wi0UpJvoj4YuWX3DO+I3vO+J3SCfcwaM5+7l//d+xUJXiRaKYkH+W27d/GnaOeYE7Kx7DjXFp9/xofrniDypedB9V1YVUk2inJRynnHAN/GMnTM58inX1UWtGbjx/sSssBJSDl3mNfWNUcqyJRQ+UDCzufD5KTIUuhuTXb13LBGy14dE470jc1oH3GUjaN6UXLFqX+vFv1WAlec6yKRA0l+cIsW0I+fDiNLlNf4dwBF7ByRwJ1lw9l8SPfMvzl8yiT23LvweZYFZFCS0m+MMuSkOdtmEfdl5vSb1E3bM0N9K6ayLrxHWncKI+/4iNzrBYvriqTIlFAffKFWbVq7L28GZ1Kr+HjZqmwZxdNtk5i4iutqVv3BPepOVZFooqSfCH22YpptPvHH+whldLLOzPw5ldod1eF0POyqkyKRA0l+UJo676ttB39GHNTJkBKA1pl/sDo+Es55ZRIRyYiBY2SfCHicz4GfD+CLrOfJd13gErL+zLmwWe5/tqSkQ5NRAqokC68mtmtZrbSzHxmFptt3XNmttbMVpvZdaGFKb+krKbhG1fyxNdxpP9+IQ8cXkbSxz2U4EXkmEI9k18B3AIMy7rQzM4H7gAaADWB2WZ2tnMuM8T2ipz0zHS6TX+NtxJexJd2EnVXv8fE59vRpIkuiIrI8YU6kXcigB19pa81MM45lwZsMLO1QDPgx1DaK2q+/+1HbvuwA1syVxKz+nZ6xb5NjxdPpbg62UQkl7xKF6cB87M8TwosO4qZxQFxAHXq1PEonEIiUE5gT/mSxI15ivGbRsKeWjTZPJVPX7mBevUiHaCIFDbHTfJmNhs4Nciq7s65L3J6WZBlLsgynHPxQDxAbGxs0G2KhMDdqxNSF9C+ZWn2lt9D6YUdGXjba7TrH4ZhkSJSJB03yTvnWpzAfpOA2lme1wI2n8B+iozNG5fTtkYi312VBlvPodX4NozaOowq7/cGqxDp8ESkkPKqrMFk4A4zK2Vm9YCzgJ88aqtQ8zkfb343lNNH/YPvztxLpdnPMO2j05m29SWqXHaOygqISEhC6pM3s5uBd4GqwDQzW+Kcu845t9LMJgCrgAzgIY2sOdrKbav496g4Vh/4AX67igeqDOGdUSdTtm4VSE1VWQERCZk5V3C6wWNjY11CQkKkw/BcWkYaz01/hbcXvYw7VJ463z7LxIRPaXrZSf66McVUN05Ecs/MFjnnYoOt02C8fPbtxv9y+0dxJGf+QkziXfQ4rw89Es6neGYazCvuLwymujEiEiZK8vlk16FdxE3owicb4mHn6TTePINPX76W+uW2wTfN4McfVdpXRMJO/QIec84xbuln1H71fD5Z9x6lFz/F8MYrWDT2Wuo/cCXUru3vd//9d5g7V33wIhJWOpP3UNKeJO748CF+SJ0MWxvT8vAURr/XlKpV8U/Zl3UGpmLFlOBFJOx0Ju+BTF8m/b4dSP3+5/PDlllUWvg609r8xIwRgQQPmoFJRPKFzuTDbHnycv4zKo5fD86HDdfQrtpQBoyvT9my2TbUDEwikg+U5MPkUMYhukx7kXd/fg13oCJ1Ej/i0553ctFFx0jemoFJRDymJJ9XgSJiWc++v14/l7Yfx7Etc5yzTP8AAArKSURBVA3FVtzD8xf1p8fEKpQoEeFYRaTIU5LPi0ARMebNg+bN2TH9M+I+68pnG0bAjvo03jSLT15rwRlnRDpQERE/Jfm8SEmBefNwGRmM2TmfTv3OZz87KL24C2/f3JO4t09S17qIFChK8nlRrRq/X9mEtlXXMu/sHbCpLi0Pz2TUiEYaHCMiBZKSfC5l+jJ5/buBPH/pSjIyoOL8N/nwkUe5sVVMpEMTEcmRknwuLN26lP+M7sDagwth3fW0qzaEdz6pS7lykY5MROTYlOSP4eDhgzwz7QUGL3kDt/8U6qwayye9bqdZM3W8i0jhoCR/RLahkTPXzub/xnUiJXMdMcva0e2i13l+UmUNixSRQiXUSUNuBXoD5wHNnHMJgeWnA4nA6sCm851znUJpy1NZhkZu/8dFtL/7LCb9Nhq2n0WjpK/5pN+VnHlmpIMUEcm7UM/kVwC3AMOCrFvnnGsU4v7zR0oKbt4PjD4PHmy6lAPrF1J6UTfevqUHce+U0bBIESm0QkryzrlEACvkWXBDif207VCTBdX/gKQLaHk4npHv/00VB0Sk0POyCmU9M/vZzL41s8s9bCf3fD5/id/AlIcZvgxe+qY/Zw+4gAUVd1Lxh3f44qbvmTFSCV5EosNxz+TNbDZwapBV3Z1zX+Twsi1AHefcdjNrCkwyswbOuT1B9h8HxAHUqVMn95HnVbaSBIs/7s9tYzqy7uBi+PVf3F99EO98Vpvy5b0LQUQkvx03yTvnWuR1p865NCAt8HiRma0DzgaOmqXbORcPxIN/Iu+8tpVrgZIE+y2Dp8osZtjwi2FfNeqs/ITxvf7NJZcU7i4nEZFgPBlCaWZVgR3OuUwzqw+cBaz3oq1cq1aNGTecx92nb2R7pb0U+7kD3S56jZ6TK2lYpIhErVCHUN4MvAtUBaaZ2RLn3HXAP4A+ZpYBZAKdnHM7Qo42rwJj31PKGg98+iRTGi+H1HNo9PMwxvf7J2efne8RiYjkq1BH13wOfB5k+WfAZ6HsO2Q+H+7KK3h/73weua4kB0ukU2phT97+dzc6DiilYZEiUiRE7R2v69Yu5I7Tl5FQ/zD8fhHX7RvIyA8ac2qwS8giIlEq6pL84czDvDL3Tfr8tzeZNUtScepLjDy4h9bLG4HO3kWkiImqJL9w00JuH9OBDQeXwi83c3/1AbwzriTl61fVRNkiUiRFRZLfl76PJ6c+z3vLBuD2nkrt5ROZ0OdmLrkk0pGJiESWl3e85puP5yxj+NIB2OKOdK+4irVTleBFRCBKzuTb/r05U8au5Y3X63HOOZGORkSk4IiKJF++PEwZXS/SYYiIFDhR0V0jIiLBKcmLiEQxJXkRkSimJC8iEsWU5EVEopiSvIhIFFOSFxGJYkryIiJRTEleRCSKhZTkzex1M/vFzJaZ2edmVjHLuufMbK2ZrTaz60IPVURE8irUM/lZQEPn3N+AX4HnAMzsfOAOoAHQEhhsZjEhtiUiInkUUpJ3zs10zmUEns4HagUetwbGOefSnHMbgLVAs1DaEhGRvAtngbJ2wPjA49PwJ/0jkgLLjmJmcUBc4Ok+M1sdQgxVgNQQXu8VxZU3iitvFFfeRGNcdXNacdwkb2azgWAzo3Z3zn0R2KY7kAGMOfKyINu7YPt3zsUD8ceLIzfMLME5FxuOfYWT4sobxZU3iitvilpcx03yzrkWx1pvZvcCNwJXO+eOJPIkoHaWzWoBm080SBEROTGhjq5pCXQBbnLOHciyajJwh5mVMrN6wFnAT6G0JSIieRdqn/xAoBQwy/wTZc93znVyzq00swnAKvzdOA855zJDbCs3wtLt4wHFlTeKK28UV94Uqbjszx4WERGJNrrjVUQkiinJi4hEsUKV5M3sVjNbaWY+M4vNtu64ZRTMrJ6ZLTCzNWY23sxKehTneDNbEvjaaGZLcthuo5ktD2yX4EUs2drrbWabssTWKoftWgaO41oz65oPceVYHiPbdp4fr+P97IHBBOMD6xeY2elexBGk3dpm9o2ZJQb+Bh4Lss0VZrY7y++3Zz7Fdszfi/kNCByzZWbWJB9iOifLcVhiZnvM7PFs2+TL8TKz981sm5mtyLKsspnNCuSiWWZWKYfX3hvYZk1gJGPeOecKzRdwHnAOMBeIzbL8fGAp/ovA9YB1QEyQ108A7gg8Hgp0zoeY+wM9c1i3EaiSj8evN/D0cbaJCRy/+kDJwHE93+O4rgWKBx6/BrwWieOVm58deBAYGnh8BzA+n353NYAmgcfl8ZcRyR7bFcDU/Ho/5fb3ArQCZuC/f+YSYEE+xxcDbAXqRuJ4Af8AmgArsizrB3QNPO4a7D0PVAbWB75XCjyulNf2C9WZvHMu0TkX7I7Y45ZRMP/wn6uATwOLRgFtvIw30OZtwFgv2wmzZsBa59x651w6MA7/8fWMy7k8Rn7Lzc/eGv97B/zvpasDv2dPOee2OOcWBx7vBRLJ4S7yAqg1MNr5zQcqmlmNfGz/amCdc+63fGzzf5xz3wE7si3O+j7KKRddB8xyzu1wzu3EXyusZV7bL1RJ/hhOA/7I8jxYGYVTgF1ZkkmOpRbC6HIg2Tm3Jof1DphpZosC5R3yw8OBj8zv5/ARMTfH0kvt8J/1BeP18crNz/6/bQLvpd3431v5JtBF1BhYEGT1pWa21MxmmFmDfArpeL+XSL+n7iDnE61IHC+A6s65LeD/Bw5UC7JNWI5bOGvXhIXlooxCsJcFWZZ9bGiuSy3kRi7jbMuxz+Ivc85tNrNq+O81+CXwX/+EHSsuYAjQF//P3Rd/V1K77LsI8tqQx9nm5njZ0eUxsgv78coeZpBlnr6P8srMygGfAY875/ZkW70Yf5fEvsD1lkn4b0T02vF+LxE7ZoHrbjcRqJCbTaSOV26F5bgVuCTvjlNGIQe5KaOQiv9jYvHAGVhIpRaOF6eZFQduAZoeYx+bA9+3mdnn+LsLQkpauT1+ZjYcmBpklSclKXJxvIKVx8i+j7Afr2xy87Mf2SYp8Ds+maM/invCzErgT/BjnHMTs6/PmvSdc9PNbLCZVXHOeVqMKxe/l0iWObkeWOycS86+IlLHKyDZzGo457YEuq62BdkmCf91gyNq4b8emSfR0l1z3DIKgcTxDfCfwKJ7gZw+GYRDC+AX51xSsJVmVtbMyh95jP/i44pg24ZLtn7Qm3NobyFwlvlHIpXE/1F3ssdx5VQeI+s2+XG8cvOzT8b/3gH/e+nrnP4phVOg338EkOicezOHbU49cn3AzJrh//ve7nFcufm9TAbuCYyyuQTYfaSrIh/k+Gk6Escri6zvo5xy0VfAtWZWKdC1em1gWd54fWU5nF/4E1MSkAYkA19lWdcd/8iI1cD1WZZPB2oGHtfHn/zXAp8ApTyMdSTQKduymsD0LLEsDXytxN9t4fXx+xBYDiwLvMlqZI8r8LwV/tEb6/IprrX4+x6XBL6GZo8rv45XsJ8d6IP/HxBA6cB7Z23gvVTf6+MTaPfv+D+qL8tynFoBnY68z4CHA8dmKf4L2M3zIa6gv5dscRkwKHBMl5NlZJzHsZ2EP2mfnGVZvh8v/P9ktgCHA/nrAfzXceYAawLfKwe2jQXey/LadoH32lrg/hNpX2UNRESiWLR014iISBBK8iIiUUxJXkQkiinJi4hEMSV5EZEopiQvIhLFlORFRKLY/wNQ0IMDOgZ73wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Parameter (name=fc.weight) [[2.0064354]]\n", "Parameter (name=fc.bias) [2.9529438]\n" ] } ], "source": [ "\n", "from mindspore.train.callback import LossMonitor\n", "\n", "epoch = 1\n", "imageshow_cb = ImageShowCallback(net, eval_data)\n", "model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)\n", "\n", "plot_model_and_datasets(net, eval_data)\n", "for net_param in net.trainable_params():\n", " print(net_param, net_param.asnumpy())" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2020-09-14T04:00:18.787349Z", "start_time": "2020-09-14T04:00:18.784236Z" } }, "source": [ "训练完成后打印出最终模型的权重参数,其中weight接近于2.0,bias接近于3.0,模型训练完成,符合预期。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 总结\n", "\n", "本次体验我们了解了线性拟合的算法原理,并在MindSpore框架下实现了相应的算法定义,了解了线性拟合这类的线性回归模型在MindSpore中的训练过程,并最终拟合出了一条接近目标函数的模型函数。另外有兴趣的可以调整数据集的生成区间从(-10,10)扩展到(-100,100),看看权重值是否更接近目标函数;调整学习率大小,看看拟合的效率是否有变化;当然也可以探索如何使用MindSpore拟合$f(x)=ax^2+bx+c$这类的二次函数或者更高次的函数。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }