Source code for jittor_geometric.datasets.recsys

'''
Description:
Author: zhengyp
Date: 2025-07-13
'''
import os
import os.path as osp
import pdb
import zipfile

import pandas as pd
import numpy as np
import jittor as jt
from jittor_geometric.data import InMemoryDataset, Data, download_url

class DataStruct:
    def __init__(self):
        self._data_dict = {}

    def update_tensor(self, name: str, value: jt.Var):
        # 确保 value 不再参与 autograd(重要)
        # value = value.detach()

        if name not in self._data_dict:
            self._data_dict[name] = value
        else:
            if not isinstance(self._data_dict[name], np.ndarray):
                raise ValueError(f"{name} is not a Jittor tensor.")
            self._data_dict[name] = np.concatenate([self._data_dict[name], value], axis=0)

    def get_tensor(self, name: str):

        return self._data_dict.get(name, None)

class TopkMetric:
    def __init__(self, k=10, metric_decimal_place=4):
        self.decimal_place = metric_decimal_place
        self.topk = [k] if isinstance(k, int) else list(k)

    def used_info(self, rec_mat):
        k = max(self.topk)
        topk_idx, pos_len_list = jt.split(rec_mat, [k, 1], dim=1)
        return topk_idx.bool().numpy(), pos_len_list.squeeze(1).numpy()

    def topk_result(self, metric, value):
        metric_dict = {}
        avg_result = value.mean(dim=0)
        for k in self.topk:
            key = "{}@{}".format(metric, k)
            metric_dict[key] = round(avg_result[k - 1].item(), self.decimal_place)
        return metric_dict

class Hit(TopkMetric):
    def calculate_metric(self, rec_mat):
        pos_index, _ = self.used_info(rec_mat)
        result = self.metric_info(pos_index)
        metric_dict = self.topk_result("hit", result)
        return metric_dict

    def metric_info(self, pos_index):
        result = jt.cumsum(pos_index, dim=1)
        return (result > 0).astype(jt.int)

class MRR(TopkMetric):
    def calculate_metric(self, rec_mat):
        pos_index, _ = self.used_info(rec_mat)
        result = self.metric_info(pos_index)
        metric_dict = self.topk_result("mrr", result)
        return metric_dict

    def metric_info(self, pos_index):
        idxs, _ = jt.argmax(pos_index, dim=1)  # pos_index.argmax(dim=1)
        result = jt.zeros((pos_index.shape), dtype=jt.float)
        for row in range(pos_index.shape[0]):
            idx = int(idxs[row].item())  # 转换为Python标量
            if pos_index[row, idx] > 0:
                # 使用jt.where替代切片赋值
                mask = jt.array([i >= idx for i in range(pos_index.shape[1])], dtype=jt.bool)
                result[row] = jt.where(mask, 1.0 / (idx + 1), 0.0)
        return result

class NDCG(TopkMetric):
    def calculate_metric(self, rec_mat):
        pos_index, pos_len = self.used_info(rec_mat)
        result = self.metric_info(pos_index, pos_len)
        metric_dict = self.topk_result("ndcg", result)
        return metric_dict

    def metric_info(self, pos_index, pos_len):
        len_rank = jt.full_like(pos_len, pos_index.shape[1])
        idcg_len = jt.minimum(pos_len, len_rank)

        iranks = jt.arange(1, pos_index.shape[1] + 1).unsqueeze(0).broadcast(pos_index.shape)
        idcg = 1.0 / jt.log2(iranks + 1)
        idcg = jt.cumsum(idcg, dim=1)

        for row in range(idcg.shape[0]):
            idx = int(idcg_len[row].item())
            if idx < idcg.shape[1]:
                idcg[row, idx:] = idcg[row, idx - 1:idx]

        ranks = jt.arange(1, pos_index.shape[1] + 1).unsqueeze(0).broadcast(pos_index.shape)
        dcg = 1.0 / jt.log2(ranks + 1)
        dcg = dcg * pos_index
        dcg = jt.cumsum(dcg, dim=1)

        result = dcg / idcg
        return result

