mindspore.dataset.dataloader.TensorDataset

查看源文件
class mindspore.dataset.dataloader.TensorDataset(*tensors)[源代码]

mindspore.Tensor 集合定义的数据集。

每个 Tensor 表示数据集的一列特征,其第0维的大小必须相同,即样本总数。 数据集将沿着 Tensor 的第0维来检索样本。

参数:

样例:

>>> from mindspore import Tensor, int32
>>> from mindspore.dataset.dataloader import TensorDataset
>>>
>>> dataset = TensorDataset(Tensor([0, 1], dtype=int32), Tensor([2, 3], dtype=int32))
>>> for sample in dataset:
...     print(sample)
(Tensor(shape=[], dtype=Int32, value= 0), Tensor(shape=[], dtype=Int32, value= 2))
(Tensor(shape=[], dtype=Int32, value= 1), Tensor(shape=[], dtype=Int32, value= 3))