Source code for jittor_geometric.nn.conv.gat_conv

'''
Description: 
Author: lusz
Date: 2024-06-26 10:57:06
'''
from typing import Optional, Tuple
from jittor_geometric.typing import Adj, OptVar

import jittor as jt
from jittor import Var
from jittor_geometric.nn.conv import MessagePassingNts
from jittor_geometric.utils import add_remaining_self_loops
from jittor_geometric.utils.num_nodes import maybe_num_nodes

from ..inits import glorot, zeros
from jittor_geometric.data import CSC,CSR
from jittor_geometric.ops import ScatterToEdge,EdgeSoftmax,aggregateWithWeight,ScatterToVertex


[docs] class GATConv(MessagePassingNts): r"""The graph convolutional operator from the `"Graph Attention Networks" 2018 ICLR _ paper """ _cached_edge_index: Optional[Tuple[Var, Var]] _cached_csc: Optional[CSC] def __init__(self, in_channels: int, out_channels: int,e_num: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs): kwargs.setdefault('aggr', 'add') super(GATConv, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.improved = improved self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_edge_index = None self._cached_adj_t = None self.weight = jt.random((in_channels, out_channels)) self.edge_weight=jt.random((2*out_channels,1)) self.reset_parameters()
[docs] def reset_parameters(self): glorot(self.weight) glorot(self.edge_weight) self._cached_adj_t = None self._cached_csc=None
[docs] def execute(self, x: Var, csc: CSC) -> Var: """""" out=self.vertex_forward(x) out = self.propagate(x=out,csc=csc) return out
[docs] def propagate(self,x,csc): e_msg=self.scatter_to_edge(x,csc) out = self.edge_forward(e_msg,csc) out=self.scatter_to_vertex(out,csc) return out
[docs] def scatter_to_edge(self,x,csc)->Var: out1=ScatterToEdge(x,csc,"src") out2=ScatterToEdge(x,csc,"dst") out =jt.contrib.concat([out1,out2],dim=1) return out
[docs] def edge_forward(self,x,csc)->Var: out = x @ self.edge_weight m=jt.nn.leaky_relu(out,scale=0.2) a=EdgeSoftmax(m,csc) half_dim=int(jt.size(x,1)/2) e_msg=x[:,0:half_dim] return e_msg * a
[docs] def scatter_to_vertex(self,edge,csc)->Var: out=ScatterToVertex(edge,csc,'src') return out
[docs] def vertex_forward(self,x:Var)->Var: out = x @ self.weight out=jt.nn.relu(out) return out
def __repr__(self): return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels)