GNNExplainer

class dgl.nn.pytorch.explain.GNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]

Bases: Module

GNNExplainer model from GNNExplainer: Generating Explanations for Graph Neural Networks

It identifies compact subgraph structures and small subsets of node features that play a critical role in GNN-based node classification and graph classification.

To generate an explanation, it learns an edge mask \(M\) and a feature mask \(F\) by optimizing the following objective function.

\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]

where \(l\) is the loss function, \(y\) is the original model prediction, \(\hat{y}\) is the model prediction with the edge and feature mask applied, \(H\) is the entropy function.

Parameters:
  • model (nn.Module) –

    The GNN model to explain.

    • The required arguments of its forward function are graph and feat. The latter one is for input node features.

    • It should also optionally take an eweight argument for edge weights and multiply the messages by it in message passing.

    • The output of its forward function is the logits for the predicted node/graph classes.

    See also the example in explain_node() and explain_graph().

  • num_hops (int) – The number of hops for GNN information aggregation.

  • lr (float, optional) – The learning rate to use, default to 0.01.

  • num_epochs (int, optional) – The number of epochs to train.

  • alpha1 (float, optional) – A higher value will make the explanation edge masks more sparse by decreasing the sum of the edge mask.

  • alpha2 (float, optional) – A higher value will make the explanation edge masks more sparse by decreasing the entropy of the edge mask.

  • beta1 (float, optional) – A higher value will make the explanation node feature masks more sparse by decreasing the mean of the node feature mask.

  • beta2 (float, optional) – A higher value will make the explanation node feature masks more sparse by decreasing the entropy of the node feature mask.

  • log (bool, optional) – If True, it will log the computation process, default to True.

explain_graph(graph, feat, **kwargs)[source]

Learn and return a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for a graph.

Parameters:
  • graph (DGLGraph) – A homogeneous graph.

  • feat (Tensor) – The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.

  • kwargs (dict) – Additional arguments passed to the GNN model. Tensors whose first dimension is the number of nodes or edges will be assumed to be node/edge features.

Returns:

  • feat_mask (Tensor) – Learned feature importance mask of shape \((D)\), where \(D\) is the feature size. The values are within range \((0, 1)\). The higher, the more important.

  • edge_mask (Tensor) – Learned importance mask of the edges in the graph, which is a tensor of shape \((E)\), where \(E\) is the number of edges in the graph. The values are within range \((0, 1)\). The higher, the more important.

Examples

>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import AvgPooling, GNNExplainer
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...         self.pool = AvgPooling()
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return self.pool(graph, graph.ndata['h'])
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for graph 0
>>> explainer = GNNExplainer(model, num_hops=1)
>>> g, _ = data[0]
>>> features = g.ndata['attr']
>>> feat_mask, edge_mask = explainer.explain_graph(g, features)
>>> feat_mask
tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533])
>>> edge_mask
tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])
explain_node(node_id, graph, feat, **kwargs)[source]

Learn and return a node feature mask and subgraph that play a crucial role to explain the prediction made by the GNN for node node_id.

Parameters:
  • node_id (int) – The node to explain.

  • graph (DGLGraph) – A homogeneous graph.

  • feat (Tensor) – The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.

  • kwargs (dict) – Additional arguments passed to the GNN model. Tensors whose first dimension is the number of nodes or edges will be assumed to be node/edge features.

Returns:

  • new_node_id (Tensor) – The new ID of the input center node.

  • sg (DGLGraph) – The subgraph induced on the k-hop in-neighborhood of the input center node.

  • feat_mask (Tensor) – Learned node feature importance mask of shape \((D)\), where \(D\) is the feature size. The values are within range \((0, 1)\). The higher, the more important.

  • edge_mask (Tensor) – Learned importance mask of the edges in the subgraph, which is a tensor of shape \((E)\), where \(E\) is the number of edges in the subgraph. The values are within range \((0, 1)\). The higher, the more important.

Examples

>>> import dgl
>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import GNNExplainer
>>> # Load dataset
>>> data = CoraGraphDataset()
>>> g = data[0]
>>> features = g.ndata['feat']
>>> labels = g.ndata['label']
>>> train_mask = g.ndata['train_mask']
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return graph.ndata['h']
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(10):
...     logits = model(g, features)
...     loss = criterion(logits[train_mask], labels[train_mask])
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for node 10
>>> explainer = GNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features)
>>> new_center
tensor([1])
>>> sg.num_edges()
12
>>> # Old IDs of the nodes in the subgraph
>>> sg.ndata[dgl.NID]
tensor([ 9, 10, 11, 12])
>>> # Old IDs of the edges in the subgraph
>>> sg.edata[dgl.EID]
tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54])
>>> feat_mask
tensor([0.2638, 0.2738, 0.3039,  ..., 0.2794, 0.2643, 0.2733])
>>> edge_mask
tensor([0.0937, 0.1496, 0.8287, 0.8132, 0.8825, 0.8515, 0.8146, 0.0915, 0.1145,
        0.9011, 0.1311, 0.8437])
forward(*input: Any) None

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.