class Recall(TopkMetric):
    def calculate_metric(self, rec_mat):
        pos_index, pos_len = self.used_info(rec_mat)
        result = self.metric_info(pos_index, pos_len)
        metric_dict = self.topk_result("recall", result)
        return metric_dict

    def metric_info(self, pos_index, pos_len):
        return jt.cumsum(pos_index, dim=1) / pos_len.reshape(-1, 1)

def _ensure_dir(path: str):
    """Create the directory if it does not exist."""
    os.makedirs(path, exist_ok=True)

def _zero_based(df: pd.DataFrame, cols):
    """
    Convert the specified integer ID columns to zero-based indexing.
    If the minimum value in the column is 1, subtract 1 from all values.
    """
    for c in cols:
        df[c] = df[c].astype(np.int64)
        if df[c].min() == 1:
            df[c] = df[c] - 1
    return df

def _df_to_edge_index(df):
    src = jt.array(df["user_id"].to_numpy(), dtype="int32")
    dst = jt.array(df["item_id"].to_numpy(), dtype="int32")
    return jt.stack([src, dst], dim=0)  # [2, E]

def _df_to_edge_attr(df, cols):
    feats = []
    for c in cols:
        if c in df.columns:
            feats.append(jt.array(df[c].to_numpy().astype(np.float32)))
    return jt.stack(feats, dim=1) if feats else None

def split_dataset(interactions, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, shuffle=True, group_by=None, seed=42):
    """
    Split a dataset into train/validation/test sets by given ratios.
    Optionally split within each group defined by `group_by` before combining results.

    Args:
        interactions (pd.DataFrame): The interaction records.
        train_ratio (float): Proportion of the training set.
        val_ratio (float): Proportion of the validation set.
        test_ratio (float): Proportion of the test set.
        shuffle (bool): Whether to shuffle before splitting.
        seed (int): Random seed for reproducibility.
        group_by (str or None): Optional column name to group by before splitting.

    Returns:
        tuple: (train_df, valid_df, test_df, used_interactions)
            train_df: DataFrame for training set
            valid_df: DataFrame for validation set
            test_df: DataFrame for test set
            used_interactions: DataFrame after any shuffling applied
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1."

    def _split_one(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Split one DataFrame according to the ratios."""
        n = len(df)
        n_train = int(train_ratio * n)  # round down
        n_val   = int(val_ratio * n)    # round down
        # The remaining samples go to test
        train = df.iloc[:n_train]
        valid = df.iloc[n_train:n_train + n_val]
        test  = df.iloc[n_train + n_val:]
        return train, valid, test

    if group_by is None:
        # No grouping: shuffle the entire dataset, then split
        used = interactions
        if shuffle:
            used = used.sample(frac=1, random_state=seed).reset_index(drop=True)

        train, valid, test = _split_one(used)
        return train, valid, test, used
    else:
        # Grouping mode: shuffle and split within each group separately
        if group_by not in interactions.columns:
            raise KeyError(f"Column not found for group_by: {group_by}")

        rng = np.random.RandomState(seed)

        train_parts, val_parts, test_parts, used_parts = [], [], [], []
        # Keep group order as in original DataFrame (sort=False)
        for _, gdf in interactions.groupby(group_by, sort=False):
            g_used = gdf
            if shuffle:
                # Use a different but reproducible sub-seed for each group
                g_seed = int(rng.randint(0, 2**31 - 1))
                g_used = g_used.sample(frac=1, random_state=g_seed).reset_index(drop=True)

            tr, va, te = _split_one(g_used)
            train_parts.append(tr)
            val_parts.append(va)
            test_parts.append(te)
            used_parts.append(g_used)

        # Concatenate splits from all groups
        train = pd.concat(train_parts, ignore_index=True)
        valid = pd.concat(val_parts, ignore_index=True)
        test  = pd.concat(test_parts,  ignore_index=True)
        used  = pd.concat(used_parts,  ignore_index=True)

        pdb.set_trace()
        return train, valid, test, used

