mindspore.mint.triangular_solve

查看源文件
mindspore.mint.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False)[源代码]

求解正上三角形或下三角形可逆矩阵 A 和包含多个元素的右侧边 b 的方程组的解。

用符号表示,它求解方程 \(A X = b\),并假设矩阵 A 是一个方阵,且为上三角矩阵(如果 upper = False,则为下三角矩阵),并且其对角线上没有零元素。

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • b (Tensor) - shape是 \((*, M, K)\) 的tensor,其中*表示任意数量的维度。

  • A (Tensor) - shape是 \((*, M, M)\) 的tensor,其中*表示任意数量的维度。

  • upper (bool,可选) - 矩阵 A 是否为上三角矩阵。如果为 False,则矩阵 A 为下三角矩阵。默认 True

  • transpose (bool,可选) - 求解方程 \(op(A) X = b\)。其中如果此标志为 True,则 \(op(A) = A^T\);如果为 False,则 \(op(A) = A\)。默认 False

  • unitriangular (bool,可选) - 矩阵 A 是否为单位三角矩阵。如果为 True,则假设矩阵 A 的对角线元素为 1,并且不会从矩阵 A 中引用这些对角线元素。默认 False

返回:

包含 XA 的tuple。

异常:
  • ValueError - 如果 b 或者 A 的维度不在 \([2, 6]\) 范围内。

  • ValueError - 如果 bA 的shape不匹配。

支持平台:

Ascend

样例:

>>> import mindspore
>>> b = mindspore.mint.ones((2, 3, 4))
>>> A = mindspore.mint.randn(2, 3, 3).triu()
>>> A
Tensor(shape=[2, 3, 3], dtype=Float32, value=
[[[-2.30607644e-01,  4.21591580e-01, -3.21297437e-01],
  [ 0.00000000e+00,  1.52568102e+00,  1.92393506e+00],
  [ 0.00000000e+00,  0.00000000e+00, -5.91036975e-01]],
 [[-1.01349080e+00,  1.60536671e+00,  6.57448649e-01],
  [ 0.00000000e+00, -5.75404823e-01,  2.84124088e+00],
  [ 0.00000000e+00,  0.00000000e+00, -1.64982307e+00]]])
>>> mindspore.mint.triangular_solve(b, A)[0]
Tensor(shape=[2, 3, 4], dtype=Float32, value=
[[[ 3.11981082e+00,  3.11981082e+00,  3.11981082e+00,  3.11981082e+00],
  [ 2.78904009e+00,  2.78904009e+00,  2.78904009e+00,  2.78904009e+00],
  [-1.69194150e+00, -1.69194150e+00, -1.69194150e+00, -1.69194150e+00]],
 [[-8.87352180e+00, -8.87352180e+00, -8.87352180e+00, -8.87352180e+00],
  [-4.73084164e+00, -4.73084164e+00, -4.73084164e+00, -4.73084164e+00],
  [-6.06125593e-01, -6.06125593e-01, -6.06125593e-01, -6.06125593e-01]]])
>>> A = mindspore.mint.randn(2, 3, 3).tril()
>>> A
Tensor(shape=[2, 3, 3], dtype=Float32, value=
[[[-2.55578011e-01,  0.00000000e+00,  0.00000000e+00],
  [-3.43545020e-01, -6.31254315e-01,  0.00000000e+00],
  [-7.08323777e-01,  1.44433156e-01,  7.35948741e-01]],
 [[ 1.84799409e+00,  0.00000000e+00,  0.00000000e+00],
  [ 9.20043513e-02, -1.35322273e+00,  0.00000000e+00],
  [ 1.86560547e+00, -1.13146865e+00,  6.90285027e-01]]])
>>> mindspore.mint.triangular_solve(b, A)[0]
Tensor(shape=[2, 3, 4], dtype=Float32, value=
[[[-3.91269970e+00, -3.91269970e+00, -3.91269970e+00, -3.91269970e+00],
  [-1.58414757e+00, -1.58414757e+00, -1.58414757e+00, -1.58414757e+00],
  [ 1.35879028e+00,  1.35879028e+00,  1.35879028e+00,  1.35879028e+00]],
 [[ 5.41127264e-01,  5.41127264e-01,  5.41127264e-01,  5.41127264e-01],
  [-7.38976657e-01, -7.38976657e-01, -7.38976657e-01, -7.38976657e-01],
  [ 1.44867694e+00,  1.44867694e+00,  1.44867694e+00,  1.44867694e+00]]])