比较与torch.cat的功能差异

查看源文件

torch.cat

torch.cat(
    tensors,
    dim=0,
    out=None
)

更多内容详见torch.cat

mindspore.ops.Concat

class mindspore.ops.Concat(
    axis=0
)(input_x)

更多内容详见mindspore.ops.Concat

使用方式

PyTorch: 输入tensor的数据类型不同时,低精度tensor会自动转成高精度tensor。

MindSpore: 当前要求输入tensor的数据类型保持一致,若不一致时可通过ops.Cast把低精度tensor转成高精度类型再调用Concat算子。

代码示例

import mindspore
import mindspore.ops as ops
from mindspore import Tensor
import torch
import numpy as np

# In MindSpore,converting low precision to high precision is needed before concat.
a = Tensor(np.ones([2, 3]).astype(np.float16))
b = Tensor(np.ones([2, 3]).astype(np.float32))
concat_op = ops.Concat()
cast_op = ops.Cast()
output = concat_op((cast_op(a, mindspore.float32), b))
print(output.shape)
# Out:
# (4, 3)

# In Pytorch.
a = torch.tensor(np.ones([2, 3]).astype(np.float16))
b = torch.tensor(np.ones([2, 3]).astype(np.float32))
output = torch.cat((a, b))
print(output.size())
# Out:
# torch.Size([4, 3])