"""Torch Module for SubgraphX"""
import math
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from .... import to_heterogeneous, to_homogeneous
from ....base import NID
from ....convert import to_networkx
from ....subgraph import node_subgraph
from ....transforms.functional import remove_nodes
__all__ = ["SubgraphX", "HeteroSubgraphX"]
class MCTSNode:
r"""Monte Carlo Tree Search Node
nodes : Tensor
The node IDs of the graph that are associated with this tree node
def __init__(self, nodes):
self.nodes = nodes
self.num_visit = 0
self.total_reward = 0.0
self.immediate_reward = 0.0
self.children = []
def __repr__(self):
r"""Get the string representation of the node.
The string representation of the node
return str(self.nodes)
[docs]class SubgraphX(nn.Module):
r"""SubgraphX from `On Explainability of Graph Neural Networks via Subgraph
Explorations <https://arxiv.org/abs/2102.05152>`
It identifies the most important subgraph from the original graph that
plays a critical role in GNN-based graph classification.
It employs Monte Carlo tree search (MCTS) in efficiently exploring
different subgraphs for explanation and uses Shapley values as the measure
of subgraph importance.
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat)`.
* The output of its forward function is the logits.
num_hops : int
Number of message passing layers in the model
coef : float, optional
This hyperparameter controls the trade-off between exploration and
exploitation. A higher value encourages the algorithm to explore
relatively unvisited nodes. Default: 10.0
high2low : bool, optional
If True, it will use the "High2low" strategy for pruning actions,
expanding children nodes from high degree to low degree when extending
the children nodes in the search tree. Otherwise, it will use the
"Low2high" strategy. Default: True
num_child : int, optional
This is the number of children nodes to expand when extending the
children nodes in the search tree. Default: 12
num_rollouts : int, optional
This is the number of rollouts for MCTS. Default: 20
node_min : int, optional
This is the threshold to define a leaf node based on the number of
nodes in a subgraph. Default: 3
shapley_steps : int, optional
This is the number of steps for Monte Carlo sampling in estimating
Shapley values. Default: 100
log : bool, optional
If True, it will log the progress. Default: False
def __init__(
self.num_hops = num_hops
self.coef = coef
self.high2low = high2low
self.num_child = num_child
self.num_rollouts = num_rollouts
self.node_min = node_min
self.shapley_steps = shapley_steps
self.log = log
self.model = model
def shapley(self, subgraph_nodes):
r"""Compute Shapley value with Monte Carlo approximation.
subgraph_nodes : tensor
The tensor node ids of the subgraph that are associated with this
tree node
Shapley value
num_nodes = self.graph.num_nodes()
subgraph_nodes = subgraph_nodes.tolist()
# Obtain neighboring nodes of the subgraph g_i, P'.
local_region = subgraph_nodes
for _ in range(self.num_hops - 1):
in_neighbors, _ = self.graph.in_edges(local_region)
_, out_neighbors = self.graph.out_edges(local_region)
neighbors = torch.cat([in_neighbors, out_neighbors]).tolist()
local_region = list(set(local_region + neighbors))
split_point = num_nodes
coalition_space = list(set(local_region) - set(subgraph_nodes)) + [
marginal_contributions = []
device = self.feat.device
for _ in range(self.shapley_steps):
permuted_space = np.random.permutation(coalition_space)
split_idx = int(np.where(permuted_space == split_point)[0])
selected_nodes = permuted_space[:split_idx]
# Mask for coalition set S_i
exclude_mask = torch.ones(num_nodes)
exclude_mask[local_region] = 0.0
exclude_mask[selected_nodes] = 1.0
# Mask for set S_i and g_i
include_mask = exclude_mask.clone()
include_mask[subgraph_nodes] = 1.0
exclude_feat = self.feat * exclude_mask.unsqueeze(1).to(device)
include_feat = self.feat * include_mask.unsqueeze(1).to(device)
with torch.no_grad():
exclude_probs = self.model(
self.graph, exclude_feat, **self.kwargs
exclude_value = exclude_probs[:, self.target_class]
include_probs = self.model(
self.graph, include_feat, **self.kwargs
include_value = include_probs[:, self.target_class]
marginal_contributions.append(include_value - exclude_value)
return torch.cat(marginal_contributions).mean().item()
def get_mcts_children(self, mcts_node):
r"""Get the children of the MCTS node for the search.
mcts_node : MCTSNode
Node in MCTS
Children nodes after pruning
if len(mcts_node.children) > 0:
return mcts_node.children
subg = node_subgraph(self.graph, mcts_node.nodes)
node_degrees = subg.out_degrees() + subg.in_degrees()
k = min(subg.num_nodes(), self.num_child)
chosen_nodes = torch.topk(
node_degrees, k, largest=self.high2low
mcts_children_maps = dict()
for node in chosen_nodes:
new_subg = remove_nodes(subg, node.to(subg.idtype), store_ids=True)
# Get the largest weakly connected component in the subgraph.
nx_graph = to_networkx(new_subg.cpu())
largest_cc_nids = list(
max(nx.weakly_connected_components(nx_graph), key=len)
# Map to the original node IDs.
largest_cc_nids = new_subg.ndata[NID][largest_cc_nids].long()
largest_cc_nids = subg.ndata[NID][largest_cc_nids].sort().values
if str(largest_cc_nids) not in self.mcts_node_maps:
child_mcts_node = MCTSNode(largest_cc_nids)
self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node
child_mcts_node = self.mcts_node_maps[str(largest_cc_nids)]
if str(child_mcts_node) not in mcts_children_maps:
mcts_children_maps[str(child_mcts_node)] = child_mcts_node
mcts_node.children = list(mcts_children_maps.values())
for child_mcts_node in mcts_node.children:
if child_mcts_node.immediate_reward == 0:
child_mcts_node.immediate_reward = self.shapley(
return mcts_node.children
def mcts_rollout(self, mcts_node):
r"""Perform a MCTS rollout.
mcts_node : MCTSNode
Starting node for MCTS
Reward for visiting the node this time
if len(mcts_node.nodes) <= self.node_min:
return mcts_node.immediate_reward
children_nodes = self.get_mcts_children(mcts_node)
children_visit_sum = sum([child.num_visit for child in children_nodes])
children_visit_sum_sqrt = math.sqrt(children_visit_sum)
chosen_child = max(
key=lambda c: c.total_reward / max(c.num_visit, 1)
+ self.coef
* c.immediate_reward
* children_visit_sum_sqrt
/ (1 + c.num_visit),
reward = self.mcts_rollout(chosen_child)
chosen_child.num_visit += 1
chosen_child.total_reward += reward
return reward
[docs] def explain_graph(self, graph, feat, target_class, **kwargs):
r"""Find the most important subgraph from the original graph for the
model to classify the graph into the target class.
graph : DGLGraph
A homogeneous graph
feat : Tensor
The input node feature of shape :math:`(N, D)`, :math:`N` is the
number of nodes, and :math:`D` is the feature size
target_class : int
The target class to explain
kwargs : dict
Additional arguments passed to the GNN model
Nodes that represent the most important subgraph
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_dim, n_classes, hidden_dim=128):
... super().__init__()
... self.conv1 = GraphConv(in_dim, hidden_dim)
... self.conv2 = GraphConv(hidden_dim, n_classes)
... self.pool = AvgPooling()
... def forward(self, g, h):
... h = F.relu(self.conv1(g, h))
... h = self.conv2(g, h)
... return self.pool(g, h)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
... logits = model(bg, bg.ndata['attr'])
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
... target_class=l)
assert (
graph.num_nodes() > self.node_min
), f"The number of nodes in the\
graph {graph.num_nodes()} should be bigger than {self.node_min}."
self.graph = graph
self.feat = feat
self.target_class = target_class
self.kwargs = kwargs
# book all nodes in MCTS
self.mcts_node_maps = dict()
root = MCTSNode(graph.nodes())
self.mcts_node_maps[str(root)] = root
for i in range(self.num_rollouts):
if self.log:
f"Rollout {i}/{self.num_rollouts}, \
{len(self.mcts_node_maps)} subgraphs have been explored."
best_leaf = None
best_immediate_reward = float("-inf")
for mcts_node in self.mcts_node_maps.values():
if len(mcts_node.nodes) > self.node_min:
if mcts_node.immediate_reward > best_immediate_reward:
best_leaf = mcts_node
best_immediate_reward = best_leaf.immediate_reward
return best_leaf.nodes
[docs]class HeteroSubgraphX(nn.Module):
r"""SubgraphX from `On Explainability of Graph Neural Networks via Subgraph
Explorations <https://arxiv.org/abs/2102.05152>`__, adapted for heterogeneous graphs
It identifies the most important subgraph from the original graph that
plays a critical role in GNN-based graph classification.
It employs Monte Carlo tree search (MCTS) in efficiently exploring
different subgraphs for explanation and uses Shapley values as the measure
of subgraph importance.
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat)`.
* The output of its forward function is the logits.
num_hops : int
Number of message passing layers in the model
coef : float, optional
This hyperparameter controls the trade-off between exploration and
exploitation. A higher value encourages the algorithm to explore
relatively unvisited nodes. Default: 10.0
high2low : bool, optional
If True, it will use the "High2low" strategy for pruning actions,
expanding children nodes from high degree to low degree when extending
the children nodes in the search tree. Otherwise, it will use the
"Low2high" strategy. Default: True
num_child : int, optional
This is the number of children nodes to expand when extending the
children nodes in the search tree. Default: 12
num_rollouts : int, optional
This is the number of rollouts for MCTS. Default: 20
node_min : int, optional
This is the threshold to define a leaf node based on the number of
nodes in a subgraph. Default: 3
shapley_steps : int, optional
This is the number of steps for Monte Carlo sampling in estimating
Shapley values. Default: 100
log : bool, optional
If True, it will log the progress. Default: False
def __init__(
self.num_hops = num_hops
self.coef = coef
self.high2low = high2low
self.num_child = num_child
self.num_rollouts = num_rollouts
self.node_min = node_min
self.shapley_steps = shapley_steps
self.log = log
self.model = model
def shapley(self, subgraph_nodes):
r"""Compute Shapley value with Monte Carlo approximation.
subgraph_nodes : dict[str, Tensor]
subgraph_nodes[nty] gives the tensor node IDs of node type nty
in the subgraph, which are associated with this tree node
Shapley value
# Obtain neighboring nodes of the subgraph g_i, P'.
local_regions = {
ntype: nodes.tolist() for ntype, nodes in subgraph_nodes.items()
for _ in range(self.num_hops - 1):
for c_etype in self.graph.canonical_etypes:
src_ntype, _, dst_ntype = c_etype
if (
src_ntype not in local_regions
or dst_ntype not in local_regions
in_neighbors, _ = self.graph.in_edges(
local_regions[dst_ntype], etype=c_etype
_, out_neighbors = self.graph.out_edges(
local_regions[src_ntype], etype=c_etype
local_regions[src_ntype] = list(
set(local_regions[src_ntype] + in_neighbors.tolist())
local_regions[dst_ntype] = list(
set(local_regions[dst_ntype] + out_neighbors.tolist())
split_point = self.graph.num_nodes()
coalition_space = {
ntype: list(
set(local_regions[ntype]) - set(subgraph_nodes[ntype].tolist())
+ [split_point]
for ntype in subgraph_nodes.keys()
marginal_contributions = []
for _ in range(self.shapley_steps):
selected_node_map = dict()
for ntype, nodes in coalition_space.items():
permuted_space = np.random.permutation(nodes)
split_idx = int(np.where(permuted_space == split_point)[0])
selected_node_map[ntype] = permuted_space[:split_idx]
# Mask for coalition set S_i
exclude_mask = {
ntype: torch.ones(self.graph.num_nodes(ntype))
for ntype in self.graph.ntypes
for ntype, region in local_regions.items():
exclude_mask[ntype][region] = 0.0
for ntype, selected_nodes in selected_node_map.items():
exclude_mask[ntype][selected_nodes] = 1.0
# Mask for set S_i and g_i
include_mask = {
ntype: exclude_mask[ntype].clone()
for ntype in self.graph.ntypes
for ntype, subgn in subgraph_nodes.items():
exclude_mask[ntype][subgn] = 1.0
exclude_feat = {
ntype: self.feat[ntype]
* exclude_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)
for ntype in self.graph.ntypes
include_feat = {
ntype: self.feat[ntype]
* include_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)
for ntype in self.graph.ntypes
with torch.no_grad():
exclude_probs = self.model(
self.graph, exclude_feat, **self.kwargs
exclude_value = exclude_probs[:, self.target_class]
include_probs = self.model(
self.graph, include_feat, **self.kwargs
include_value = include_probs[:, self.target_class]
marginal_contributions.append(include_value - exclude_value)
return torch.cat(marginal_contributions).mean().item()
def get_mcts_children(self, mcts_node):
r"""Get the children of the MCTS node for the search.
mcts_node : MCTSNode
Node in MCTS
Children nodes after pruning
if len(mcts_node.children) > 0:
return mcts_node.children
subg = node_subgraph(self.graph, mcts_node.nodes)
# Choose k nodes based on the highest degree in the subgraph
node_degrees_map = {
ntype: torch.zeros(
subg.num_nodes(ntype), device=subg.nodes(ntype).device
for ntype in subg.ntypes
for c_etype in subg.canonical_etypes:
src_ntype, _, dst_ntype = c_etype
node_degrees_map[src_ntype] += subg.out_degrees(etype=c_etype)
node_degrees_map[dst_ntype] += subg.in_degrees(etype=c_etype)
node_degrees_list = [
((ntype, i), degree)
for ntype, node_degrees in node_degrees_map.items()
for i, degree in enumerate(node_degrees)
node_degrees = torch.stack([v for _, v in node_degrees_list])
k = min(subg.num_nodes(), self.num_child)
chosen_node_indicies = torch.topk(
node_degrees, k, largest=self.high2low
chosen_nodes = [node_degrees_list[i][0] for i in chosen_node_indicies]
mcts_children_maps = dict()
for ntype, node in chosen_nodes:
new_subg = remove_nodes(subg, node, ntype, store_ids=True)
if new_subg.num_edges() > 0:
new_subg_homo = to_homogeneous(new_subg)
# Get the largest weakly connected component in the subgraph.
nx_graph = to_networkx(new_subg_homo.cpu())
largest_cc_nids = list(
max(nx.weakly_connected_components(nx_graph), key=len)
largest_cc_homo = node_subgraph(new_subg_homo, largest_cc_nids)
largest_cc_hetero = to_heterogeneous(
largest_cc_homo, new_subg.ntypes, new_subg.etypes
# Follow steps for backtracking to original graph node ids
# 1. retrieve instanced homograph from connected-component homograph
# 2. retrieve instanced heterograph from instanced homograph
# 3. retrieve hetero-subgraph from instanced heterograph
# 4. retrieve orignal graph ids from subgraph node ids
cc_nodes = {
ntype: subg.ndata[NID][ntype][
for ntype, indicies in largest_cc_hetero.ndata[NID].items()
available_ntypes = [
for ntype in new_subg.ntypes
if new_subg.num_nodes(ntype) > 0
chosen_ntype = np.random.choice(available_ntypes)
# backtrack from subgraph node ids to entire graph
chosen_node = subg.ndata[NID][chosen_ntype][
cc_nodes = {
chosen_ntype: torch.tensor(
if str(cc_nodes) not in self.mcts_node_maps:
child_mcts_node = MCTSNode(cc_nodes)
self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node
child_mcts_node = self.mcts_node_maps[str(cc_nodes)]
if str(child_mcts_node) not in mcts_children_maps:
mcts_children_maps[str(child_mcts_node)] = child_mcts_node
mcts_node.children = list(mcts_children_maps.values())
for child_mcts_node in mcts_node.children:
if child_mcts_node.immediate_reward == 0:
child_mcts_node.immediate_reward = self.shapley(
return mcts_node.children
def mcts_rollout(self, mcts_node):
r"""Perform a MCTS rollout.
mcts_node : MCTSNode
Starting node for MCTS
Reward for visiting the node this time
if (
sum(len(nodes) for nodes in mcts_node.nodes.values())
<= self.node_min
return mcts_node.immediate_reward
children_nodes = self.get_mcts_children(mcts_node)
children_visit_sum = sum([child.num_visit for child in children_nodes])
children_visit_sum_sqrt = math.sqrt(children_visit_sum)
chosen_child = max(
key=lambda c: c.total_reward / max(c.num_visit, 1)
+ self.coef
* c.immediate_reward
* children_visit_sum_sqrt
/ (1 + c.num_visit),
reward = self.mcts_rollout(chosen_child)
chosen_child.num_visit += 1
chosen_child.total_reward += reward
return reward
[docs] def explain_graph(self, graph, feat, target_class, **kwargs):
r"""Find the most important subgraph from the original graph for the
model to classify the graph into the target class.
graph : DGLGraph
A heterogeneous graph
feat : dict[str, Tensor]
The dictionary that associates input node features (values) with
the respective node types (keys) present in the graph.
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
number of nodes for node type :math:`t`, and :math:`D_t` is the feature size for
node type :math:`t`
target_class : int
The target class to explain
kwargs : dict
Additional arguments passed to the GNN model
dict[str, Tensor]
The dictionary associating tensor node ids (values) to
node types (keys) that represents the most important subgraph
>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroSubgraphX
>>> class Model(nn.Module):
... def __init__(self, in_dim, num_classes, canonical_etypes):
... super(Model, self).__init__()
... self.etype_weights = nn.ModuleDict(
... {
... "_".join(c_etype): nn.Linear(in_dim, num_classes)
... for c_etype in canonical_etypes
... }
... )
... def forward(self, graph, feat):
... with graph.local_scope():
... c_etype_func_dict = {}
... for c_etype in graph.canonical_etypes:
... src_type, etype, dst_type = c_etype
... wh = self.etype_weights["_".join(c_etype)](feat[src_type])
... graph.nodes[src_type].data[f"h_{c_etype}"] = wh
... c_etype_func_dict[c_etype] = (
... fn.copy_u(f"h_{c_etype}", "m"),
... fn.mean("m", "h"),
... )
... graph.multi_update_all(c_etype_func_dict, "sum")
... hg = 0
... for ntype in graph.ntypes:
... if graph.num_nodes(ntype):
... hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
... return hg
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata["h"]
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, feat)
... loss = F.cross_entropy(logits, th.tensor([1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Explain for the graph
>>> explainer = HeteroSubgraphX(model, num_hops=1)
>>> explainer.explain_graph(g, feat, target_class=1)
{'game': tensor([0, 1]), 'user': tensor([1, 2])}
assert (
graph.num_nodes() > self.node_min
), f"The number of nodes in the\
graph {graph.num_nodes()} should be bigger than {self.node_min}."
self.graph = graph
self.feat = feat
self.target_class = target_class
self.kwargs = kwargs
# book all nodes in MCTS
self.mcts_node_maps = dict()
root_dict = {ntype: graph.nodes(ntype) for ntype in graph.ntypes}
root = MCTSNode(root_dict)
self.mcts_node_maps[str(root)] = root
for i in range(self.num_rollouts):
if self.log:
f"Rollout {i}/{self.num_rollouts}, \
{len(self.mcts_node_maps)} subgraphs have been explored."
best_leaf = None
best_immediate_reward = float("-inf")
for mcts_node in self.mcts_node_maps.values():
if len(mcts_node.nodes) > self.node_min:
if mcts_node.immediate_reward > best_immediate_reward:
best_leaf = mcts_node
best_immediate_reward = best_leaf.immediate_reward
return best_leaf.nodes