import warnings
from typing import Optional
import jittor
from jittor import Var
import jittor_geometric.typing
from jittor_geometric.typing import jt_scatter
import numpy as np
major, minor, _ = jittor.__version__.split('.', maxsplit=2)
major, minor = int(major), int(minor)
has_pyjittor112 = major > 1 or (major == 1 and minor >= 12)
if True: # pragma: no cover
warnings.filterwarnings('ignore', '.*is in beta and the API may change.*')
def broadcast(src: Var, ref: Var, dim: int) -> Var:
size = [1] * ref.dim()
size[dim] = -1
return src.view(size).expand_as(ref)
def scatter(src: Var, index: Var, dim: int = 0,
dim_size: Optional[int] = None, reduce: str = 'sum') -> Var:
r"""Reduces all values from the :obj:`src` tensor at the indices
specified in the :obj:`index` tensor along a given dimension
:obj:`dim`. See the `documentation
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
scatter.html>`__ of the :obj:`torch_scatter` package for more
information.
Args:
src (torch.Var): The source tensor.
index (torch.Var): The index tensor.
dim (int, optional): The dimension along which to index.
(default: :obj:`0`)
dim_size (int, optional): The size of the output tensor at
dimension :obj:`dim`. If set to :obj:`None`, will create a
minimal-sized output tensor according to
:obj:`index.max() + 1`. (default: :obj:`None`)
reduce (str, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
"""
if index.dim() != 1:
raise ValueError(f"The `index` argument must be one-dimensional "
f"(got {index.dim()} dimensions)")
dim = src.dim() + dim if dim < 0 else dim
if dim < 0 or dim >= src.dim():
raise ValueError(f"The `dim` argument must lay between 0 and "
f"{src.dim() - 1} (got {dim})")
if dim_size is None:
# dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
if index.numel() > 0:
# 先将tensor移到CPU计算max
index_cpu = index.numpy()
dim_size = int(np.max(index_cpu)) + 1
else:
dim_size = 0
# For now, we maintain various different code paths, based on whether
# the input requires gradients and whether it lays on the CPU/GPU.
# For example, `torch_scatter` is usually faster than
# `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster
# on CPU.
# `torch.scatter_reduce` has a faster forward implementation for
# "min"/"max" reductions since it does not compute additional arg
# indices, but is therefore way slower in its backward implementation.
# More insights can be found in `test/utils/test_scatter.py`.
size = list(src.size())
size[dim] = dim_size
# For "sum" and "mean" reduction, we make use of `scatter_add_`:
if reduce == 'sum' or reduce == 'add':
index = broadcast(index, src, dim)
return jittor.scatter(src.new_zeros(size), dim, index, src, reduce=reduce)
if reduce == 'mean':
count = src.new_zeros(dim_size)
count.scatter_add_(0, index, src.new_ones(src.size(dim)))
count = count.clamp(min=1)
index = broadcast(index, src, dim)
out = src.new_zeros(size).scatter_add_(dim, index, src)
return out / broadcast(count, out, dim)
# For "min" and "max" reduction, we prefer `scatter_reduce_` on CPU or
# in case the input does not require gradients:
if reduce == 'min' or reduce == 'max':
if (not jittor_geometric.typing.WITH_TORCH_SCATTER
or not src.is_cuda or not src.requires_grad):
if src.is_cuda and src.requires_grad:
warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
f"can be accelerated via the 'torch-scatter'"
f" package, but it was not found")
index = broadcast(index, src, dim)
return src.new_zeros(size).scatter_reduce_(
dim, index, src, reduce=f'a{reduce}', include_self=False)
return jt_scatter.scatter(src, index, dim, dim_size=dim_size,
reduce=reduce)
# For "mul" reduction, we prefer `scatter_reduce_` on CPU:
if reduce == 'mul':
if (not jittor_geometric.typing.WITH_TORCH_SCATTER
or not src.is_cuda):
if src.is_cuda:
warnings.warn(f"The usage of `scatter(reduce='{reduce}')` "
f"can be accelerated via the 'torch-scatter'"
f" package, but it was not found")
index = broadcast(index, src, dim)
# We initialize with `one` here to match `scatter_mul` output:
return src.new_ones(size).scatter_reduce_(
dim, index, src, reduce='prod', include_self=True)
return jt_scatter.scatter(src, index, dim, dim_size=dim_size,
reduce='mul')
raise ValueError(f"Encountered invalid `reduce` argument '{reduce}'")
else:
[docs]
def scatter(src: Var, index: Var, dim: int = 0,
dim_size: Optional[int] = None, reduce: str = 'sum') -> Var:
r"""Reduces all values from the :obj:`src` tensor at the indices
specified in the :obj:`index` tensor along a given dimension
:obj:`dim`. See the `documentation
<https://pytorch-scatter.readthedocs.io/en/latest/functions/
scatter.html>`_ of the :obj:`torch_scatter` package for more
information.
Args:
src (torch.Var): The source tensor.
index (torch.Var): The index tensor.
dim (int, optional): The dimension along which to index.
(default: :obj:`0`)
dim_size (int, optional): The size of the output tensor at
dimension :obj:`dim`. If set to :obj:`None`, will create a
minimal-sized output tensor according to
:obj:`index.max() + 1`. (default: :obj:`None`)
reduce (str, optional): The reduce operation (:obj:`"sum"`,
:obj:`"mean"`, :obj:`"mul"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"sum"`)
"""
#if not jittor_geometric.typing.WITH_TORCH_SCATTER:
# raise ImportError("'scatter' requires the 'torch-scatter' package")
return jt_scatter.scatter(src, index, dim, dim_size=dim_size,
reduce=reduce)