mindspore.dataset.dataloader.TensorDataset

View Source On Gitee
class mindspore.dataset.dataloader.TensorDataset(*tensors)[source]

Dataset that defined by a collection of mindspore.Tensor .

Each Tensor represent a feature column of the dataset, and must have the same size in the first dimension, which means the total number of samples. Samples will be retrieved by indexing Tensor along their first dimension.

Parameters

*tensors (mindspore.Tensor) – A collection of mindspore.Tensor.

Examples

>>> from mindspore import Tensor, int32
>>> from mindspore.dataset.dataloader import TensorDataset
>>>
>>> dataset = TensorDataset(Tensor([0, 1], dtype=int32), Tensor([2, 3], dtype=int32))
>>> for sample in dataset:
...     print(sample)
(Tensor(shape=[], dtype=Int32, value= 0), Tensor(shape=[], dtype=Int32, value= 2))
(Tensor(shape=[], dtype=Int32, value= 1), Tensor(shape=[], dtype=Int32, value= 3))