mindspore.numpy
MindSpore NumPy工具包提供了一系列类NumPy接口。用户可以使用类NumPy语法在MindSpore上进行模型的搭建。
MindSpore Numpy具有四大功能模块:Array生成、Array操作、逻辑运算和数学运算。
在API示例中,常用的模块导入方法如下:
import mindspore.numpy as np
说明
MindSpore numpy通过组装底层算子来提供与numpy一致的编程体验接口,方便开发人员使用和代码移植。相比于MindSpore的function和ops接口,与原始numpy的接口格式及行为一致性更好,以便于用户理解和使用。注意:由于兼容numpy的考虑,部分接口的性能可能弱于function和ops接口。使用者可以按需选择不同类型的接口。
Array生成
生成类算子用来生成和构建具有指定数值、类型和形状的数组(Tensor)。
构建数组代码示例:
import mindspore.numpy as np
import mindspore.ops as ops
input_x = np.array([1, 2, 3], np.float32)
print("input_x =", input_x)
print("type of input_x =", ops.typeof(input_x))
运行结果如下:
input_x = [1. 2. 3.]
type of input_x = Tensor[Float32]
除了使用上述方法来创建外,也可以通过以下几种方式创建。
- 生成具有相同元素的数组 - 生成具有相同元素的数组代码示例: - input_x = np.full((2, 3), 6, np.float32) print(input_x) - 运行结果如下: - [[6. 6. 6.] [6. 6. 6.]] - 生成指定形状的全1数组,示例: - input_x = np.ones((2, 3), np.float32) print(input_x) - 运行结果如下: - [[1. 1. 1.] [1. 1. 1.]] 
- 生成具有某个范围内的数值的数组 - 生成指定范围内的等差数组代码示例: - input_x = np.arange(0, 5, 1) print(input_x) - 运行结果如下: - [0 1 2 3 4] 
- 生成特殊类型的数组 - 生成给定对角线处下方元素为1,上方元素为0的矩阵,示例: - input_x = np.tri(3, 3, 1) print(input_x) - 运行结果如下: - [[1. 1. 0.] [1. 1. 1.] [1. 1. 1.]] - 生成对角线为1,其他元素为0的二维矩阵,示例: - input_x = np.eye(2, 2) print(input_x) - 运行结果如下: - [[1. 0.] [0. 1.]] 
| API Name | Description | Supported Platforms | 
| Returns evenly spaced values within a given interval. | 
 | |
| Creates a tensor. | 
 | |
| Converts the input to tensor. | 
 | |
| Similar to asarray, converts the input to a float tensor. | 
 | |
| Returns the Bartlett window. | 
 | |
| Returns the Blackman window. | 
 | |
| Returns a tensor copy of the given object. | 
 | |
| Extracts a diagonal or construct a diagonal array. | 
 | |
| Returns the indices to access the main diagonal of an array. | 
 | |
| Creates a two-dimensional array with the flattened input as a diagonal. | 
 | |
| Returns specified diagonals. | 
 | |
| Returns a new array of given shape and type, without initializing entries. | 
 | |
| Returns a new array with the same shape and type as a given array. | 
 | |
| Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. | 
 | |
| Returns a new tensor of given shape and type, filled with fill_value. | 
 | |
| Returns a full array with the same shape and type as a given array. | 
 | |
| Returns numbers spaced evenly on a log scale (a geometric progression). | 
 | |
| Returns the Hamming window. | 
 | |
| Returns the Hanning window. | 
 | |
| Function to calculate only the edges of the bins used by the histogram function. | 
 | |
| Returns the identity tensor. | 
 | |
| Returns an array representing the indices of a grid. | 
 | |
| Constructs an open mesh from multiple sequences. | 
 | |
| Returns evenly spaced values within a given interval. | 
 | |
| Returns numbers spaced evenly on a log scale. | 
 | |
| Returns coordinate matrices from coordinate vectors. | 
 | |
| mgrid is an  | 
 | |
| ogrid is an  | 
 | |
| Returns a new tensor of given shape and type, filled with ones. | 
 | |
| Returns an array of ones with the same shape and type as a given array. | 
 | |
| Pads an array. | 
 | |
| Returns a new Tensor with given shape and dtype, filled with random numbers from the uniform distribution on the interval \([0, 1)\). | 
 | |
| Return random integers from minval (inclusive) to maxval (exclusive). | 
 | |