class RecSysBase(InMemoryDataset):
    def __init__(self, root, name, transform=None, pre_transform=None,
                 train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42,
                 edge_attr_columns=None,
                 group_by='user_id',
                 shuffle=True, with_aux: bool = False):
        self.name = name
        self._train_ratio = train_ratio
        self._val_ratio = val_ratio
        self._test_ratio = test_ratio
        self._seed = seed
        self._edge_attr_columns = edge_attr_columns
        self.group_by = group_by
        self._shuffle = shuffle
        self.with_aux = with_aux
        super().__init__(root, transform, pre_transform)

        proc_path = self.processed_paths[0]
        if not osp.exists(proc_path):
            self.process()
        self.data, self.slices = jt.load(proc_path)

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, "raw")

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, "processed")

    @property
    def processed_file_names(self):
        return "data.pkl"

    @property
    def raw_file_names(self):
        raise NotImplementedError

    def read_raw(self):
        """
        Must return either:
          - interactions_df                        (if self.with_aux == False)
          - (interactions_df, {'items': df?, 'users': df?})  (if self.with_aux == True)
        """
        raise NotImplementedError

    def process(self):
        _ensure_dir(self.processed_dir)
        raw = self.read_raw()
        if self.with_aux:
            interactions, aux = raw
            items = aux.get('items', None)
            users = aux.get('users', None)
        else:
            interactions = raw
            items = None
            users = None

        # Zero-based IDs
        interactions = _zero_based(interactions, ["user_id", "item_id"])
        num_users = int(interactions["user_id"].max()) + 1
        num_items = int(interactions["item_id"].max()) + 1
        print(f"Number of users: {num_users}")
        print(f"Number of items: {num_items}")
        print(f"Number of interactions: {len(interactions)}")

        train_df, val_df, test_df, inter_used = split_dataset(
            interactions,
            train_ratio=self._train_ratio,
            val_ratio=self._val_ratio,
            test_ratio=self._test_ratio,
            shuffle=self._shuffle, group_by=self.group_by,
            seed=self._seed
        )
        self.train_df = train_df
        train_edge_index = _df_to_edge_index(train_df)
        val_edge_index = _df_to_edge_index(val_df)
        test_edge_index = _df_to_edge_index(test_df)
        edge_index_all = _df_to_edge_index(inter_used)

        if self._edge_attr_columns is not None:
            train_edge_attr = _df_to_edge_attr(train_df, self._edge_attr_columns)
            val_edge_attr = _df_to_edge_attr(val_df, self._edge_attr_columns)
            test_edge_attr = _df_to_edge_attr(test_df, self._edge_attr_columns)
            edge_attr_all = _df_to_edge_attr(inter_used, self._edge_attr_columns)

        E = len(inter_used)
        n_train, n_val = len(train_df), len(val_df)
        print(f'Edges for train: {n_train}, valid: {n_val}, test: {E - n_train - n_val}')

        data = Data()
        data.edge_index = edge_index_all
        if self._edge_attr_columns is not None:
            data.edge_attr = edge_attr_all
            data.train_edge_attr = train_edge_attr
            data.val_edge_attr = val_edge_attr
            data.test_edge_attr = test_edge_attr
        data.num_users = num_users
        data.num_items = num_items
        data.num_nodes = num_users + num_items
        data.train_edge_index = train_edge_index
        data.val_edge_index = val_edge_index
        data.test_edge_index = test_edge_index

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        jt.save(self.collate([data]), self.processed_paths[0])

        if self.with_aux:
            if items is not None:
                items_path = osp.join(self.processed_dir, "items.csv")
                items.to_csv(items_path, index=False)
                self.items_df = items  # 运行期可用
            if users is not None:
                users_path = osp.join(self.processed_dir, "users.csv")
                users.to_csv(users_path, index=False)
                self.users_df = users

