Function Differences with torch.Tensor.flatten

View Source On Gitee

torch.Tensor.flatten

torch.Tensor.flatten(input, start_dim=0, end_dim=-1)

For more information, see torch.Tensor.flatten.

mindspore.Tensor.flatten

mindspore.Tensor.flatten(order='C', *, start_dim=0, end_dim=-1)

For more information, see mindspore.Tensor.flatten.

Usage

torch.Tensor.flatten does not support the order option for prioritizing row or column flatten.

mindspore.Tensor.flatten prioritizes row or column flatten by order to “C” or “F”.

Code Example

import mindspore as ms

a = ms.Tensor([[1,2], [3,4]], ms.int32)
print(a.flatten())
# [1 2 3 4]
print(a.flatten('F'))
# [1 3 2 4]
print(a.flatten(start_dim=1))
# [[1 2]
#  [3 4]]

import torch

b = torch.tensor([[1, 2], [3, 4]])
print(torch.Tensor.flatten(b))
# tensor([1, 2, 3, 4])
print(torch.Tensor.flatten(b, start_dim=1))
# tensor([[1, 2],
#         [3, 4]])