Source code for dgl.function.reducer

"""Built-in reducer function."""
# pylint: disable=redefined-builtin
from __future__ import absolute_import

import sys

from .base import BuiltinFunction, TargetCode
from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var


class ReduceFunction(BuiltinFunction):
    """Base builtin reduce function class."""

    def _invoke(self, graph, edge_frame, out_size, edge_map=None,
                out_map=None):
        """Symbolic computation of this builtin function to create
        runtime.executor
        """
        raise NotImplementedError

    @property
    def name(self):
        """Return the name of this builtin function."""
        raise NotImplementedError


class SimpleReduceFunction(ReduceFunction):
    """Builtin reduce function that aggregates a single field into another
    single field."""
    def __init__(self, name, msg_field, out_field):
        self._name = name
        self.msg_field = msg_field
        self.out_field = out_field

    def _invoke(self, graph, edge_frame, out_size, edge_map=None,
                out_map=None):
        """Symbolic execution of this builtin function"""
        reducer = self._name
        graph = var.GRAPH(graph)
        edge_map = var.MAP(edge_map)
        out_map = var.MAP(out_map)
        edge_data = ir.READ_COL(edge_frame, var.STR(self.msg_field))
        return ir.COPY_REDUCE(reducer, graph, TargetCode.EDGE, edge_data,
                              out_size, edge_map, out_map)

    @property
    def name(self):
        return self._name


###############################################################################
# Generate all following reducer functions:
# sum, max, min, mean, prod

def _gen_reduce_builtin(reducer):
    docstring = """Builtin reduce function that aggregates messages by {0}.

    Parameters
    ----------
    msg : str
        The message field.
    out : str
        The output node feature field.
    Examples
    --------
    >>> import dgl
    >>> reduce_func = dgl.function.{0}('m', 'h')

    The above example is equivalent to the following user defined function
    (if using PyTorch):

    >>> import torch
    >>> def reduce_func(nodes):
    >>>     return {{'h': torch.{0}(nodes.mailbox['m'], dim=1)}}
    """.format(reducer)

    def func(msg, out):
        return SimpleReduceFunction(reducer, msg, out)
    func.__name__ = reducer
    func.__doc__ = docstring
    return func


__all__ = []


def _register_builtin_reduce_func():
    """Register builtin reduce functions"""
    for reduce_op in ["max", "min", "sum", "mean"]:
        builtin = _gen_reduce_builtin(reduce_op)
        setattr(sys.modules[__name__], reduce_op, builtin)
        __all__.append(reduce_op)


_register_builtin_reduce_func()