{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型层\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.3/docs/mindspore/programming_guide/source_zh_cn/layer.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.3/programming_guide/zh_cn/mindspore_layer.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV9sYXllci5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述\n", "\n", "在讲述了`Cell`的使用方法后可知,MindSpore能够以`Cell`为基类构造网络结构。\n", "\n", "为了方便用户的使用,MindSpore框架内置了大量的模型层,用户可以通过接口直接调用。\n", "\n", "同样,用户也可以自定义模型,此内容在“构建自定义网络”中介绍。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 内置模型层\n", "\n", "MindSpore框架在`mindspore.nn`的layer层内置了丰富的接口,主要内容如下:\n", "\n", "- 激活层\n", "\n", " 激活层内置了大量的激活函数,在定义网络结构中经常使用。激活函数为网络加入了非线性运算,使得网络能够拟合效果更好。\n", "\n", " 主要接口有`Softmax`、`Relu`、`Elu`、`Tanh`、`Sigmoid`等。\n", " \n", "\n", "- 基础层\n", "\n", " 基础层实现了网络中一些常用的基础结构,例如全连接层、Onehot编码、Dropout、平铺层等都在此部分实现。\n", "\n", " 主要接口有`Dense`、`Flatten`、`Dropout`、`Norm`、`OneHot`等。\n", " \n", "\n", "- 容器层\n", "\n", " 容器层主要功能是实现一些存储多个Cell的数据结构。\n", "\n", " 主要接口有`SequentialCell`、`CellList`等。\n", " \n", "\n", "- 卷积层\n", "\n", " 卷积层提供了一些卷积计算的功能,如普通卷积、深度卷积和卷积转置等。\n", "\n", " 主要接口有`Conv2d`、`Conv1d`、`Conv2dTranspose`、`Conv1dTranspose`等。\n", " \n", "\n", "- 池化层\n", "\n", " 池化层提供了平均池化和最大池化等计算的功能。\n", "\n", " 主要接口有`AvgPool2d`、`MaxPool2d`和`AvgPool1d`。\n", " \n", "\n", "- 嵌入层\n", "\n", " 嵌入层提供word embedding的计算功能,将输入的单词映射为稠密向量。\n", "\n", " 主要接口有`Embedding`、`EmbeddingLookup`、`EmbeddingLookUpSplitMode`等。\n", " \n", "\n", "- 长短记忆循环层\n", "\n", " 长短记忆循环层提供LSTM计算功能。其中`LSTM`内部会调用`LSTMCell`接口,`LSTMCell`是一个LSTM单元,对一个LSTM层做运算,当涉及多LSTM网络层运算时,使用`LSTM`接口。\n", "\n", " 主要接口有`LSTM`和`LSTMCell`。\n", " \n", "\n", "- 标准化层\n", "\n", " 标准化层提供了一些标准化的方法,即通过线性变换等方式将数据转换成均值和标准差。\n", "\n", " 主要接口有`BatchNorm1d`、`BatchNorm2d`、`LayerNorm`、`GroupNorm`、`GlobalBatchNorm`等。\n", " \n", "\n", "- 数学计算层\n", "\n", " 数学计算层提供一些算子拼接而成的计算功能,例如数据生成和一些数学计算等。\n", "\n", " 主要接口有`ReduceLogSumExp`、`Range`、`LinSpace`、`LGamma`等。\n", " \n", "\n", "- 图片层\n", "\n", " 图片计算层提供了一些矩阵计算相关的功能,将图片数据进行一些变换与计算。\n", "\n", " 主要接口有`ImageGradients`、`SSIM`、`MSSSIM`、`PSNR`、`CentralCrop`等。\n", " \n", "\n", "- 量化层\n", "\n", " 量化是指将数据从float的形式转换成一段数据范围的int类型,所以量化层提供了一些数据量化的方法和模型层结构封装。\n", "\n", " 主要接口有`Conv2dBnAct`、`DenseBnAct`、`Conv2dBnFoldQuant`、`LeakyReLUQuant`等。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 应用实例\n", "\n", "MindSpore的模型层在`mindspore.nn`下,使用方法如下所示:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-02-08T01:01:31.944015Z", "start_time": "2021-02-08T01:01:31.917571Z" } }, "outputs": [], "source": [ "import mindspore.nn as nn\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')\n", " self.bn = nn.BatchNorm2d(64)\n", " self.relu = nn.ReLU()\n", " self.flatten = nn.Flatten()\n", " self.fc = nn.Dense(64 * 222 * 222, 3)\n", "\n", " def construct(self, x):\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " x = self.flatten(x)\n", " out = self.fc(x)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "依然是上述网络构造的用例,从这个用例中可以看出,程序调用了`Conv2d`、`BatchNorm2d`、`ReLU`、`Flatten`和`Dense`模型层的接口。\n", "\n", "在`Net`初始化方法里被定义,然后在`construct`方法里真正运行,这些模型层接口有序的连接,形成一个可执行的网络。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }