mindspore.mint.triangular_solve
- mindspore.mint.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False)[source]
Solve a system of equations with a square upper or lower triangular invertible matrix A and multiple right-hand sides b.
In symbols, it solves \(A X = b\) and assumes A is square upper-triangular (or lower-triangular if
upper = False) and does not have zeros on the diagonal.Warning
This is an experimental API that is subject to change or deletion.
- Parameters
b (Tensor) – A tensor of shape \((*, M, K)\) where * is zero or more batch dimensions.
A (Tensor) – A tensor of shape \((*, M, M)\) where * is zero or more batch dimensions.
upper (bool, optional) – Whether A is upper or lower triangular. If
False, A is lower triangular. DefaultTrue.transpose (bool, optional) – Solve \(op(A) X = b\) where \(op(A) = A^T\) if this flag is
True, and \(op(A) = A\) if it isFalse. DefaultFalse.unitriangular (bool, optional) – Whether A is unit triangular. If
True, the diagonal elements of A are assumed to be 1 and not referenced from A. DefaultFalse.
- Returns
A tuple of X and A.
- Raises
ValueError – If the rank of b or A is not in the range of \([2, 6]\).
ValueError – If the shapes of b and A are not matched.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> b = mindspore.mint.ones((2, 3, 4)) >>> A = mindspore.mint.randn(2, 3, 3).triu() >>> A Tensor(shape=[2, 3, 3], dtype=Float32, value= [[[-2.30607644e-01, 4.21591580e-01, -3.21297437e-01], [ 0.00000000e+00, 1.52568102e+00, 1.92393506e+00], [ 0.00000000e+00, 0.00000000e+00, -5.91036975e-01]], [[-1.01349080e+00, 1.60536671e+00, 6.57448649e-01], [ 0.00000000e+00, -5.75404823e-01, 2.84124088e+00], [ 0.00000000e+00, 0.00000000e+00, -1.64982307e+00]]]) >>> mindspore.mint.triangular_solve(b, A)[0] Tensor(shape=[2, 3, 4], dtype=Float32, value= [[[ 3.11981082e+00, 3.11981082e+00, 3.11981082e+00, 3.11981082e+00], [ 2.78904009e+00, 2.78904009e+00, 2.78904009e+00, 2.78904009e+00], [-1.69194150e+00, -1.69194150e+00, -1.69194150e+00, -1.69194150e+00]], [[-8.87352180e+00, -8.87352180e+00, -8.87352180e+00, -8.87352180e+00], [-4.73084164e+00, -4.73084164e+00, -4.73084164e+00, -4.73084164e+00], [-6.06125593e-01, -6.06125593e-01, -6.06125593e-01, -6.06125593e-01]]]) >>> A = mindspore.mint.randn(2, 3, 3).tril() >>> A Tensor(shape=[2, 3, 3], dtype=Float32, value= [[[-2.55578011e-01, 0.00000000e+00, 0.00000000e+00], [-3.43545020e-01, -6.31254315e-01, 0.00000000e+00], [-7.08323777e-01, 1.44433156e-01, 7.35948741e-01]], [[ 1.84799409e+00, 0.00000000e+00, 0.00000000e+00], [ 9.20043513e-02, -1.35322273e+00, 0.00000000e+00], [ 1.86560547e+00, -1.13146865e+00, 6.90285027e-01]]]) >>> mindspore.mint.triangular_solve(b, A)[0] Tensor(shape=[2, 3, 4], dtype=Float32, value= [[[-3.91269970e+00, -3.91269970e+00, -3.91269970e+00, -3.91269970e+00], [-1.58414757e+00, -1.58414757e+00, -1.58414757e+00, -1.58414757e+00], [ 1.35879028e+00, 1.35879028e+00, 1.35879028e+00, 1.35879028e+00]], [[ 5.41127264e-01, 5.41127264e-01, 5.41127264e-01, 5.41127264e-01], [-7.38976657e-01, -7.38976657e-01, -7.38976657e-01, -7.38976657e-01], [ 1.44867694e+00, 1.44867694e+00, 1.44867694e+00, 1.44867694e+00]]])