Source code for jittor_geometric.datasets.molecule_net

import os
import os.path as osp
import re
import warnings
from typing import Callable, Dict, Optional, Tuple, Union
import jittor as jt
from jittor_geometric.data import InMemoryDataset, download_url, extract_gz
from huggingface_hub import hf_hub_download
from jittor_geometric.utils import from_smiles as _from_smiles


[docs] class MoleculeNet(InMemoryDataset): r"""The `MoleculeNet <http://moleculenet.org/datasets-1>`_ benchmark collection from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets from physical chemistry, biophysics and physiology. All datasets come with the additional node and edge features introduced by the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_. Args: root (str): Root directory where the dataset should be saved. name (str): The name of the dataset (:obj:`"ESOL"`, :obj:`"FreeSolv"`, :obj:`"Lipo"`, :obj:`"PCBA"`, :obj:`"MUV"`, :obj:`"HIV"`, :obj:`"BACE"`, :obj:`"BBBP"`, :obj:`"Tox21"`, :obj:`"ToxCast"`, :obj:`"SIDER"`, :obj:`"ClinTox"`). transform (callable, optional): A function/transform that takes in an :obj:`jittor_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`jittor_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`jittor_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) from_smiles (callable, optional): A custom function that takes a SMILES string and outputs a :obj:`~jittor_geometric.data.Data` object. If not set, defaults to :meth:`~jittor_geometric.utils.from_smiles`. (default: :obj:`None`) **STATS:** .. list-table:: :widths: 20 10 10 10 10 10 :header-rows: 1 * - Name - #graphs - #nodes - #edges - #features - #classes * - ESOL - 1,128 - ~13.3 - ~27.4 - 9 - 1 * - FreeSolv - 642 - ~8.7 - ~16.8 - 9 - 1 * - ClinTox - 1,484 - ~26.1 - ~55.5 - 9 - 2 """ url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}' # Format: name: (display_name, url_name, csv_name, smiles_idx, y_idx) names: Dict[str, Tuple[str, str, str, int, Union[int, slice]]] = { 'esol': ('ESOL', 'delaney-processed.csv', 'delaney-processed', -1, -2), 'freesolv': ('FreeSolv', 'SAMPL.csv', 'SAMPL', 1, 2), 'lipo': ('Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1), 'pcba': ('PCBA', 'pcba.csv.gz', 'pcba', -1, slice(0, 128)), 'muv': ('MUV', 'muv.csv.gz', 'muv', -1, slice(0, 17)), 'hiv': ('HIV', 'HIV.csv', 'HIV', 0, -1), 'bace': ('BACE', 'bace.csv', 'bace', 0, 2), 'bbbp': ('BBBP', 'BBBP.csv', 'BBBP', -1, -2), 'tox21': ('Tox21', 'tox21.csv.gz', 'tox21', -1, slice(0, 12)), 'toxcast': ('ToxCast', 'toxcast_data.csv.gz', 'toxcast_data', 0, slice(1, 618)), 'sider': ('SIDER', 'sider.csv.gz', 'sider', 0, slice(1, 28)), 'clintox': ('ClinTox', 'clintox.csv.gz', 'clintox', 0, slice(1, 3)), }
[docs] def __init__( self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, from_smiles: Optional[Callable] = None, ) -> None: self.name = name.lower() assert self.name in self.names.keys() self.from_smiles = from_smiles or _from_smiles super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = jt.load(self.processed_paths[0])
@property def raw_dir(self) -> str: return osp.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: return f'{self.names[self.name][2]}.csv' @property def processed_file_names(self) -> str: return f'{self.name}.pkl'
[docs] def download(self) -> None: url = self.url.format(self.names[self.name][1]) path = download_url(url, self.raw_dir) if self.names[self.name][1][-2:] == 'gz': extract_gz(path, self.raw_dir) os.unlink(path)
[docs] def process(self) -> None: with open(self.raw_paths[0]) as f: dataset = f.read().split('\n')[1:-1] dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. data_list = [] for line in dataset: line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. values = line.split(',') smiles = values[self.names[self.name][3]] labels = values[self.names[self.name][4]] labels = labels if isinstance(labels, list) else [labels] ys = [float(y) if len(y) > 0 else float('NaN') for y in labels] y = jt.array(ys, dtype='float32').view(1, -1) data = self.from_smiles(smiles) data.y = y if data.num_nodes == 0: warnings.warn(f"Skipping molecule '{smiles}' since it " f"resulted in zero atoms") continue if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) jt.save(self.collate(data_list), self.processed_paths[0])
def __repr__(self) -> str: return f'{self.name}(len={len(self)})'