mindspore.dataset.dataloader.TensorDataset
- class mindspore.dataset.dataloader.TensorDataset(*tensors)[源代码]
由
mindspore.Tensor
集合定义的数据集。每个
Tensor
表示数据集的一列特征,其第0维的大小必须相同,即样本总数。 数据集将沿着Tensor
的第0维来检索样本。- 参数:
*tensors (mindspore.Tensor) -
mindspore.Tensor
集合。
样例:
>>> from mindspore import Tensor, int32 >>> from mindspore.dataset.dataloader import TensorDataset >>> >>> dataset = TensorDataset(Tensor([0, 1], dtype=int32), Tensor([2, 3], dtype=int32)) >>> for sample in dataset: ... print(sample) (Tensor(shape=[], dtype=Int32, value= 0), Tensor(shape=[], dtype=Int32, value= 2)) (Tensor(shape=[], dtype=Int32, value= 1), Tensor(shape=[], dtype=Int32, value= 3))