| Returns a new Tensor with given shape and dtype, filled with a sample (or samples) from the standard normal distribution. | 
 | |
| Returns the sum along diagonals of the array. | 
 | |
| Returns a tensor with ones at and below the given diagonal and zeros elsewhere. | 
 | |
| Returns a lower triangle of a tensor. | 
 | |
| Returns the indices for the lower-triangle of an (n, m) array. | 
 | |
| Returns the indices for the lower-triangle of arr. | 
 | |
| Returns an upper triangle of a tensor. | 
 | |
| Returns the indices for the upper-triangle of an (n, m) array. | 
 | |
| Returns the indices for the upper-triangle of arr. | 
 | |
| Generates a Vandermonde matrix. | 
 | |
| Returns a new tensor of given shape and type, filled with zeros. | 
 | |
| Returns an array of zeros with the same shape and type as a given array. | 
 | 
Array操作
操作类算子主要进行数组的维度变换,分割和拼接等。
- 数组维度变换 - 矩阵转置,代码示例: - input_x = np.arange(10).reshape(5, 2) output = np.transpose(input_x) print(output) - 运行结果如下: - [[0 2 4 6 8] [1 3 5 7 9]] - 交换指定轴,代码示例: - input_x = np.ones((1, 2, 3)) output = np.swapaxes(input_x, 0, 1) print(output.shape) - 运行结果如下: - (2, 1, 3) 
- 数组分割 - 将输入数组平均切分为多个数组,代码示例: - input_x = np.arange(9) output = np.split(input_x, 3) print(output) - 运行结果如下: - (Tensor(shape=[3], dtype=Int32, value= [0, 1, 2]), Tensor(shape=[3], dtype=Int32, value= [3, 4, 5]), Tensor(shape=[3], dtype=Int32, value= [6, 7, 8])) 
- 数组拼接 - 将两个数组按照指定轴进行拼接,代码示例: - input_x = np.arange(0, 5) input_y = np.arange(10, 15) output = np.concatenate((input_x, input_y), axis=0) print(output) - 运行结果如下: - [ 0 1 2 3 4 10 11 12 13 14] 
| API Name | Description | Supported Platforms | 
| Appends values to the end of a tensor. | 
 | |
| Applies a function to 1-D slices along the given axis. | 
 | |
| Applies a function repeatedly over multiple axes. | 
 | |
| Find the indices of Tensor elements that are non-zero, grouped by element. | 
 | |
| Splits a tensor into multiple sub-tensors. | 
 | |
| Returns a string representation of the data in an array. | 
 | |
| Converts inputs to arrays with at least one dimension. | 
 | |
| Reshapes inputs as arrays with at least two dimensions. | 
 | |
| Reshapes inputs as arrays with at least three dimensions. | 
 | |
| Broadcasts any number of arrays against each other. | 
 | |
| Broadcasts an array to a new shape. | 
 | |
| Construct an array from an index array and a list of arrays to choose from. | 
 | |
| Stacks 1-D tensors as columns into a 2-D tensor. | 
 | |
| Joins a sequence of tensors along an existing axis. | 
 | |
| Splits a tensor into multiple sub-tensors along the 3rd axis (depth). | 
 | |
| Stacks tensors in sequence depth wise (along the third axis). | 
 | |
| Expands the shape of a tensor. | 
 | |
| Reverses the order of elements in an array along the given axis. | 
 | |
| Flips the entries in each row in the left/right direction. | 
 | |
| Flips the entries in each column in the up/down direction. | 
 | |
| Splits a tensor into multiple sub-tensors horizontally (column-wise). | 
 | |
| Stacks tensors in sequence horizontally. | 
 | |
| Find the intersection of two Tensors. | 
 | |
| Moves axes of an array to new positions. | 
 | |
| Evaluates a piecewise-defined function. | 
 | |
| Returns a contiguous flattened tensor. | 
 | |
| Repeats elements of an array. | 
 | |
| Reshapes a tensor without changing its data. | 
 | |
| Rolls a tensor along given axes. | 
 | |
| Rolls the specified axis backwards, until it lies in the given position. | 
 | |
| Rotates a tensor by 90 degrees in the plane specified by axes. | 
 | |
| Returns an array drawn from elements in choicelist, depending on conditions. | 
 | |
| Find the set difference of two Tensors. | 
 | |
| Returns the number of elements along a given axis. | 
 | |
