mindspore.mint.triangular_solve

View Source On AtomGit
mindspore.mint.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False)[source]

Solve a system of equations with a square upper or lower triangular invertible matrix A and multiple right-hand sides b.

In symbols, it solves \(A X = b\) and assumes A is square upper-triangular (or lower-triangular if upper = False) and does not have zeros on the diagonal.

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • b (Tensor) – A tensor of shape \((*, M, K)\) where * is zero or more batch dimensions.

  • A (Tensor) – A tensor of shape \((*, M, M)\) where * is zero or more batch dimensions.

  • upper (bool, optional) – Whether A is upper or lower triangular. If False, A is lower triangular. Default True.

  • transpose (bool, optional) – Solve \(op(A) X = b\) where \(op(A) = A^T\) if this flag is True, and \(op(A) = A\) if it is False. Default False.

  • unitriangular (bool, optional) – Whether A is unit triangular. If True, the diagonal elements of A are assumed to be 1 and not referenced from A. Default False.

Returns

A tuple of X and A.

Raises
  • ValueError – If the rank of b or A is not in the range of \([2, 6]\).

  • ValueError – If the shapes of b and A are not matched.

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> b = mindspore.mint.ones((2, 3, 4))
>>> A = mindspore.mint.randn(2, 3, 3).triu()
>>> A
Tensor(shape=[2, 3, 3], dtype=Float32, value=
[[[-2.30607644e-01,  4.21591580e-01, -3.21297437e-01],
  [ 0.00000000e+00,  1.52568102e+00,  1.92393506e+00],
  [ 0.00000000e+00,  0.00000000e+00, -5.91036975e-01]],
 [[-1.01349080e+00,  1.60536671e+00,  6.57448649e-01],
  [ 0.00000000e+00, -5.75404823e-01,  2.84124088e+00],
  [ 0.00000000e+00,  0.00000000e+00, -1.64982307e+00]]])
>>> mindspore.mint.triangular_solve(b, A)[0]
Tensor(shape=[2, 3, 4], dtype=Float32, value=
[[[ 3.11981082e+00,  3.11981082e+00,  3.11981082e+00,  3.11981082e+00],
  [ 2.78904009e+00,  2.78904009e+00,  2.78904009e+00,  2.78904009e+00],
  [-1.69194150e+00, -1.69194150e+00, -1.69194150e+00, -1.69194150e+00]],
 [[-8.87352180e+00, -8.87352180e+00, -8.87352180e+00, -8.87352180e+00],
  [-4.73084164e+00, -4.73084164e+00, -4.73084164e+00, -4.73084164e+00],
  [-6.06125593e-01, -6.06125593e-01, -6.06125593e-01, -6.06125593e-01]]])
>>> A = mindspore.mint.randn(2, 3, 3).tril()
>>> A
Tensor(shape=[2, 3, 3], dtype=Float32, value=
[[[-2.55578011e-01,  0.00000000e+00,  0.00000000e+00],
  [-3.43545020e-01, -6.31254315e-01,  0.00000000e+00],
  [-7.08323777e-01,  1.44433156e-01,  7.35948741e-01]],
 [[ 1.84799409e+00,  0.00000000e+00,  0.00000000e+00],
  [ 9.20043513e-02, -1.35322273e+00,  0.00000000e+00],
  [ 1.86560547e+00, -1.13146865e+00,  6.90285027e-01]]])
>>> mindspore.mint.triangular_solve(b, A)[0]
Tensor(shape=[2, 3, 4], dtype=Float32, value=
[[[-3.91269970e+00, -3.91269970e+00, -3.91269970e+00, -3.91269970e+00],
  [-1.58414757e+00, -1.58414757e+00, -1.58414757e+00, -1.58414757e+00],
  [ 1.35879028e+00,  1.35879028e+00,  1.35879028e+00,  1.35879028e+00]],
 [[ 5.41127264e-01,  5.41127264e-01,  5.41127264e-01,  5.41127264e-01],
  [-7.38976657e-01, -7.38976657e-01, -7.38976657e-01, -7.38976657e-01],
  [ 1.44867694e+00,  1.44867694e+00,  1.44867694e+00,  1.44867694e+00]]])