mindspore_gl.nn.ASTGCN

View Source On Gitee
class mindspore_gl.nn.ASTGCN(n_blocks: int, in_channels: int, k: int, n_chev_filters: int, n_time_filters: int, time_conv_strides: int, num_for_predict: int, len_input: int, n_vertices: int, normalization: Optional[str] = 'sym', bias: bool = True)[source]

Attention Based Spatial-Temporal Graph Convolutional Networks. From the paper Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting .

Parameters
  • n_blocks (int) – Number of ASTGCN Blocks

  • in_channels (int) – Input node feature size.

  • k (int) – Order of Chebyshev polynomials.

  • n_chev_filters (int) – Number of Chebyshev filters.

  • n_time_filters (int) – Number of time filters.

  • time_conv_strides (int) – Time strides during temporal convolution.

  • num_for_predict (int) – Number of predictions to make in the future.

  • len_input (int) – Length of the input sequence.

  • n_vertices (int) – Number of vertices in the graph.

  • normalization (str, optional) – The normalization scheme for the graph Laplacian. Default: 'sym'. \((L)\) is normalized matrix, \((D)\) is degree matrix, \((A)\) is adjaceny matrix, \((I)\) is unit matrix. \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

  • bias (bool, optional) – Whether the layer will learn an additive bias. Default: True.

Inputs:
  • x (Tensor) - The input node features for T time periods. The shape is \((B, N, F_{in}, T_{in})\) where \(N\) is the number of nodes,

  • g (Graph) - The input graph.

Outputs:
  • Tensor, output node features with shape of \((B, N, T_{out})\).

Raises
  • TypeError – If n_blocks, in_channels, k, n_chev_filters, n_time_filters, time_conv_strides, num_for_predict, len_input or n_vertices is not a positive int.

  • ValueError – If normalization is not 'sym'.

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore_gl.graph import norm
>>> from mindspore_gl.nn import ASTGCN
>>> from mindspore_gl import GraphField
>>> node_count = 5
>>> num_for_predict = 4
>>> len_input = 4
>>> n_time_strides = 1
>>> node_features = 2
>>> nb_block = 2
>>> k = 3
>>> n_chev_filters = 8
>>> n_time_filters = 8
>>> batch_size = 2
>>> normalization = "sym"
>>> edge_index = np.array([[0, 0, 0, 0, 1, 1, 1, 2, 2, 3],
                           [1, 4, 2, 3, 2, 3, 4, 3, 4, 4]])
>>> model = ASTGCN(nb_block, node_features, k, n_chev_filters, n_time_filters,
            n_time_strides, num_for_predict, len_input, node_count, normalization)
>>> edge_index_norm, edge_weight_norm = norm(Tensor(edge_index, dtype=ms.int32), node_count)
>>> graph = GraphField(edge_index_norm[1], edge_index_norm[0], node_count, len(edge_index_norm[0]))
>>> x_seq = Tensor(np.ones([batch_size, node_count, node_features, len_input]), dtype=ms.float32)
>>> output = model(x_seq, edge_weight_norm, *graph.get_graph())
>>> print(output.shape)
(2, 5, 4)