[docs] class MovieLens1M(RecSysBase): """ MovieLens-1M dataset with auto-download from Recbole: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-1m.zip Expected (after extraction) in raw_dir: - ml-1m.item - ml-1m.user - ml-1m.inter Files are tab-separated; first header row is skipped (skiprows=1). """ url = "https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-1m.zip"
[docs] def __init__(self, root, transform=None, pre_transform=None, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, shuffle=True, with_aux: bool=False): super().__init__(root=root, name="ml-1m", transform=transform, pre_transform=pre_transform, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed, edge_attr_columns=None, shuffle=shuffle, with_aux=with_aux) print('MovieLens1M - with_aux:', self.with_aux)
@property def raw_file_names(self): return ["ml-1m.item", "ml-1m.user", "ml-1m.inter"] def _raw_exists(self): return all(osp.exists(osp.join(self.raw_dir, f)) for f in self.raw_file_names)
[docs] def download(self): """Download and extract ml-1m.zip into raw_dir (idempotent).""" if self._raw_exists(): return os.makedirs(self.raw_dir, exist_ok=True) zip_path = osp.join(self.raw_dir, "ml-1m.zip") # Download the zip download_url(self.url, self.raw_dir) # Extract with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(self.raw_dir) candidates = [ self.raw_dir, osp.join(self.raw_dir, "ml-1m"), osp.join(self.raw_dir, "MovieLens-1M"), ] # Try to locate the three files and move them to raw_dir if needed for base in candidates: item_src = osp.join(base, "ml-1m.item") user_src = osp.join(base, "ml-1m.user") inter_src = osp.join(base, "ml-1m.inter") if all(osp.exists(p) for p in [item_src, user_src, inter_src]): # If base is not raw_dir, move files up if base != self.raw_dir: for src in [item_src, user_src, inter_src]: dst = osp.join(self.raw_dir, osp.basename(src)) if not osp.exists(dst): os.replace(src, dst) break assert self._raw_exists(), ( "Failed to locate ml-1m.item/ml-1m.user/ml-1m.inter after extraction." )
[docs] def read_raw(self): inter_file = osp.join(self.raw_dir, "ml-1m.inter") interactions = pd.read_csv(inter_file, sep='\t', engine='python', skiprows=1, names=['user_id', 'item_id', 'rating', 'timestamp']) interactions['user_id'] = interactions['user_id'].astype(int) interactions['item_id'] = interactions['item_id'].astype(int) if not self.with_aux: return interactions items = None users = None print('load properties') item_file = osp.join(self.raw_dir, "ml-1m.item") if osp.exists(item_file): items = pd.read_csv( item_file, sep='\t', engine='python', skiprows=1, names=['item_id', 'movie_title', 'release_year', 'genre'], usecols=[0, 1, 2, 3] ) items["item_id"] = items["item_id"].astype(int) user_file = osp.join(self.raw_dir, "ml-1m.user") if osp.exists(user_file): users = pd.read_csv( user_file, sep='\t', engine='python', skiprows=1, header=None, names=['user_id', 'age', 'gender', 'occupation', 'zip_code'] ) users["user_id"] = users["user_id"].astype(int) return interactions, {"items": items, "users": users}
[docs] class MovieLens100K(RecSysBase): """MovieLens-100K (RecBole processed). Downloads: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-100k.zip Expected in raw_dir after extraction: - ml-100k.item - ml-100k.user - ml-100k.inter """ url = "https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/MovieLens/ml-100k.zip"
[docs] def __init__(self, root, transform=None, pre_transform=None, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, shuffle=True, with_aux: bool=False): super().__init__(root=root, name="ml-100k", transform=transform, pre_transform=pre_transform, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed, shuffle=shuffle, with_aux=with_aux, edge_attr_columns=("rating", "timestamp"))
@property def raw_file_names(self): return ["ml-100k.item", "ml-100k.user", "ml-100k.inter"] def _raw_exists(self): return all(osp.exists(osp.join(self.raw_dir, f)) for f in self.raw_file_names)
[docs] def download(self): if self._raw_exists(): return os.makedirs(self.raw_dir, exist_ok=True) zip_path = osp.join(self.raw_dir, "ml-100k.zip") # fetch download_url(self.url, self.raw_dir) # -> raw/ml-100k.zip # extract with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(self.raw_dir) # normalize location if nested candidates = [ self.raw_dir, osp.join(self.raw_dir, "ml-100k"), osp.join(self.raw_dir, "MovieLens-100K"), ] for base in candidates: paths = [osp.join(base, n) for n in self.raw_file_names] if all(osp.exists(p) for p in paths): if base != self.raw_dir: for src in paths: dst = osp.join(self.raw_dir, osp.basename(src)) if not osp.exists(dst): os.replace(src, dst) break assert self._raw_exists(), "ml-100k.* not found after extraction."
[docs] def read_raw(self): inter_path = osp.join(self.raw_dir, "ml-100k.inter") if not osp.exists(inter_path): raise FileNotFoundError(f"Missing: {inter_path}") interactions = pd.read_csv( inter_path, sep='\t', engine='python', skiprows=1, names=['user_id', 'item_id', 'rating', 'timestamp'] ) interactions['user_id'] = interactions['user_id'].astype(int) interactions['item_id'] = interactions['item_id'].astype(int) if not self.with_aux: return interactions items, users = None, None item_path = osp.join(self.raw_dir, "ml-100k.item") if osp.exists(item_path): items = pd.read_csv( item_path, sep='\t', engine='python', skiprows=1, names=['item_id', 'movie_title', 'release_year', 'genre'], usecols=[0, 1, 2, 3] ) items['item_id'] = items['item_id'].astype(int) user_path = osp.join(self.raw_dir, "ml-100k.user") if osp.exists(user_path): users = pd.read_csv( user_path, sep='\t', engine='python', skiprows=1, header=None, names=['user_id', 'age', 'gender', 'occupation', 'zip_code'] ) users['user_id'] = users['user_id'].astype(int) return interactions, {"items": items, "users": users}
[docs] class Yelp2018(RecSysBase): """Yelp-2018 (RecBole processed). Downloads: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Yelp/yelp2018.zip Accepts either file naming variant inside the zip: - yelp2018.item/.user/.inter (common) - yelp-2018.item/.user/.inter (also supported) After extraction, we normalize to yelp-2018.* in raw_dir. """ url = "https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Yelp/yelp2018.zip" @property def raw_file_names(self): return ["yelp-2018.item", "yelp-2018.user", "yelp-2018.inter"]
[docs] def __init__(self, root, transform=None, pre_transform=None, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42, shuffle=True, with_aux: bool=False): super().__init__(root=root, name="yelp-2018", transform=transform, pre_transform=pre_transform, train_ratio=train_ratio, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed, shuffle=shuffle, with_aux=with_aux, edge_attr_columns=("rating", "timestamp"))
def _raw_exists(self): return all(osp.exists(osp.join(self.raw_dir, f)) for f in self.raw_file_names)
[docs] def download(self): if self._raw_exists(): return os.makedirs(self.raw_dir, exist_ok=True) zip_path = osp.join(self.raw_dir, "yelp2018.zip") download_url(self.url, self.raw_dir) # -> raw/yelp2018.zip with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(self.raw_dir) # Possible inner names variants = [ ("yelp2018.item", "yelp-2018.item"), ("yelp2018.user", "yelp-2018.user"), ("yelp2018.inter", "yelp-2018.inter"), ("yelp-2018.item", "yelp-2018.item"), ("yelp-2018.user", "yelp-2018.user"), ("yelp-2018.inter", "yelp-2018.inter"), ] # search bases candidates = [ self.raw_dir, osp.join(self.raw_dir, "yelp2018"), osp.join(self.raw_dir, "yelp-2018"), osp.join(self.raw_dir, "Yelp-2018"), ] found = {} for base in candidates: for src_name, target_name in variants: src = osp.join(base, src_name) if osp.exists(src): dst = osp.join(self.raw_dir, target_name) if not osp.exists(dst): os.replace(src, dst) found[target_name] = True assert self._raw_exists(), "yelp-2018.* not found after extraction."
[docs] def read_raw(self): inter_path = osp.join(self.raw_dir, "yelp-2018.inter") if not osp.exists(inter_path): raise FileNotFoundError(f"Missing: {inter_path}") interactions = pd.read_csv( inter_path, sep='\t', engine='python', skiprows=1, names=['user_id', 'item_id', 'rating', 'timestamp', 'useful', 'funny', 'cool', 'review_id'] ) if not self.with_aux: interactions['user_id'] = interactions['user_id'].astype('category').cat.codes.astype(np.int64) interactions['item_id'] = interactions['item_id'].astype('category').cat.codes.astype(np.int64) return interactions items, users = None, None item_path = osp.join(self.raw_dir, "yelp-2018.item") if osp.exists(item_path): items = pd.read_csv( item_path, sep='\t', engine='python', skiprows=1, names=['item_id', 'item_name', 'address', 'city','state', 'postal_code', 'latitude', 'longitude', 'item_stars', 'item_review_count', 'is_open', 'categories'] ) user_path = osp.join(self.raw_dir, "yelp-2018.user") if osp.exists(user_path): users = pd.read_csv( user_path, sep='\t', engine='python', skiprows=1, names=['user_id', 'user_name', 'user_review_count', 'yelping_since', 'user_useful', 'user_funny', 'user_cool', 'elite', 'fans', 'average_stars', 'compliment_hot', 'compliment_more', 'compliment_profile', 'compliment_cute', 'compliment_list', 'compliment_note', 'compliment_plain', 'compliment_cool', 'compliment_funny', 'compliment_writer', 'compliment_photos'] ) # If both aux frames exist, remap using their unique IDs for consistency if items is not None and users is not None and not items.empty and not users.empty: item_cats = items['item_id'].astype('category').cat.categories user_cats = users['user_id'].astype('category').cat.categories item_to_id = {cat: idx for idx, cat in enumerate(item_cats)} user_to_id = {cat: idx for idx, cat in enumerate(user_cats)} interactions['item_id'] = interactions['item_id'].map(item_to_id) interactions['user_id'] = interactions['user_id'].map(user_to_id) interactions = interactions.dropna(subset=['user_id', 'item_id']) interactions['user_id'] = interactions['user_id'].astype(np.int64) interactions['item_id'] = interactions['item_id'].astype(np.int64) items = items.copy() users = users.copy() items['item_id'] = items['item_id'].map(item_to_id).astype(np.int64) users['user_id'] = users['user_id'].map(user_to_id).astype(np.int64) else: # fallback: interactions-only remap interactions['user_id'] = interactions['user_id'].astype('category').cat.codes.astype(np.int64) interactions['item_id'] = interactions['item_id'].astype('category').cat.codes.astype(np.int64) return interactions, {"items": items, "users": users}
if __name__ == '__main__': # interactions only ds = MovieLens1M(root="./data", with_aux=False) data = ds.get(0) # with auxiliary metaframes # ds_aux = MovieLens1M(root="./data", with_aux=True) # data_aux = ds_aux.get(0) # ds_aux = MovieLens100K(root="./data", with_aux=True) # data_aux = ds_aux.get(0) # pdb.set_trace() # ds_aux = Yelp2018(root="./data", with_aux=True) # data_aux = ds_aux.get(0) # pdb.set_trace()