Function Differences with torch.flatten

View Source On Gitee

torch.flatten

torch.flatten(
    input,
    start_dim=0,
    end_dim=-1
)

For more information, see torch.flatten.

mindspore.ops.flatten

mindspore.ops.flatten(input, order='C', *, start_dim=1, end_dim=-1)

For more information, see mindspore.ops.flatten.

Differences

PyTorch: Supports the flatten operation of elements by specified dimensions, where start_dim defaults to 0 and end_dim defaults to -1.

MindSpore:Supports the flatten operation of elements by specified dimensions, where start_dim defaults to 1 and end_dim defaults to -1. Prioritizes row or column flatten by order to “C” or “F”.

Categories

Subcategories

PyTorch

MindSpore

Differences

Parameter

Parameter 1

input

input

Same function

Parameter 2

-

order

Flatten order, PyTorch does not have this Parameter

Parameter 3

start_dim

start_dim

Same function

Parameter 4

end_dim

end_dim

Same function

Code Example

import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np

# MindSpore
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
output = ops.flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)

input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
output = ops.flatten(input_tensor, start_dim=2)
print(output.shape)
# Out:
# (1, 2, 12)

# PyTorch
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output1 = torch.flatten(input=input_tensor, start_dim=1)
print(output1.shape)
# Out:
# torch.Size([1, 24])

input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output2 = torch.flatten(input=input_tensor, start_dim=2)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])