# Network Construction [](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/source_en/migration_guide/model_development/model_and_cell.md) ## Basic Logic The basic logic of PyTorch and MindSpore is shown below:  It can be seen that PyTorch and MindSpore generally require network definition, forward computation, backward computation, and gradient update steps in the implementation process. - Network definition: In the network definition, the desired forward network, loss function, and optimizer are generally defined. To define the forward network in Net(), PyTorch network inherits from nn.Module; similarly, MindSpore network inherits from nn.Cell. In MindSpore, the loss function and optimizers can be customized in addition to using those provided in MindSpore. You can refer to [Model Module Customization](https://mindspore.cn/docs/en/r2.5.0/model_train/index.html). Interfaces such as functional/nn can be used to splice the required forward networks, loss functions and optimizers. - Forward computation: Run the instantiated network to get the logit, and use the logit and target as inputs to calculate the loss. It should be noted that if the forward function has more than one output, you need to pay attention to the effect of more than one output on the result when calculating the backward function. - Backward computation: After getting the loss, we can do the backward calculation. In PyTorch the gradient can be computed using loss.backward(), and in MindSpore, the gradient can be computed by first defining the backward propagation equation net_backward using mindspore.grad(), and then passing the input into net_backward. If the forward function has more than one output, you can set has_aux to True to ensure that only the first output is involved in the derivation, and the other outputs will be returned directly in the backward calculation. For the difference in interface usage in the backward calculation, see [Automatic Differentiation](./gradient.md). - Gradient update: Update the computed gradient into the Parameters of the network. Use optim.step() in PyTorch, while in MindSpore, pass the gradient of the Parameter into the defined optim to complete the gradient update. ## Network Basic Unit: Cell MindSpore uses [Cell](https://www.mindspore.cn/docs/en/r2.5.0/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell) to construct graphs. You need to define a class that inherits the `Cell` base class, declare the required APIs and submodules in `init`, and perform calculation in `construct`. `Cell` compiles a computational graph in `GRAPH_MODE` (static graph mode). It is used as the basic module of neural network in `PYNATIVE_MODE` (dynamic graph mode). The basic `Cell` setup process in PyTorch and MindSpore are as follows:
PyTorch | MindSpore |
```python import torch.nn as torch_nn class MyCell_pt(torch_nn.Module): def __init__(self, forward_net): super(MyCell_pt, self).__init__() self.net = forward_net self.relu = torch_nn.ReLU() def forward(self, x): y = self.net(x) return self.relu(y) inner_net_pt = torch_nn.Conv2d(120, 240, kernel_size=4, bias=False) pt_net = MyCell_pt(inner_net_pt) for i in pt_net.parameters(): print(i.shape) ``` Outputs: ```text torch.Size([240, 120, 4, 4]) ``` |
```python import mindspore.nn as nn import mindspore.ops as ops class MyCell(nn.Cell): def __init__(self, forward_net): super(MyCell, self).__init__(auto_prefix=True) self.net = forward_net self.relu = ops.ReLU() def construct(self, x): y = self.net(x) return self.relu(y) inner_net = nn.Conv2d(120, 240, 4, has_bias=False) my_net = MyCell(inner_net) print(my_net.trainable_params()) ``` Outputs: ```text [Parameter (name=net.weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] ``` |
PyTorch sets the model data type | MindSpore sets the model data type |
```python import torch import torch.nn as nn class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(3, 12, kernel_size=3, padding=1), nn.BatchNorm2d(12), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.layer2 = nn.Sequential( nn.Conv2d(12, 4, kernel_size=3, padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) self.pool = nn.AdaptiveMaxPool2d((5, 5)) self.fc = nn.Linear(100, 10) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.view(x.size(0), -1) out = self.fc(x) return out net = Network() net = net.to(torch.float32) for name, module in net.named_modules(): if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): module.to(torch.float32) loss = nn.CrossEntropyLoss(reduction='mean') loss = loss.to(torch.float32) ``` |
```python import mindspore as ms from mindspore import nn # Define model class Network(nn.Cell): def __init__(self): super().__init__() self.layer1 = nn.SequentialCell([ nn.Conv2d(3, 12, kernel_size=3, pad_mode='pad', padding=1), nn.BatchNorm2d(12), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ]) self.layer2 = nn.SequentialCell([ nn.Conv2d(12, 4, kernel_size=3, pad_mode='pad', padding=1), nn.BatchNorm2d(4), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ]) self.pool = nn.AdaptiveMaxPool2d((5, 5)) self.fc = nn.Dense(100, 10) def construct(self, x): x = self.layer1(x) x = self.layer2(x) x = self.pool(x) x = x.view((-1, 100)) out = nn.Dense(x) return out net = Network() net.to_float(ms.float16) #Add the float16 flag to all operations in the net. The framework adds the cast method to the input during compilation. for _, cell in net.cells_and_names(): if isinstance(cell, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): cell.to_float(ms.float32) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean').to_float(ms.float32) net_with_loss = nn.WithLossCell(net, loss_fn=loss) ``` |
PyTorch | MindSpore |
```python import torch.nn as nn net = nn.Linear(2, 1) for name, param in net.named_parameters(): print("Parameter Name:", name) for name, param in net.named_parameters(): if "bias" in name: param.requires_grad = False for name, param in net.named_parameters(): if param.requires_grad: print("Parameter Name:", name) ``` Outputs: ```text Parameter Name: weight Parameter Name: bias Parameter Name: weight ``` |
```python import mindspore.nn as nn net = nn.Dense(2, 1, has_bias=True) print(net.trainable_params()) for param in net.trainable_params(): param_name = param.name if "bias" in param_name: param.requires_grad = False print(net.trainable_params()) ``` Outputs: ```text [Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)] [Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True)] ``` |
PyTorch | MindSpore |
```python import torch import torch.nn as nn linear_layer = nn.Linear(2, 1, bias=True) linear_layer.weight.data.fill_(1.0) linear_layer.bias.data.zero_() print("Original linear layer parameters:") print(linear_layer.weight) print(linear_layer.bias) torch.save(linear_layer.state_dict(), 'linear_layer_params.pth') new_linear_layer = nn.Linear(2, 1, bias=True) new_linear_layer.load_state_dict(torch.load('linear_layer_params.pth')) # Print the loaded Parameter, which should be the same as the original Parameter print("Loaded linear layer parameters:") print(new_linear_layer.weight) print(new_linear_layer.bias) ``` Outputs: ```text Original linear layer parameters: Parameter containing: tensor([[1., 1.]], requires_grad=True) Parameter containing: tensor([0.], requires_grad=True) Loaded linear layer parameters: Parameter containing: tensor([[1., 1.]], requires_grad=True) Parameter containing: tensor([0.], requires_grad=True) ``` |
```python import mindspore as ms import mindspore.ops as ops import mindspore.nn as nn net = nn.Dense(2, 1, has_bias=True) for param in net.get_parameters(): print(param.name, param.data.asnumpy()) ms.save_checkpoint(net, "dense.ckpt") dense_params = ms.load_checkpoint("dense.ckpt") print(dense_params) new_params = {} for param_name in dense_params: print(param_name, dense_params[param_name].data.asnumpy()) new_params[param_name] = ms.Parameter(ops.ones_like(dense_params[param_name].data), name=param_name) ms.load_param_into_net(net, new_params) for param in net.get_parameters(): print(param.name, param.data.asnumpy()) ``` Outputs: ```text weight [[-0.0042482 -0.00427286]] bias [0.] {'weight': Parameter (name=weight, shape=(1, 2), dtype=Float32, requires_grad=True), 'bias': Parameter (name=bias, shape=(1,), dtype=Float32, requires_grad=True)} weight [[-0.0042482 -0.00427286]] bias [0.] weight [[1. 1.]] bias [1.] ``` |
torch.nn.init | mindspore.common.initializer |
```python import torch x = torch.empty(2, 2) torch.nn.init.uniform_(x) ``` |
```python import mindspore from mindspore.common.initializer import initializer, Uniform x = initializer(Uniform(), [1, 2, 3], mindspore.float32) ``` |
PyTorch | MindSpore |
```python import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) # Add submodules using add_module self.add_module('conv3', nn.Conv2d(64, 128, 3, 1)) self.sequential_block = nn.Sequential( nn.ReLU(), nn.Conv2d(128, 256, 3, 1), nn.ReLU() ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.sequential_block(x) return x module = MyModule() # Iterate through all submodules (both direct and indirect) using named_modules for name, module_instance in module.named_modules(): print(f"Module name: {name}, type: {type(module_instance)}") ``` Output: ```text Module name: , type: |
```python from mindspore import nn class MyCell(nn.Cell): def __init__(self): super(MyCell, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) # Add submodules using insert_child_to_cell self.insert_child_to_cell('conv3', nn.Conv2d(64, 128, 3, 1)) self.sequential_block = nn.SequentialCell( nn.ReLU(), nn.Conv2d(128, 256, 3, 1), nn.ReLU() ) def construct(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.sequential_block(x) return x module = MyCell() # Iterate through all submodules (both direct and indirect) using cells_and_names for name, cell_instance in module.cells_and_names(): print(f"Cell name: {name}, type: {type(cell_instance)}") ``` Output: ```text Cell name: , type: |
PyTorch | MindSpore |
```python import torch torch_net = torch.nn.Linear(3, 4) torch_net.cpu() ``` |
```python import mindspore mindspore.set_device(device_target="CPU") ms_net = mindspore.nn.Dense(3, 4) ``` |