Source code for mindspore.dataset.engine.graphdata

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data.
"""
import numpy as np
from mindspore._c_dataengine import Graph
from mindspore._c_dataengine import Tensor

from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
    check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
    check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk


[docs]class GraphData: """ Reads the graph dataset used for GNN training from the shared file and database. Args: dataset_file (str): One of file names in dataset. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). """ @check_gnn_graphdata def __init__(self, dataset_file, num_parallel_workers=None): self._dataset_file = dataset_file if num_parallel_workers is None: num_parallel_workers = 1 self._graph = Graph(dataset_file, num_parallel_workers)
[docs] @check_gnn_get_all_nodes def get_all_nodes(self, node_type): """ Get all nodes in the graph. Args: node_type (int): Specify the type of node. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) Raises: TypeError: If `node_type` is not integer. """ return self._graph.get_all_nodes(node_type).as_array()
[docs] @check_gnn_get_all_edges def get_all_edges(self, edge_type): """ Get all edges in the graph. Args: edge_type (int): Specify the type of edge. Returns: numpy.ndarray: array of edges. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_edges(0) Raises: TypeError: If `edge_type` is not integer. """ return self._graph.get_all_edges(edge_type).as_array()
[docs] @check_gnn_get_nodes_from_edges def get_nodes_from_edges(self, edge_list): """ Get nodes from the edges. Args: edge_list (list or numpy.ndarray): The given list of edges. Returns: numpy.ndarray: array of nodes. Raises: TypeError: If `edge_list` is not list or ndarray. """ return self._graph.get_nodes_from_edges(edge_list).as_array()
[docs] @check_gnn_get_all_neighbors def get_all_neighbors(self, node_list, neighbor_type): """ Get `neighbor_type` neighbors of the nodes in `node_list`. Args: node_list (list or numpy.ndarray): The given list of nodes. neighbor_type (int): Specify the type of neighbor. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> neighbors = data_graph.get_all_neighbors(nodes, 0) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `neighbor_type` is not integer. """ return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
[docs] @check_gnn_get_sampled_neighbors def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): """ Get sampled neighbor information, maximum support 6-hop sampling. Args: node_list (list or numpy.ndarray): The given list of nodes. neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop. neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0]) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ return self._graph.get_sampled_neighbors( node_list, neighbor_nums, neighbor_types).as_array()
[docs] @check_gnn_get_neg_sampled_neighbors def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): """ Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`. Args: node_list (list or numpy.ndarray): The given list of nodes. neg_neighbor_num (int): Number of neighbors sampled. neg_neighbor_type (int): Specify the type of negative neighbor. Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_type` is not integer. """ return self._graph.get_neg_sampled_neighbors( node_list, neg_neighbor_num, neg_neighbor_type).as_array()
[docs] @check_gnn_get_node_feature def get_node_feature(self, node_list, feature_types): """ Get `feature_types` feature of the nodes in `node_list`. Args: node_list (list or numpy.ndarray): The given list of nodes. feature_types (list or ndarray): The given list of feature types. Returns: numpy.ndarray: array of features. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.get_all_nodes(0) >>> features = data_graph.get_node_feature(nodes, [1]) Raises: TypeError: If `node_list` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray. """ if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) return [ t.as_array() for t in self._graph.get_node_feature( Tensor(node_list), feature_types)]
[docs] def graph_info(self): """ Get the meta information of the graph, including the number of nodes, the type of nodes, the feature information of nodes, the number of edges, the type of edges, and the feature information of edges. Returns: Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, node_feature_type and edge_feature_type. """ return self._graph.graph_info()
[docs] @check_gnn_random_walk def random_walk( self, target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1): """ Random walk in nodes. Args: target_nodes (list[int]): Start node list in random walk meta_path (list[int]): node type for each walk step step_home_param (float): return hyper parameter in node2vec algorithm step_away_param (float): inout hyper parameter in node2vec algorithm default_node (int): default node if no more neighbors found Returns: numpy.ndarray: array of nodes. Examples: >>> import mindspore.dataset as ds >>> data_graph = ds.GraphData('dataset_file', 2) >>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1]) Raises: TypeError: If `target_nodes` is not list or ndarray. TypeError: If `meta_path` is not list or ndarray. """ return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, default_node).as_array()