# WithLossCell [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.3/docs/mindspore/programming_guide/source_zh_cn/withlosscell.md) `WithLossCell`本质上是一个包含损失函数的`Cell`,构造`WithLossCell`需要事先定义好网络和损失函数。 下面通过一个实例来介绍其具体的使用, 首先需要构造一个网络,内容如下: ```python import numpy as np import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum import mindspore.ops as ops context.set_context(mode=context.GRAPH_MODE, device_target="GPU") class LeNet5(nn.Cell): """ Lenet network Args: num_class (int): Number of classes. Default: 10. num_channel (int): Number of channels. Default: 1. Returns: Tensor, output tensor Examples: >>> LeNet(num_class=10) """ def __init__(self, num_class=10, num_channel=1, include_top=True): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.include_top = include_top if self.include_top: self.flatten = nn.Flatten() self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) def construct(self, x): x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) if not self.include_top: return x x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x ``` 下面是`WithLossCell`的使用实例,分别定义好网络和损失函数,然后创建一个`WithLossCell`,传入输入数据和标签数据,`WithLossCell`内部根据网络和损失函数返回计算结果。 ```python data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = LeNet5() criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') net_with_criterion = WithLossCell(net, criterion) loss = net_with_criterion(data, label) print("+++++++++Loss+++++++++++++") print(loss) ``` ```text +++++++++Loss+++++++++++++ 2.302585 ```