class dgl.nn.pytorch.conv.CuGraphSAGEConv(in_feats, out_feats, aggregator_type='mean', feat_drop=0.0, bias=True)[source]

Bases: torch.nn.modules.module.Module

An accelerated GraphSAGE layer from Inductive Representation Learning on Large Graphs that leverages the highly-optimized aggregation primitives in cugraph-ops:

\[ \begin{align}\begin{aligned}h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)\\h_{i}^{(l+1)} &= W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{(l+1)})\end{aligned}\end{align} \]

This module depends on pylibcugraphops package, which can be installed via conda install -c nvidia pylibcugraphops>=23.02.


This is an experimental feature.

  • in_feats (int) – Input feature size.

  • out_feats (int) – Output feature size.

  • aggregator_type (str) – Aggregator type to use (mean, sum, min, max).

  • feat_drop (float) – Dropout rate on features, default: 0.

  • bias (bool) – If True, adds a learnable bias to the output. Default: True.


>>> import dgl
>>> import torch
>>> from dgl.nn import CuGraphSAGEConv
>>> device = 'cuda'
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device)
>>> g = dgl.add_self_loop(g)
>>> feat = torch.ones(6, 10).to(device)
>>> conv = CuGraphSAGEConv(10, 2, 'mean').to(device)
>>> res = conv(g, feat)
>>> res
tensor([[-1.1690,  0.1952],
        [-1.1690,  0.1952],
        [-1.1690,  0.1952],
        [-1.1690,  0.1952],
        [-1.1690,  0.1952],
        [-1.1690,  0.1952]], device='cuda:0', grad_fn=<AddmmBackward0>)
forward(g, feat, max_in_degree=None)[source]

Forward computation.

  • g (DGLGraph) – The graph.

  • feat (torch.Tensor) – Node features. Shape: \((N, D_{in})\).

  • max_in_degree (int) – Maximum in-degree of destination nodes. It is only effective when g is a DGLBlock, i.e., bipartite graph. When g is generated from a neighbor sampler, the value should be set to the corresponding fanout. If not given, max_in_degree will be calculated on-the-fly.


Output node features. Shape: \((N, D_{out})\).

Return type



Reinitialize learnable parameters.