SubgraphX

class dgl.nn.pytorch.explain.SubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]

Bases: torch.nn.modules.module.Module

SubgraphX from On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>

It identifies the most important subgraph from the original graph that plays a critical role in GNN-based graph classification.

It employs Monte Carlo tree search (MCTS) in efficiently exploring different subgraphs for explanation and uses Shapley values as the measure of subgraph importance.

Parameters
  • model (nn.Module) –

    The GNN model to explain that tackles multiclass graph classification

    • Its forward function must have the form forward(self, graph, nfeat).

    • The output of its forward function is the logits.

  • num_hops (int) – Number of message passing layers in the model

  • coef (float, optional) – This hyperparameter controls the trade-off between exploration and exploitation. A higher value encourages the algorithm to explore relatively unvisited nodes. Default: 10.0

  • high2low (bool, optional) – If True, it will use the “High2low” strategy for pruning actions, expanding children nodes from high degree to low degree when extending the children nodes in the search tree. Otherwise, it will use the “Low2high” strategy. Default: True

  • num_child (int, optional) – This is the number of children nodes to expand when extending the children nodes in the search tree. Default: 12

  • num_rollouts (int, optional) – This is the number of rollouts for MCTS. Default: 20

  • node_min (int, optional) – This is the threshold to define a leaf node based on the number of nodes in a subgraph. Default: 3

  • shapley_steps (int, optional) – This is the number of steps for Monte Carlo sampling in estimating Shapley values. Default: 100

  • log (bool, optional) – If True, it will log the progress. Default: False

explain_graph(graph, feat, target_class, **kwargs)[source]

Find the most important subgraph from the original graph for the model to classify the graph into the target class.

Parameters
  • graph (DGLGraph) – A homogeneous graph

  • feat (Tensor) – The input node feature of shape \((N, D)\), \(N\) is the number of nodes, and \(D\) is the feature size

  • target_class (int) – The target class to explain

  • kwargs (dict) – Additional arguments passed to the GNN model

Returns

Nodes that represent the most important subgraph

Return type

Tensor

Examples

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_dim, n_classes, hidden_dim=128):
...         super().__init__()
...         self.conv1 = GraphConv(in_dim, hidden_dim)
...         self.conv2 = GraphConv(hidden_dim, n_classes)
...         self.pool = AvgPooling()
...
...     def forward(self, g, h):
...         h = F.relu(self.conv1(g, h))
...         h = self.conv2(g, h)
...         return self.pool(g, h)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
...                                           target_class=l)
forward(*input: Any)None

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.