HeteroPGExplainerο
- class dgl.nn.pytorch.explain.HeteroPGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]ο
Bases:
PGExplainer
PGExplainer from Parameterized Explainer for Graph Neural Network, adapted for heterogeneous graphs
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation process of explanations, which enables it to explain multiple instances collectively. PGExplainer models the underlying structure as edge distributions, from which the explanatory graph is sampled.
- Parameters:
model (nn.Module) β
The GNN model to explain that tackles multiclass graph classification
Its forward function must have the form
forward(self, graph, nfeat, embed, edge_weight)
.The output of its forward function is the logits if embed=False else the intermediate node embeddings.
num_features (int) β Node embedding size used by
model
.coff_budget (float, optional) β Size regularization to constrain the explanation size. Default: 0.01.
coff_connect (float, optional) β Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
sample_bias (float, optional) β Some members of a population are systematically more likely to be selected in a sample than others. Default: 0.0.
- explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]ο
Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for a graph. Also, return the prediction made with the edges chosen based on the edge mask.
- Parameters:
graph (DGLGraph) β A heterogeneous graph.
feat (dict[str, Tensor]) β A dict mapping node types (keys) to feature tensors (values). The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
temperature (float) β The temperature parameter fed to the sampling procedure.
training (bool) β Training the explanation network.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns:
Tensor β Classification probabilities given the masked graph. It is a tensor of shape \((B, L)\), where \(L\) is the different types of label in the dataset, and \(B\) is the batch size.
dict[str, Tensor] β A dict mapping edge types (keys) to edge tensors (values) of shape \((E_t)\), where \(E_t\) is the number of edges in the graph for edge type \(t\). A higher weight suggests a larger contribution of the edge.
Examples
>>> import dgl >>> import torch as th >>> import torch.nn as nn >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, hid_feats, out_feats, rel_names): ... super().__init__() ... self.conv = dgl.nn.HeteroGraphConv( ... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names}, ... aggregate="sum", ... ) ... self.fc = nn.Linear(hid_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... if edge_weight: ... mod_kwargs = { ... etype: {"edge_weight": mask} for etype, mask in edge_weight.items() ... } ... h = self.conv(g, h, mod_kwargs=mod_kwargs) ... else: ... h = self.conv(g, h) ... ... if embed: ... return h ... ... with g.local_scope(): ... g.ndata["h"] = h ... hg = 0 ... for ntype in g.ntypes: ... hg = hg + dgl.mean_nodes(g, "h", ntype=ntype) ... return self.fc(hg)
>>> # Load dataset >>> input_dim = 5 >>> hidden_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes) >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, g.ndata["h"]) ... loss = th.nn.functional.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... loss = explainer.train_step(g, g.ndata["h"], tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the graph >>> feat = g.ndata.pop("h") >>> probs, edge_mask = explainer.explain_graph(g, feat)
- explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]ο
Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for provided set of node IDs. Also, return the prediction made with the batched graph and edge mask.
- Parameters:
nodes (dict[str, Iterable[int]]) β A dict mapping node types (keys) to an iterable set of node ids (values).
graph (DGLGraph) β A heterogeneous graph.
feat (dict[str, Tensor]) β A dict mapping node types (keys) to feature tensors (values). The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
temperature (float) β The temperature parameter fed to the sampling procedure.
training (bool) β Training the explanation network.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns:
dict[str, Tensor] β A dict mapping node types (keys) to classification probabilities for node labels (values). The values are tensors of shape \((N_t, L)\), where \(L\) is the different types of node labels in the dataset, and \(N_t\) is the number of nodes in the graph for node type \(t\).
dict[str, Tensor] β A dict mapping edge types (keys) to edge tensors (values) of shape \((E_t)\), where \(E_t\) is the number of edges in the graph for edge type \(t\). A higher weight suggests a larger contribution of the edge.
DGLGraph β The batched set of subgraphs induced on the k-hop in-neighborhood of the input center nodes.
dict[str, Tensor] β A dict mapping node types (keys) to a tensor of node IDs (values) which correspond to the subgraph center nodes.
Examples
>>> import dgl >>> import torch as th >>> import torch.nn as nn >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, hid_feats, out_feats, rel_names): ... super().__init__() ... self.conv = dgl.nn.HeteroGraphConv( ... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names}, ... aggregate="sum", ... ) ... self.fc = nn.Linear(hid_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... if edge_weight: ... mod_kwargs = { ... etype: {"edge_weight": mask} for etype, mask in edge_weight.items() ... } ... h = self.conv(g, h, mod_kwargs=mod_kwargs) ... else: ... h = self.conv(g, h) ... ... return h
>>> # Load dataset >>> input_dim = 5 >>> hidden_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes) >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, g.ndata["h"])['user'] ... loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.HeteroPGExplainer( ... model, hidden_dim, num_hops=2, explain_graph=False ... )
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... loss = explainer.train_step_node( ... { ntype: g.nodes(ntype) for ntype in g.ntypes }, ... g, g.ndata["h"], tmp ... ) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the graph >>> feat = g.ndata.pop("h") >>> probs, edge_mask, bg, inverse_indices = explainer.explain_node( ... { "user": [0] }, g, feat ... )
- forward(*input: Any) None ο
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- train_step(graph, feat, temperature, **kwargs)[source]ο
Compute the loss of the explanation network for graph classification
- Parameters:
graph (DGLGraph) β Input batched heterogeneous graph.
feat (dict[str, Tensor]) β A dict mapping node types (keys) to feature tensors (values). The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
temperature (float) β The temperature parameter fed to the sampling procedure.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns:
A scalar tensor representing the loss.
- Return type:
Tensor
- train_step_node(nodes, graph, feat, temperature, **kwargs)[source]ο
Compute the loss of the explanation network for node classification
- Parameters:
nodes (dict[str, Iterable[int]]) β A dict mapping node types (keys) to an iterable set of node ids (values).
graph (DGLGraph) β Input heterogeneous graph.
feat (dict[str, Tensor]) β A dict mapping node types (keys) to feature tensors (values). The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
temperature (float) β The temperature parameter fed to the sampling procedure.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns:
A scalar tensor representing the loss.
- Return type:
Tensor