| Splits a tensor into multiple sub-tensors along the given axis. | 
 | |
| Removes single-dimensional entries from the shape of a tensor. | 
 | |
| Joins a sequence of arrays along a new axis. | 
 | |
| Interchanges two axes of a tensor. | 
 | |
| Takes elements from an array along an axis. | 
 | |
| Takes values from the input array by matching 1d index and data slices. | 
 | |
| Constructs an array by repeating a the number of times given by reps. | 
 | |
| Reverses or permutes the axes of a tensor; returns the modified tensor. | 
 | |
| Finds the unique elements of a tensor. | 
 | |
| Converts a flat index or array of flat indices into a tuple of coordinate arrays. | 
 | |
| Splits a tensor into multiple sub-tensors vertically (row-wise). | 
 | |
| Stacks tensors in sequence vertically. | 
 | |
| Returns elements chosen from x or y depending on condition. | 
 | 
逻辑运算
逻辑运算类算子主要进行各类逻辑相关的运算。
相等(equal)和小于(less)计算代码示例如下:
input_x = np.arange(0, 5)
input_y = np.arange(0, 10, 2)
output = np.equal(input_x, input_y)
print("output of equal:", output)
output = np.less(input_x, input_y)
print("output of less:", output)
运行结果如下:
output of equal: [ True False False False False]
output of less: [False  True  True  True  True]
| API Name | Description | Supported Platforms | 
| Returns True if input arrays have same shapes and all elements equal. | 
 | |
| Returns True if input arrays are shape consistent and all elements equal. | 
 | |
| Returns the truth value of  | 
 | |
| Returns the truth value of  | 
 | |
| Returns the truth value of  | 
 | |
| Tests whether each element of a 1-D array is also present in a second array. | 
 | |
| Returns a boolean tensor where two tensors are element-wise equal within a tolerance. | 
 | |
| Tests element-wise for finiteness (not infinity or not Not a Number). | 
 | |
| Calculates element in test_elements, broadcasting over element only. | 
 | |
| Tests element-wise for positive or negative infinity. | 
 | |
| Tests element-wise for NaN and return result as a boolean array. | 
 | |
| Tests element-wise for negative infinity, returns result as bool array. | 
 | |
| Tests element-wise for positive infinity, returns result as bool array. | 
 | |
| Returns True if the type of element is a scalar type. | 
 | |
| Returns the truth value of  | 
 | |
| Returns the truth value of  | 
 | |
| Computes the truth value of x1 AND x2 element-wise. | 
 | |
| Computes the truth value of NOT a element-wise. | 
 | |
| Computes the truth value of x1 OR x2 element-wise. | 
 | |
| Computes the truth value of x1 XOR x2, element-wise. | 
 | |
| Returns (x1 != x2) element-wise. | 
 | |
| Returns element-wise True where signbit is set (less than zero). | 
 | |
| Tests whether any array element along a given axis evaluates to True. | 
 | 
数学运算
数学运算类算子包括各类数学相关的运算:加减乘除乘方,以及指数、对数等常见函数等。
数学计算支持类似NumPy的广播特性。
- 加法 - 以下代码实现了 input_x 和 input_y 两数组相加的操作: - input_x = np.full((3, 2), [1, 2]) input_y = np.full((3, 2), [3, 4]) output = np.add(input_x, input_y) print(output) - 运行结果如下: - [[4 6] [4 6] [4 6]] 
- 矩阵乘法 - 以下代码实现了 input_x 和 input_y 两矩阵相乘的操作: - input_x = np.arange(2*3).reshape(2, 3).astype('float32') input_y = np.arange(3*4).reshape(3, 4).astype('float32') output = np.matmul(input_x, input_y) print(output) - 运行结果如下: - [[20. 23. 26. 29.] [56. 68. 80. 92.]] 
- 求平均值 - 以下代码实现了求 input_x 所有元素的平均值的操作: - input_x = np.arange(6).astype('float32') output = np.mean(input_x) print(output) - 运行结果如下: - 2.5
- 指数 - 以下代码实现了自然常数 e 的 input_x 次方的操作: - input_x = np.arange(5).astype('float32') output = np.exp(input_x) print(output) - 运行结果如下: - [ 1. 2.7182817 7.389056 20.085537 54.59815 ] 
| API Name | Description | Supported Platforms | 
| Calculates the absolute value element-wise. | 
 | |
