mindspore_gl.graph.graph_csr_data

查看源文件
mindspore_gl.graph.graph_csr_data(src_idx, dst_idx, n_nodes, n_edges, node_feat=None, node_label=None, train_mask=None, val_mask=None, test_mask=None, rerank=False)[源代码]

将COO类型的整图转为CSR类型。

参数:
  • src_idx (Union[Tensor, numpy.ndarray]) - shape为 \((N\_EDGES)\) 的int类型Tensor,表示COO边矩阵的源节点索引。

  • dst_idx (Union[Tensor, numpy.ndarray]) - shape为 \((N\_EDGES)\) 的int类型Tensor,表示COO边矩阵的目标节点索引。

  • n_nodes (int) - 图中节点数量。

  • n_edges (int) - 图中边数量。

  • node_feat (Union[Tensor, numpy.ndarray, 可选]) - 节点特征。默认值:None

  • node_label (Union[Tensor, numpy.ndarray, 可选]) - 节点标签。默认值:None

  • train_mask (Union[Tensor, numpy.ndarray, 可选]) - 训练索引的掩码。默认值:None

  • val_mask (Union[Tensor, numpy.ndarray, 可选]) - 验证索引的掩码。默认值:None

  • test_mask (Union[Tensor, numpy.ndarray, 可选]) - 测试索引的掩码。默认值:None

  • rerank (bool, 可选) - 是否对节点特征、标签、掩码进行重排序。默认值:False

返回:
  • csr_g (tuple) - CSR图的信息,它包含CSR图的indices,CSR图的indptr,CSR图的节点数、CSR图的边数、CSR图的预存的反向indices、CSR图的预存储反向indptr。

  • in_deg - 每个节点的入度。

  • out_deg - 每个节点的出度。

  • node_feat (Union[Tensor, numpy.ndarray, 可选]) - 重排序的节点特征。

  • node_label (Union[Tensor, numpy.ndarray, 可选]) - 重排序的节点标签。

  • train_mask (Union[Tensor, numpy.ndarray, 可选]) - 重排序的训练索引的掩码。

  • val_mask (Union[Tensor, numpy.ndarray, 可选]) - 重排序的验证索引的掩码。

  • test_mask (Union[Tensor, numpy.ndarray, 可选]) - 重排序的测试索引的掩码。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> from mindspore_gl.graph import graph_csr_data
>>> node_feat = np.array([[1, 2, 3, 4], [2, 4, 1, 3], [1, 3, 2, 4],
...                       [9, 7, 5, 8], [8, 7, 6, 5], [8, 6, 4, 6], [1, 2, 1, 1]], np.float32)
>>> n_nodes = 7
>>> n_edges = 8
>>> edge_feat_size = 7
>>> src_idx = np.array([0, 2, 2, 3, 4, 5, 5, 6], np.int32)
>>> dst_idx = np.array([1, 0, 1, 5, 3, 4, 6, 4], np.int32)
>>> node_label = np.array([0, 1, 0, 1, 0, 1, 0])
>>> train_mask = np.array([True, True, True, True, False, False, False])
>>> val_mask = np.array([False, False, False, False, True, True, True])
>>> g, in_deg, out_deg, node_feat, node_label, train_mask, val_mask,\
>>> test_mask = graph_csr_data(src_idx,dst_idx, n_nodes, n_edges, node_feat, node_label,
...                            train_mask, val_mask, test_mask=None, rerank=True)
>>> print(g[0], g[1])
[2 3 5 6 3 4 0 6] [0 2 4 5 6 7 8 8]
>>> print(node_feat, node_label)
[[8. 7. 6. 5.]
[2. 4. 1. 3.]
[1. 2. 1. 1.]
[8. 6. 4. 6.]
[9. 7. 5. 8.]
[1. 2. 3. 4.]
[1. 3. 2. 4.]] [0 1 0 1 1 0 0]
>>> print(train_mask, val_mask)
[False  True False False  True  True  True] [ True False  True  True False False False]