Tensor and Parameter
Tensor
A Tensor is a basic data structure in MindSpore network operations whose functions is like a Numpy array (ndarray). MindSpore uses Tensor to represent the data passed in the neural network.
For operations such as Tensor creation, Tensor operations, and Tensor to NumPy conversion, see Tensor Tensor.
Tensor Index Support
Single-level and multi-level Tensor indexing is supported on both PyNative and Graph mode.
Index Values
The index value can be int, bool, None, ellipsis, slice, Tensor, List, or Tuple.
intindex valueSingle-level and multi-level
intindex values are supported. The single-levelintindex value istensor_x[int_index], and the multi-levelintindex value istensor_x[int_index0][int_index1]....The
intindex value is obtained on dimension 0 and is less than the length of dimension 0. After the position data corresponding to dimension 0 is obtained, dimension 0 is eliminated.For example, if a single-level
intindex value is obtained for a tensor whoseshapeis(3, 4, 5), the obtainedshapeis(4, 5).The multi-level index value can be understood as obtaining the current-level
intindex value based on the previous-level index value.For example:
import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(2 * 3 * 2).reshape((2, 3, 2))) data_single = tensor_x[0] data_multi = tensor_x[0][1] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[0 1] [2 3] [4 5]] data_multi: [2 3]
boolindex valueSingle-level and multi-level
boolindex values are supported. The single-levelboolindex value istensor_x[True], and the multi-levelTrueindex value istensor_x[True][True]....The
Trueindex value operation is obtained on dimension 0. After all data is obtained, a dimension is extended on theaxis=0axis. The length of the dimension is 1.Falsewill introduce0in the shape, thus onlyTureis supported now.For example, if a single-level
Trueindex value is obtained from a tensor whoseshapeis(3, 4, 5), the obtainedshapeis(1, 3, 4, 5).The multi-level index value can be understood as obtaining the current-level
boolindex value based on the previous-level index value.For example:
import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(2 * 3).reshape((2, 3))) data_single = tensor_x[True] data_multi = tensor_x[True][True] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[[0 1 2] [3 4 5]]] data_multi: [[[[0 1 2] [3 4 5]]]]
Noneindex valueThe
Noneindex value is the same as theTrueindex value. For details, see theTrueindex value.ellipsisindex valueSingle-level and multi-level
ellipsisindex values are supported. The single-levelellipsisindex value istensor_x[...], and the multi-levelellipsisindex value istensor_x[...][...]....The
ellipsisindex value is obtained on all dimensions to get the original data without any change. Generally, it is used as a component of theTupleindex. TheTupleindex is described as follows.For example, if the
ellipsisindex value is obtained for a tensor whoseshapeis(3, 4, 5), the obtainedshapeis still(3, 4, 5).For example:
import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(2 * 3).reshape((2, 3))) data_single = tensor_x[...] data_multi = tensor_x[...][...] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[0 1 2] [3 4 5]] data_multi: [[0 1 2] [3 4 5]]
sliceindex valueSingle-level and multi-level
sliceindex values are supported. The single-levelsliceindex value istensor_x[slice_index], and the multi-levelsliceindex value istensor_x[slice_index0][slice_index1]....The
sliceindex value is obtained on dimension 0. The element of the sliced position on dimension 0 is obtained. Theslicedoes not reduce the dimension even if the length is 1, which is different from theintindex value.For example,
tensor_x[0:1:1] != tensor_x[0], becauseshape_former = (1,) + shape_latter.The multi-level index value can be understood as obtaining the current-level
sliceindex value based on the previous-level index value.sliceconsists ofstart,stop, andstep. The default value ofstartis 0, the default value ofstopis the length of the dimension, and the default value ofstepis 1.Example:
tensor_x[:] == tensor_x[0:length:1].For example:
import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(4 * 2 * 2).reshape((4, 2, 2))) data_single = tensor_x[1:4:2] data_multi = tensor_x[1:4:2][1:] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[[ 4 5] [ 6 7]] [[12 13] [14 15]]] data_multi: [[[12 13] [14 15]]]
Tensorindex valueSingle-level and multi-level
Tensorindex values are supported. The single-levelTensorindex value istensor_x[tensor_index], and the multi-levelTensorindex value istensor_x[tensor_index0][tensor_index1]....The
Tensorindex value is obtained on dimension 0, and the element in the corresponding position of dimension 0 is obtained.The data type of the
Tensorindex can be int and bool.When the data type is int, it must be one of int8, int16, int32, and int64. The element cannot be a negative number, and the value must be less than the length of dimension 0.
The
Tensorindex value is obtained bydata_shape = tensor_inde4x.shape + tensor_x.shape[1:].For example, if the index value is obtained for a tensor whose shape is
(6, 4, 5)by using a tensor whose shape is(2, 3), the obtained shape is(2, 3, 4, 5).When the data type is bool, the dimension of the result obtained by the
Tensorindex istensor_x.ndim - tensor_index.ndim + 1.Let the number of True in
tensor_indexbenum_trueand the shape oftensor_xbe(N0, N1, ... Ni-1, Ni, Ni+1, ..., Nk), the shape oftensor_indexis(N0, N1, ... Ni-1), then the shape of the returned value is(num_true, Ni+1, Ni+2, ... , Nk).For example:
from mindspore import dtype as mstype import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor([1, 2, 3]) tensor_index = ms.Tensor([True, False, True], dtype=mstype.bool_) output = tensor_x[tensor_index] print(output)
The result is as follows:
[1 3]
The multi-level index value can be understood as obtaining the current-level
Tensorindex value based on the previous-level index value.For example:
from mindspore import dtype as mstype import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(4 * 2 * 3).reshape((4, 2, 3))) tensor_index0 = ms.Tensor(np.array([[1, 2], [0, 3]]), mstype.int32) tensor_index1 = ms.Tensor(np.array([[0, 0]]), mstype.int32) data_single = tensor_x[tensor_index0] data_multi = tensor_x[tensor_index0][tensor_index1] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[[[ 6 7 8] [ 9 10 11]] [[12 13 14] [15 16 17]]] [[[ 0 1 2] [ 3 4 5]] [[18 19 20] [21 22 23]]]] data_multi: [[[[[ 6 7 8] [ 9 10 11]] [[12 13 14] [15 16 17]]] [[[ 6 7 8] [ 9 10 11]] [[12 13 14] [15 16 17]]]]]Listindex valueSingle-level and multi-level
Tensorindex values are supported. The single-levelListindex value istensor_x[list_index], and the multi-levelListindex value istensor_x[list_index0][list_index1]....The
Listindex value is obtained on dimension 0, and the element in the corresponding position of dimension 0 is obtained.The data type of the
Listindex must be all bool, all int or mixed of them. TheListelements of int type must be in the range of [-dimension_shape,dimension_shape-1] and the count ofListelements with bool type must be the same as thedimension_shapeof dimension 0 and will perform as to filter the corresponding element of the Tenson data. If the above two types appear simultaneously, theListelements with the bool type will be converted to1/0forTrue/False.The
Tensorindex value is obtained bydata_shape = tensor_inde4x.shape + tensor_x.shape[1:].For example, if the index value is obtained for a tensor whose shape is
(6, 4, 5)by using a tensor whose shape is(2, 3), the obtained shape is(2, 3, 4, 5).The multi-level index value can be understood as obtaining the current-level
Tensorindex value based on the previous-level index value.For example:
import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(4 * 2 * 3).reshape((4, 2, 3))) list_index0 = [1, 2, 0] list_index1 = [True, False, True] data_single = tensor_x[list_index0] data_multi = tensor_x[list_index0][list_index1] print('data_single:') print(data_single) print('data_multi:') print(data_multi)
The result is as follows:
data_single: [[[ 6 7 8] [ 9 10 11]] [[12 13 14] [15 16 17]] [[ 0 1 2] [ 3 4 5]]] data_multi: [[[ 6 7 8] [ 9 10 11]] [[ 0 1 2] [ 3 4 5]]]
Tupleindex valueThe data type of the
Tupleindex can beint,bool,None,slice,ellipsis,Tensor,List, orTuple. Single-level and multi-levelTupleindex values are supported. For the single-levelTupleindex, the value istensor_x[tuple_index]. For the multi-levelTupleindex, the value istensor_x[tuple_index0][tuple_index1].... The regulations of elementsListandTupleare the same as that of single indexListindex. The regulations of others are the same to the responding single element type.Elements in the
Tupleindex can be sort out byBasic IndexorAdvanced Index.slice,ellipsis,intandNoneareBasic Indexandbool,Tensor,List,TupleareAdvanced Index. In the Getitem Progress, all the elements of theAdvanced Indextype will be broadcast to the same shape, and the final shape will be inserted to the firstAdvanced Indexelement's position if they are continuous, else they will be inserted to the0position.In the index, the
Noneelements will expand the corresponding dimensions,boolelements will expand the corresponding dimension and be broadcast with the otherAdvanced Indexelement. The others elements except the type ofellipsis,bool, andNone, will correspond to each position dimension. That is, the 0th element inTupleoperates the 0th dimension, and the 1st element operates the 1st dimension. The index rule of each element is the same as the index value rule of the element type.The
Tupleindex contains a maximum of oneellipsis. The first half of theellipsisindex elements correspond to theTensordimensions starting from the dimension 0, and the second half of the index elements correspond to theTensordimensions starting from the last dimension. If other dimensions are not specified, all dimensions are obtained.The data type of
Tensorcontained in the element can be int or bool, and int type must be one of (int8, int16, int32, int64). In addition, the value ofTensorelement must be non-negative and less than the length of the operation dimension.For example,
tensor_x[0:3, 1, tensor_index] == tensor_x[(0:3, 1, tensor_index)], because0:3, 1, tensor_indexis aTuple.The multi-level index value can be understood as obtaining the current-level
Tupleindex value based on the previous-level index value.For example:
from mindspore import dtype as mstype import mindspore as ms import mindspore.numpy as np tensor_x = ms.Tensor(np.arange(2 * 3 * 4).reshape((2, 3, 4))) tensor_index = ms.Tensor(np.array([[1, 2, 1], [0, 3, 2]]), mstype.int32) data = tensor_x[1, 0:1, tensor_index] print('data:') print(data)
The result is as follows:
data: [[[13] [14] [13]] [[12] [15] [14]]]
Index Value Assignment
For a case like: tensor_x[index] = value, the type of the index can be int, bool, ellipsis, slice, None, Tensor, List, orTuple.
The type of the assigned value can be Number, Tuple, List, or Tensor, the value will be converted to Tensor and casted to the same dtype as the original tensor (tensor_x) before being assigned.
When value is Number, all position elements obtained from the tensor_x[index] will be updated to Number.
When value is a tensor whose type is Tuple, List or Tensor and only contains Number, the value.shape needs to be able to be broadcasted to tensor_x[index].shape. After the value' is broadcasted and casted to Tensor, the elements with the position tensor_x[index] will be updated with the value broadcast(Tensor(value)).
When value is Tuple/List, and contains mixtures of Number, Tuple, List and Tensor, only one-dimensional Tuple and List are currently supported.
When value is Tuple or List, and contains Tensor, all the non-Tensor elements in value will be converted to Tensor first, and then these Tensor values are packed on the axis=0 axis and become new Tensor. In this case, the value is assigned according to the rule of assigning the value to Tensor. All Tensors must have the same dtype.
Index value assignment can be understood as assigning values to indexed position elements based on certain rules. All index value assignment does not change the original shape of Tensor.
If there are multiple index elements in indices that correspond to the same position, the value of that position in the output will be nondeterministic. For more details, please see:TensorScatterUpdate
Only single-bracket indexing is supported (
tensor_x[index] = value), multi-bracket(tensor_x[index1][index2]... = value) is not supported.
intindex value assignmentSingle-level
intindex value assignments are supported. The single-levelintindex value assignment istensor_x[int_index] = u.For example:
import mindspore.numpy as np tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_x[1] = 88.0 tensor_y[1] = np.array([66, 88, 99]).astype(np.float32) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y)
The result is as follows:
tensor_x: [[ 0. 1. 2.] [88. 88. 88.]] tensor_y: [[ 0. 1. 2.] [66. 88. 99.]]
boolindex value assignmentSingle-level
boolindex value assignments are supported. The single-levelintindex value assignment istensor_x[bool_index] = u.For example:
import mindspore.numpy as np tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_x[True] = 88.0 tensor_y[True]= np.array([66, 88, 99]).astype(np.float32) tensor_z[True] = (66, 88, 99) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z)
The result is as follows:
tensor_x: [[88. 88. 88.] [88. 88. 88.]] tensor_y: [[66. 88. 99.] [66. 88. 99.]] tensor_z: [[66. 88. 99.] [66. 88. 99.]]
ellipsisindex value assignmentSingle-level
ellipsisindex value assignments are supported. The single-levelellipsisindex value assignment istensor_x[...] = u.For example:
import mindspore.numpy as np tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_x[...] = 88.0 tensor_y[...] = np.array([[22, 44, 55], [22, 44, 55]]) tensor_z[...] = ([11, 22, 33], [44, 55, 66]) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z)
The result is as follows:
tensor_x: [[88. 88. 88.] [88. 88. 88.]] tensor_y: [[22. 44. 55.] [22. 44. 55.]] tensor_z: [[11. 22. 33.] [44. 55. 66.]]
sliceindex value assignmentSingle-level
sliceindex value assignments are supported. The single-levelsliceindex value assignment istensor_x[slice_index] = u.For example:
import mindspore.numpy as np tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_k = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_x[0:1] = 88.0 tensor_y[0:2] = 88.0 tensor_z[0:2] = np.array([[11, 12, 13], [11, 12, 13]]).astype(np.float32) tensor_k[0:2] = ([11, 12, 13], (14, 15, 16)) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z) print('tensor_k:') print(tensor_k)
The result is as follows:
tensor_x: [[88. 88. 88.] [ 3. 4. 5.] [ 6. 7. 8.]] tensor_y: [[88. 88. 88.] [88. 88. 88.] [ 6. 7. 8.]] tensor_z: [[11. 12. 13.] [11. 12. 13.] [ 6. 7. 8.]] tensor_k: [[11. 12. 13.] [14. 15. 16.] [ 6. 7. 8.]]
Noneindex value assignmentSingle-level
Noneindex value assignments are supported. The single-levelintindex value assignment istensor_x[none_index] = u.For example:
import mindspore.numpy as np tensor_x = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_y = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_z = np.arange(2 * 3).reshape((2, 3)).astype(np.float32) tensor_x[None] = 88.0 tensor_y[None]= np.array([66, 88, 99]).astype(np.float32) tensor_z[None] = (66, 88, 99) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z)
The result is as follows:
tensor_x: [[88. 88. 88.] [88. 88. 88.]] tensor_y: [[66. 88. 99.] [66. 88. 99.]] tensor_z: [[66. 88. 99.] [66. 88. 99.]]
Tensorindex value assignmentSingle-level
Tensorindex value assignments are supported. The single-levelTensorindex value assignment istensor_x[tensor_index] = u.Currently, the supported index types are
intandbool.An example of the
inttype is as follows:import mindspore.numpy as np tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_index = np.array([[2, 0, 2], [0, 2, 0], [0, 2, 0]], np.int32) tensor_x[tensor_index] = 88.0 tensor_y[tensor_index] = np.array([11.0, 12.0, 13.0]).astype(np.float32) tensor_z[tensor_index] = [11, 12, 13] print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z)
The result is as follows:
tensor_x: [[88. 88. 88.] [ 3. 4. 5.] [88. 88. 88.]] tensor_y: [[11. 12. 13.] [ 3. 4. 5.] [11. 12. 13.]] tensor_z: [[11. 12. 13.] [ 3. 4. 5.] [11. 12. 13.]]
An example of the
booltype is as follows:from mindspore import dtype as mstype import mindspore as ms tensor_x = ms.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], mstype.float32) tensor_index = ms.Tensor([True, False, True], mstype.bool_) tensor_x[tensor_index] = -1 print(tensor_x)
The result is as follows:
[[-1. -1. -1.] [ 3. 4. 5.] [-1. -1. -1.]]
Listindex value assignmentsingle-level
Listindex value assignments are supported. The single-levelListindex value assignment istensor_x[list_index] = u.The
Listindex value assignment is the same as that of theListindex value.For example:
import mindspore.numpy as np tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_index = np.array([[0, 1], [1, 0]]).astype(np.int32) tensor_x[[0,1]] = 88.0 tensor_y[[True, False, False]] = np.array([11, 12, 13]).astype(np.float32) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y)
The result is as follows:
tensor_x: [[88. 88. 88.] [88. 88. 88.] [ 6. 7. 8.]] tensor_y: [[11. 12. 13.] [ 3. 4. 5.] [ 6. 7. 8.]]
Tupleindex value assignmentsingle-level
Tupleindex value assignments are supported. The single-levelTupleindex value assignment istensor_x[tuple_index] = u.The
Tupleindex value assignment is the same as that of theTupleindex value, butNonetype is not supported now.For example:
import mindspore.numpy as np tensor_x = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_y = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_z = np.arange(3 * 3).reshape((3, 3)).astype(np.float32) tensor_index = np.array([0, 1]).astype(np.int32) tensor_x[1, 1:3] = 88.0 tensor_y[1:3, tensor_index] = 88.0 tensor_z[1:3, tensor_index] = np.array([11, 12]).astype(np.float32) print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y) print('tensor_z:') print(tensor_z)
The result is as follows:
tensor_x: [[ 0. 1. 2.] [ 3. 88. 88.] [ 6. 7. 8.]] tensor_y: [[ 0. 1. 2.] [88. 88. 5.] [88. 88. 8.]] tensor_z: [[ 0. 1. 2.] [11. 12. 5.] [11. 12. 8.]]
Index Value Augmented-assignment
Index value augmented-assignment supports seven augmented_assignment operations: +=, -=, *=, /=, %=, **=, and //=. The rules and constraints of index and value are the same as index assignment. The index value supports eight types: int, bool, ellipsis, slice, None, tensor, list and tuple. The assignment value supports four types: Number, Tensor, Tuple and List.
Index value augmented-assignment can be regarded as taking the value of the position elements to be indexed according to certain rules, and then performing operator operation with value. Finally, assign the operation result to the origin Tensor. All index augmented-assignments will not change the shape of the original Tensor.
If there are multiple index elements in indices that correspond to the same position, the value of that position in the output will be nondeterministic. For more details, please see:TensorScatterUpdate
Currently indices that contain
True,FalseandNoneare not supported.
Rules and constraints:
Compared with index assignment, the process of value and operation is increased. The constraint rules of
indexare the same asindexin Index Value, and supportInt,Bool,Tensor,Slice,Ellipse,None,ListandTuple. The values ofIntcontained in the above types of data should be in[-dim_size, dim_size-1]within the closed range.The constraint rules of
valuein the operation process are the same as those ofvaluein index assignment. The type ofvalueneeds to be one of (Number,Tensor,List,Tuple). And ifvalue's type is notnumber,value.shapeshould be able to broadcast totensor_x[index].shape.For example:
import mindspore as ms tensor_x = ms.Tensor(np.arange(3 * 4).reshape(3, 4).astype(np.float32)) tensor_y = ms.Tensor(np.arange(3 * 4).reshape(3, 4).astype(np.float32)) tensor_x[[0, 1], 1:3] += 2 tensor_y[[1], ...] -= [4, 3, 2, 1] print('tensor_x:') print(tensor_x) print('tensor_y:') print(tensor_y)
The result is as follows:
tensor_x: [[ 0. 3. 4. 3.] [ 4. 7. 8. 7.] [ 8. 9. 10. 11.]] tensor_y: [[ 0. 1. 2. 3.] [ 0. 2. 4. 6.] [ 8. 9. 10. 11.]]
Tensor View
MindSpore allows a tensor to be a view-class Operators of an existing tensor. View tensor shares the same underlying data with its base tensor. Supporting View avoids explicit data copy, thus allows us to do fast and memory efficient reshaping, slicing and element-wise operations."
For example, to get a view of an existing tensor t, you can call t.view(…).
from mindspore import Tensor
import numpy as np
t = Tensor(np.array([[1, 2, 3], [2, 3, 4]], dtype=np.float32))
b = t.view((3, 2))
# Modifying view tensor changes base tensor as well.
b[0][0] = 100
print(t[0][0])
# 100
Since views share underlying data with its base tensor, if you edit the data in the view, it will be reflected in the base tensor as well.
Typically a MindSpore op returns a new tensor as output, e.g. add(). But in case of view ops, outputs are views of input tensors to avoid unnecessary data copy. No data movement occurs when creating a view, view tensor just changes the way it interprets the same data. Taking a view of contiguous tensor could potentially produce a non-contiguous tensor. Users should pay additional attention as contiguity might have implicit performance impact. transpose() is a common example.
from mindspore import Tensor
import numpy as np
base = Tensor([[0, 1], [2, 3]])
base.is_contiguous()
# True
t = base.transpose(1, 0) # t is a view of base. No data movement happened here.
t.is_contiguous()
# False
# To get a contiguous tensor, call `.contiguous()` to enforce
# copying data when `t` is not contiguous.
c = t.contiguous()
c.is_contiguous()
# True
view-class Operators
For reference, here’s a full list of view ops in MindSpore:
Parameter
Parameter is a special class of Tensor, which is a variable whose value can be updated during model training. MindSpore provides the mindspore.Parameter class for Parameter construction. In order to distinguish between Parameter for different purposes, two different categories of Parameter are defined below. In order to distinguish between Parameter for different purposes, two different categories of Parameter are defined below:
Trainable parameter. Tensor that is updated after the gradient is obtained according to the backward propagation algorithm during model training, and
requires_gradneeds to be set toTrue.Untrainable parameters. Tensor that does not participate in backward propagation needs to update values (e.g.
meanandvarvariables in BatchNorm), whenrequires_gradneeds to be set toFalse.
Parameter is set to
requires_grad=Trueby default.
We construct a simple fully-connected layer as follows:
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
self.b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias
def construct(self, x):
z = ops.matmul(x, self.w) + self.b
return z
net = Network()
In the __init__ method of Cell, we define two parameters w and b and configure name for namespace management. Use self.attr in the construct method to call directly to participate in Tensor operations.
Obtaining Parameter
After constructing the neural network layer by using Cell+Parameter, we can use various methods to obtain the Parameter managed by Cell.
Obtaining a Single Parameter
To get a particular parameter individually, just call a member variable of a Python class directly.
print(net.b.asnumpy())
[-1.2192779 -0.36789745 0.0946381 ]
Obtaining a Trainable Parameter
Trainable parameters can be obtained by using the Cell.trainable_params method, and this interface is usually called when configuring the optimizer.
print(net.trainable_params())
[Parameter (name=w, shape=(5, 3), dtype=Float32, requires_grad=True), Parameter (name=b, shape=(3,), dtype=Float32, requires_grad=True)]
Obtaining All Parameters
Use the Cell.get_parameters() method to get all parameters, at which point a Python iterator will be returned.
print(type(net.get_parameters()))
<class 'generator'>
Or you can call Cell.parameters_and_names to return the parameter names and parameters.
for name, param in net.parameters_and_names():
print(f"{name}:\n{param.asnumpy()}")
w:
[[ 4.15680408e-02 -1.20311625e-01 5.02573885e-02]
[ 1.22175144e-04 -1.34980649e-01 1.17642188e+00]
[ 7.57667869e-02 -1.74758151e-01 -5.19092619e-01]
[-1.67846107e+00 3.27240258e-01 -2.06452996e-01]
[ 5.72323874e-02 -8.27963874e-02 5.94243526e-01]]
b:
[-1.2192779 -0.36789745 0.0946381 ]
Modifying the Parameter
Modifying Parameter Values Directly
Parameter is a special kind of Tensor, so its value can be modified by using the Tensor index modification.
net.b[0] = 1.
print(net.b.asnumpy())
[ 1. -0.36789745 0.0946381 ]
Overriding the Modified Parameter Values
The Parameter.set_data method can be called to override the Parameter by using a Tensor with the same Shape. This method is commonly used for Cell traversal initialization by using Initializer.
net.b.set_data(Tensor([3, 4, 5]))
print(net.b.asnumpy())
[3. 4. 5.]
Modifying Parameter Values During Runtime
In deep learning model training, the core function of parameters is the iterative updating of their values to optimize model performance. Due to the compiled design of MindSpore's Accelerating with Static Graphs, it is necessary at this point to use the mindspore.ops.assign interface to assign parameters. This method is commonly used in Custom Optimizer scenarios. The following is a simple sample modification of parameter values during runtime:
import mindspore as ms
@ms.jit
def modify_parameter():
b_hat = ms.Tensor([7, 8, 9])
ops.assign(net.b, b_hat)
return True
modify_parameter()
print(net.b.asnumpy())
[7. 8. 9.]
Parameter Tuple
ParameterTuple, variable tuple, used to store multiple Parameter, is inherited from tuple tuples, and provides cloning function.
The following example provides the ParameterTuple creation and clone method:
from mindspore.common.initializer import initializer
from mindspore import ParameterTuple
# Create ParameterTuple
x = Parameter(default_input=ms.Tensor(np.arange(2 * 3).reshape((2, 3))), name="x")
y = Parameter(default_input=initializer('ones', [1, 2, 3], ms.float32), name='y')
z = Parameter(default_input=2.0, name='z')
params = ParameterTuple((x, y, z))
# Clone ParameterTuple
params_copy = params.clone("params_copy")
print(params)
print(params_copy)
(Parameter (name=x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=z, shape=(), dtype=Float32, requires_grad=True))
(Parameter (name=params_copy.x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=params_copy.y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=params_copy.z, shape=(), dtype=Float32, requires_grad=True))