| Adds arguments element-wise. | 
 | |
| Returns the maximum of an array or maximum along an axis. | 
 | |
| Returns the minimum of an array or minimum along an axis. | 
 | |
| Trigonometric inverse cosine, element-wise. | 
 | |
| Inverse hyperbolic cosine, element-wise. | 
 | |
| Inverse sine, element-wise. | 
 | |
| Inverse hyperbolic sine element-wise. | 
 | |
| Trigonometric inverse tangent, element-wise. | 
 | |
| Element-wise arc tangent of \(x1/x2\) choosing the quadrant correctly. | 
 | |
| Inverse hyperbolic tangent element-wise. | 
 | |
| Returns the indices of the maximum values along an axis. | 
 | |
| Returns the indices of the minimum values along an axis. | 
 | |
| Evenly round to the given number of decimals. | 
 | |
| Computes the weighted average along the specified axis. | 
 | |
| Count number of occurrences of each value in array of non-negative ints. | 
 | |
| Computes the bit-wise AND of two arrays element-wise. | 
 | |
| Computes the bit-wise OR of two arrays element-wise. | 
 | |
| Computes the bit-wise XOR of two arrays element-wise. | 
 | |
| Returns the cube-root of a tensor, element-wise. | 
 | |
| Returns the ceiling of the input, element-wise. | 
 | |
| Clips (limits) the values in an array. | 
 | |
| Returns the discrete, linear convolution of two one-dimensional sequences. | 
 | |
| Changes the sign of x1 to that of x2, element-wise. | 
 | |
| Returns Pearson product-moment correlation coefficients. | 
 | |
| Cross-correlation of two 1-dimensional sequences. | 
 | |
| Cosine element-wise. | 
 | |
| Hyperbolic cosine, element-wise. | 
 | |
| Counts the number of non-zero values in the tensor x. | 
 | |
| Estimates a covariance matrix, given data and weights. | 
 | |
| Returns the cross product of two (arrays of) vectors. | 
 | |
| Returns the cumulative product of elements along a given axis. | 
 | |
| Returns the cumulative sum of the elements along a given axis. | 
 | |
| Converts angles from degrees to radians. | 
 | |
| Calculates the n-th discrete difference along the given axis. | 
 | |
| Returns the indices of the bins to which each value in input array belongs. | 
 | |
| Returns a true division of the inputs, element-wise. | 
 | |
| Returns element-wise quotient and remainder simultaneously. | 
 | |
| Returns the dot product of two arrays. | 
 | |
| The differences between consecutive elements of a tensor. | 
 | |
| Calculates the exponential of all elements in the input array. | 
 | |
| Calculates  | 
 | |
| Calculates  | 
 | |
| Rounds to nearest integer towards zero. | 
 | |
| First array elements raised to powers from second array, element-wise. | 
 | |
| Returns the floor of the input, element-wise. | 
 | |
| Returns the largest integer smaller or equal to the division of the inputs. | 
 | |
| Returns the element-wise remainder of division. | 
 | |
| Returns the greatest common divisor of  | 
 | |
| Returns the gradient of a N-dimensional array. | 
 | |
| Computes the Heaviside step function. | 
 | |
| Computes the histogram of a dataset. | 
 | |
| Computes the multidimensional histogram of some data. | 
 | |
| Computes the multidimensional histogram of some data. | 
 | |
| Given the "legs" of a right triangle, returns its hypotenuse. | 
 | |
| Returns the inner product of two tensors. | 
 | |
| One-dimensional linear interpolation for monotonically increasing sample points. | 
 | |
| Computes bit-wise inversion, or bit-wise NOT, element-wise. | 
 | |
| Kronecker product of two arrays. | 
 | |
| Returns the lowest common multiple of  | 
 | |
| Returns the natural logarithm, element-wise. | 
 | |
| Base-10 logarithm of x. | 
 | |
| Returns the natural logarithm of one plus the input array, element-wise. | 
 | |
| Base-2 logarithm of x. | 
 | |
| Logarithm of the sum of exponentiations of the inputs. | 
 | |
| Logarithm of the sum of exponentiations of the inputs in base of 2. | 
 | |
| Returns the matrix product of two arrays. | 
 | |
| Raises a square matrix to the (integer) power n. | 
 | |
| Returns the element-wise maximum of array elements. | 
 | |
| Computes the arithmetic mean along the specified axis. | 
 | |
