import pandas as pd
import shutil, os
import os.path as osp
import numpy as np
import jittor as jt
from jittor_geometric.io import read_graph, read_heterograph, read_node_label_hetero, read_nodesplitidx_split_hetero
from jittor_geometric.data import InMemoryDataset, download_url, decide_download, extract_zip
[docs]
class OGBNodePropPredDataset(InMemoryDataset):
r"""The Open Graph Benchmark (OGB) Node Property Prediction Datasets, provided by the OGB team.
These datasets are designed to benchmark large-scale node property prediction tasks on real-world graphs.
This class provides access to various OGB datasets focused on node property prediction tasks. Each dataset contains
nodes representing entities (e.g., papers, products) and edges representing relationships (e.g., citations, co-purchases).
The goal is to predict specific node-level properties, such as categories or timestamps, based on the graph structure
and node features.
Dataset Details:
- **ogbn-arxiv**: A citation network where nodes represent arXiv papers and directed edges indicate citation relationships.
The task is to predict the subject area of each paper based on word2vec features derived from the title and abstract.
- **ogbn-products**: An Amazon product co-purchasing network where nodes represent products and edges indicate frequently
co-purchased products. The task is to classify each product based on its category, with node features based on product descriptions.
- **ogbn-paper100M**: A large-scale citation network where nodes represent research papers and edges indicate citation links.
The node features are derived from word embeddings of the paper abstracts. The task is to predict the subject area of each paper.
These datasets are provided by the Open Graph Benchmark (OGB) team, which aims to facilitate machine learning research
on graphs by offering diverse, large-scale datasets. For more details, visit the OGB website: https://ogb.stanford.edu/.
Args:
name (str): The name of the dataset to load. Options include:
- :obj:`"ogbn-arxiv"`
- :obj:`"ogbn-products"`
- :obj:`"ogbn-paper100M"`
root (str): Root directory where the dataset folder will be stored.
transform (callable, optional): A function/transform that takes in a graph object and returns a transformed version.
The graph object will be transformed on each access. (default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in a graph object and returns a transformed version.
The graph object will be transformed before being saved to disk. (default: :obj:`None`)
meta_dict (dict, optional): A dictionary containing meta-information about the dataset.
When provided, it overrides default meta-information, useful for debugging or contributions from external users.
Example:
>>> dataset = OGBNodePropPredDataset(name="ogbn-arxiv", root="path/to/dataset")
>>> data = dataset[0] # Access the first graph object
Acknowledgment:
The OGBNodePropPredDataset is developed and maintained by the Open Graph Benchmark (OGB) team. We sincerely thank
the OGB team for their significant contributions to the graph machine learning community.
"""
def __init__(self, name, root='dataset', transform=None, pre_transform=None, meta_dict=None):
self.name = name # original name, e.g., ogbn-proteins
if meta_dict is None:
self.dir_name = '_'.join(name.split('-'))
# check if previously-downloaded folder exists.
# If so, use that one.
if osp.exists(osp.join(root, self.dir_name)):
self.dir_name = self.dir_name
self.original_root = root
self.root = osp.join(root, self.dir_name)
master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col=0, keep_default_na=False)
if not self.name in master:
error_mssg = 'Invalid dataset name {}.\n'.format(self.name)
error_mssg += 'Available datasets are as follows:\n'
error_mssg += '\n'.join(master.keys())
raise ValueError(error_mssg)
self.meta_info = master[self.name]
else:
self.dir_name = meta_dict['dir_path']
self.original_root = ''
self.root = meta_dict['dir_path']
self.meta_info = meta_dict
# check version
# First check whether the dataset has been already downloaded or not.
# If so, check whether the dataset version is the newest or not.
# If the dataset is not the newest version, notify this to the user.
if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))):
print(self.name + ' has been updated.')
if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
shutil.rmtree(self.root)
self.download_name = self.meta_info['download_name'] # name of downloaded file, e.g., tox21
self.num_tasks = int(self.meta_info['num tasks'])
self.task_type = self.meta_info['task type']
self.eval_metric = self.meta_info['eval metric']
self.__num_classes__ = int(self.meta_info['num classes'])
self.is_hetero = self.meta_info['is hetero'] == 'True'
self.binary = self.meta_info['binary'] == 'True'
super(OGBNodePropPredDataset, self).__init__(self.root, transform, pre_transform)
self.data, self.slices = jt.load(self.processed_paths[0])
[docs]
def get_idx_split(self, split_type=None):
if split_type is None:
split_type = self.meta_info['split']
path = osp.join(self.root, 'split', split_type)
# short-cut if split_dict.pkl exists
if os.path.isfile(os.path.join(path, 'split_dict.pkl')):
return jt.load(os.path.join(path, 'split_dict.pkl'))
if self.is_hetero:
train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path)
for nodetype in train_idx_dict.keys():
train_idx_dict[nodetype] = jt.array(train_idx_dict[nodetype]).int32()
valid_idx_dict[nodetype] = jt.array(valid_idx_dict[nodetype]).int32()
test_idx_dict[nodetype] = jt.array(test_idx_dict[nodetype]).int32()
return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict}
else:
train_idx = jt.array(pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0]).int32()
valid_idx = jt.array(pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0]).int32()
test_idx = jt.array(pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0]).int32()
return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}
@property
def num_classes(self):
return self.__num_classes__
@property
def raw_file_names(self):
if self.binary:
if self.is_hetero:
return ['edge_index_dict.npz']
else:
return ['data.npz']
else:
if self.is_hetero:
return ['num-node-dict.csv.gz', 'triplet-type-list.csv.gz']
else:
file_names = ['edge']
if self.meta_info['has_node_attr'] == 'True':
file_names.append('node-feat')
if self.meta_info['has_edge_attr'] == 'True':
file_names.append('edge-feat')
return [file_name + '.csv.gz' for file_name in file_names]
@property
def processed_file_names(self):
return osp.join('geometric_data_processed.pkl')
[docs]
def download(self):
url = self.meta_info['url']
if decide_download(url):
path = download_url(url, self.original_root)
extract_zip(path, self.original_root)
os.unlink(path)
shutil.rmtree(self.root)
shutil.move(osp.join(self.original_root, self.download_name), self.root)
else:
print('Stop downloading.')
shutil.rmtree(self.root)
exit(-1)
[docs]
def process(self):
add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'
if self.meta_info['additional node files'] == 'None':
additional_node_files = []
else:
additional_node_files = self.meta_info['additional node files'].split(',')
if self.meta_info['additional edge files'] == 'None':
additional_edge_files = []
else:
additional_edge_files = self.meta_info['additional edge files'].split(',')
if self.is_hetero:
data = read_heterograph(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files, binary=self.binary)[0]
if self.binary:
tmp = np.load(osp.join(self.raw_dir, 'node-label.npz'))
node_label_dict = {key: tmp[key] for key in tmp.keys()}
del tmp
else:
node_label_dict = read_node_label_hetero(self.raw_dir)
data.y_dict = {}
if 'classification' in self.task_type:
for nodetype, node_label in node_label_dict.items():
# detect if there is any nan
if np.isnan(node_label).any():
data.y_dict[nodetype] = jt.array(node_label).float32()
else:
data.y_dict[nodetype] = jt.array(node_label).int32()
else:
for nodetype, node_label in node_label_dict.items():
data.y_dict[nodetype] = jt.array(node_label).float32()
else:
data = read_graph(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files, binary=self.binary)[0]
### adding prediction target
if self.binary:
node_label = np.load(osp.join(self.raw_dir, 'node-label.npz'))['node_label']
else:
node_label = pd.read_csv(osp.join(self.raw_dir, 'node-label.csv.gz'), compression='gzip', header=None).values
if 'classification' in self.task_type:
# detect if there is any nan
if np.isnan(node_label).any():
data.y = jt.array(node_label).float32()
else:
data.y = jt.array(node_label).int32()
else:
data.y = jt.array(node_label).float32()
data = data if self.pre_transform is None else self.pre_transform(data)
print('Saving...')
jt.save(self.collate([data]), self.processed_paths[0])
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
## TODO: DELETE
if __name__ == '__main__':
dataset = OGBNodePropPredDataset(name='ogbn-mag')
split_index = dataset.get_idx_split()