mindspore_gl.dataset.metr_la 源代码

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""METR LA"""
import os
import numpy as np
from .base_dataset import BaseDataSet


#pylint: disable=W0223
[文档]class MetrLa(BaseDataSet): """ METR-LA Dataset, a source dataset for reading and parsing METR-LA dataset. About METR-LA dataset: METR-LA is a large-scale dataset collected from 1500 traffic loop detectors in Los Angeles country road network. This dataset includes speed, volume and occupancy data, covering approximately 3,420 miles. Statistics: - Time step: 12,6850 - Nodes: 207 - Edges: 1515 Dataset can be download here: `METR-LA <https://graphmining.ai/temporal_datasets/METR-LA.zip>`_ . You can organize the dataset files into the following directory structure and read by `mindspore_gl.dataset.MetrLa.get_data` API. .. code-block:: . ├── adj_mat.npy └── node_values.npy Args: root(str): path to the root directory that contains METR-LA/adj_mat.npy and METR-LA/node_values.npy. Inputs: - **in_timestep** (int) - numbers of input time sequence. - **out_timestep** (int) - numbers of output time sequence. Raises: TypeError: if `root` is not a str. RuntimeError: if `root` does not contain data files. TypeError: If `in_timestep` or `out_timestep` is not a positive int. Examples: >>> from mindspore_gl.dataset.ppi import MetrLa >>> root = "path/to/metrla" >>> dataset = MetrLa(root) >>> features, labels = dataset.get_data(in_timestep, out_timestep) """ def __init__(self, root): if not isinstance(root, str): raise TypeError(f"For '{self.cls_name}', the 'root' should be a str, " f"but got {type(root)}.") self._root = root self._adj = os.path.join(root, 'adj_mat.npy') self._node = os.path.join(root, 'node_values.npy') if os.path.exists(self._adj) and os.path.isfile(self._adj) and \ os.path.exists(self._node) and os.path.isfile(self._node): self._load() else: raise Exception('data file does not exist') def _load(self): """load data""" self.adj = np.load(self._adj) index = np.nonzero(self.adj) self.edge_attr = self.adj[index] self.edge_index = np.stack(index, axis=0) self.x = np.load(self._node).transpose((1, 2, 0)) means = np.mean(self.x, axis=(0, 2)) self.x = self.x - means.reshape(1, -1, 1) stds = np.std(self.x, axis=(0, 2)) self.x = self.x / stds.reshape(1, -1, 1)
[文档] def get_data(self, in_timestep, out_timestep): """ Get sequence time feature and label. Args: in_timestep(int): numbers of input time sequence. out_timestep(int): numbers of output time sequence. """ if not (isinstance(in_timestep, int) and in_timestep > 0): raise Exception('the in_timestep must be a positive integer value') if not (isinstance(out_timestep, int) and out_timestep > 0): raise Exception('the out_timestep must be a positive integer value') indices = [(i, i + (in_timestep + out_timestep)) for i in range(self.x.shape[2] - (in_timestep + out_timestep) + 1) ] features, labels = [], [] for i, j in indices: features.append((self.x[:, :, i: i + in_timestep])) labels.append((self.x[:, 0, i + in_timestep: j])) self.features = np.array(features) self.labels = np.array(labels) return self.features, self.labels
@property def node_count(self): """ Number of nodes. Returns: - int, number of node. Examples: >>> #dataset is an instance object of Dataset >>> node_count = dataset.node_count """ return self.features.shape[1]