| Element-wise minimum of tensor elements. | 
 | |
| Computes the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. | 
 | |
| Multiplies arguments element-wise. | 
 | |
| Return the cumulative sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. | 
 | |
| Return the maximum of an array or maximum along an axis, ignoring any NaNs. | 
 | |
| Computes the arithmetic mean along the specified axis, ignoring NaNs. | 
 | |
| Returns the minimum of array elements over a given axis, ignoring any NaNs. | 
 | |
| Computes the standard deviation along the specified axis, while ignoring NaNs. | 
 | |
| Returns the sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. | 
 | |
| Computes the variance along the specified axis, while ignoring NaNs. | 
 | |
| Numerical negative, element-wise. | 
 | |
| Matrix or vector norm. | 
 | |
| Computes the outer product of two vectors. | 
 | |
| Finds the sum of two polynomials. | 
 | |
| Returns the derivative of the specified order of a polynomial. | 
 | |
| Returns an antiderivative (indefinite integral) of a polynomial. | 
 | |
| Finds the product of two polynomials. | 
 | |
| Difference (subtraction) of two polynomials. | 
 | |
| Evaluates a polynomial at specific values. | 
 | |
| Numerical positive, element-wise. | 
 | |
| First array elements raised to powers from second array, element-wise. | 
 | |
| Returns the data type with the smallest size and smallest scalar kind. | 
 | |
| Range of values (maximum - minimum) along an axis. | 
 | |
| Converts angles from radians to degrees. | 
 | |
| Converts angles from degrees to radians. | 
 | |
| Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index. | 
 | |
| Returns the reciprocal of the argument, element-wise. | 
 | |
| Returns element-wise remainder of division. | 
 | |
| Returns the type that results from applying the type promotion rules to the arguments. | 
 | |
| Rounds elements of the array to the nearest integer. | 
 | |
| Finds indices where elements should be inserted to maintain order. | 
 | |
| Returns an element-wise indication of the sign of a number. | 
 | |
| Trigonometric sine, element-wise. | 
 | |
| Hyperbolic sine, element-wise. | 
 | |
| Returns the non-negative square-root of an array, element-wise. | 
 | |
| Returns the element-wise square of the input. | 
 | |
| Computes the standard deviation along the specified axis. | 
 | |
| Subtracts arguments, element-wise. | 
 | |
| Returns sum of array elements over a given axis. | 
 | |
| Computes tangent element-wise. | 
 | |
| Computes hyperbolic tangent element-wise. | 
 | |
| Computes tensor dot product along specified axes. | 
 | |
| Integrates along the given axis using the composite trapezoidal rule. | 
 | |
| Returns a true division of the inputs, element-wise. | 
 | |
| Returns the truncated value of the input, element-wise. | 
 | |
| Unwraps by changing deltas between values to  | 
 | |
| Computes the variance along the specified axis. | 
 | 
MindSpore Numpy与MindSpore特性结合
mindspore.numpy能够充分利用MindSpore的强大功能,实现算子的自动微分,并使用图模式加速运算,帮助用户快速构建高效的模型。同时,MindSpore还支持多种后端设备,包括Ascend、GPU和CPU等,用户可以根据自己的需求灵活设置。以下提供了几种常用方法:
- jit 装饰器: 将代码包裹进图模式,用于提高代码运行效率。 
- GradOperation: 用于自动求导。 
- mindspore.set_context: 用于设置运行模式和后端设备等。 
- mindspore.nn.Cell: 用于建立深度学习模型。 
使用示例如下:
- jit 装饰器使用示例 - 首先,以神经网络里经常使用到的矩阵乘与矩阵加算子为例: - import mindspore.numpy as np x = np.arange(8).reshape(2, 4).astype('float32') w1 = np.ones((4, 8)) b1 = np.zeros((8,)) w2 = np.ones((8, 16)) b2 = np.zeros((16,)) w3 = np.ones((16, 4)) b3 = np.zeros((4,)) def forward(x, w1, b1, w2, b2, w3, b3): x = np.dot(x, w1) + b1 x = np.dot(x, w2) + b2 x = np.dot(x, w3) + b3 return x print(forward(x, w1, b1, w2, b2, w3, b3)) - 运行结果如下: - [[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]] - 对上述示例,我们可以借助 jit 装饰器将所有算子编译到一张静态图里以加快运行效率,示例如下: - from mindspore import jit forward_compiled = jit(forward) print(forward(x, w1, b1, w2, b2, w3, b3)) - 运行结果如下: - [[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]] - 说明 - 目前静态图不支持在Python交互式模式下运行,并且有部分语法限制。 
- GradOperation使用示例 - GradOperation 可以实现自动求导。以下示例可以实现对上述没有用 jit 修饰的 forward 函数定义的计算求导。 - from mindspore import ops grad_all = ops.GradOperation(get_all=True) print(grad_all(forward)(x, w1, b1, w2, b2, w3, b3)) - 运行结果如下: - (Tensor(shape=[2, 4], dtype=Float32, value= [[ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02]]), Tensor(shape=[4, 8], dtype=Float32, value= [[ 2.56000000e+02, 2.56000000e+02, 2.56000000e+02 ... 2.56000000e+02, 2.56000000e+02, 2.56000000e+02], [ 3.84000000e+02, 3.84000000e+02, 3.84000000e+02 ... 3.84000000e+02, 3.84000000e+02, 3.84000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02 ... 5.12000000e+02, 5.12000000e+02, 5.12000000e+02] [ 6.40000000e+02, 6.40000000e+02, 6.40000000e+02 ... 6.40000000e+02, 6.40000000e+02, 6.40000000e+02]]), ... Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00])) - 如果要对 jit 修饰的 forward 计算求导,需要提前使用 set_context 设置运算模式为图模式,示例如下: - from mindspore import jit, set_context, GRAPH_MODE, ops set_context(mode=GRAPH_MODE) grad_all = ops.GradOperation(get_all=True) print(grad_all(jit(forward))(x, w1, b1, w2, b2, w3, b3)) - 运行结果如下: - (Tensor(shape=[2, 4], dtype=Float32, value= [[ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02, 5.12000000e+02]]), Tensor(shape=[4, 8], dtype=Float32, value= [[ 2.56000000e+02, 2.56000000e+02, 2.56000000e+02 ... 2.56000000e+02, 2.56000000e+02, 2.56000000e+02], [ 3.84000000e+02, 3.84000000e+02, 3.84000000e+02 ... 3.84000000e+02, 3.84000000e+02, 3.84000000e+02], [ 5.12000000e+02, 5.12000000e+02, 5.12000000e+02 ... 5.12000000e+02, 5.12000000e+02, 5.12000000e+02] [ 6.40000000e+02, 6.40000000e+02, 6.40000000e+02 ... 6.40000000e+02, 6.40000000e+02, 6.40000000e+02]]), ... Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00, 2.00000000e+00])) - 更多细节可参考 API GradOperation 。 
- mindspore.set_context使用示例 - MindSpore支持多后端运算,可以通过 mindspore.set_context 进行设置。mindspore.numpy 的多数算子可以使用图模式或者PyNative模式运行,也可以运行在CPU,CPU或者Ascend等多种后端设备上。 - from mindspore import set_context, GRAPH_MODE, PYNATIVE_MODE # Execution in static graph mode set_context(mode=GRAPH_MODE) # Execution in PyNative mode set_context(mode=PYNATIVE_MODE) # Execution on CPU backend set_context(device_target="CPU") # Execution on GPU backend set_context(device_target="GPU") # Execution on Ascend backend set_context(device_target="Ascend") ... - 更多细节可参考 API mindspore.set_context 。 
- mindspore.numpy使用示例 - 这里提供一个使用 mindspore.numpy 构建网络模型的示例。 - mindspore.numpy 接口可以定义在 nn.Cell 代码块内进行网络的构建,示例如下: - import mindspore.numpy as np from mindspore import set_context, GRAPH_MODE from mindspore.nn import Cell set_context(mode=GRAPH_MODE) x = np.arange(8).reshape(2, 4).astype('float32') w1 = np.ones((4, 8)) b1 = np.zeros((8,)) w2 = np.ones((8, 16)) b2 = np.zeros((16,)) w3 = np.ones((16, 4)) b3 = np.zeros((4,)) class NeuralNetwork(Cell): def construct(self, x, w1, b1, w2, b2, w3, b3): x = np.dot(x, w1) + b1 x = np.dot(x, w2) + b2 x = np.dot(x, w3) + b3 return x net = NeuralNetwork() print(net(x, w1, b1, w2, b2, w3, b3)) - 运行结果如下: - [[ 768. 768. 768. 768.] [2816. 2816. 2816. 2816.]]