Source code for webdnn.frontend.chainer.converter

# -*- coding:utf-8 -*-

"""
Chainer Link -> Graph object converters
Assuming Chainer >=1.23-1.24, <5.0.0
"""
import traceback
import warnings
from typing import List, Union, Sequence, Dict, Tuple

import numpy as np

from webdnn.frontend.chainer.placeholder_variable import PlaceholderVariable
from webdnn.frontend.converter import Converter, CyclicGraphError
from webdnn.frontend.util import semver
from webdnn.graph.graph import Graph
from webdnn.graph.order import Order
from webdnn.graph.variable import Variable
from webdnn.graph.variables.attributes.input import Input
from webdnn.graph.variables.attributes.output import Output
from webdnn.graph.variables.constant_variable import ConstantVariable
from webdnn.util import console

# Function and type declarations (actual implementation is depend on installed Chainer version)

# T_FUNCTION: type of function node in computation graph
# T_VARIABLE: type of variable node in computation graph
T_NODE = Union["T_VARIABLE", "T_FUNCTION"]


def get_variable_data(variable: "T_VARIABLE") -> Union[np.ndarray, "chainer.cuda.ndarray"]: ...  # return variable's data


def to_variable_node(c_var: "chainer.Variable") -> "T_VARIABLE": ...  # convert "chainer.Variable" into variable node (T_VARIABLE instance)


FLAG_CHAINER_INSTALLED = False

try:
    import chainer
    import chainer.computational_graph

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__)

    if VERSION_MAJOR >= 3:
        # v3.x.x
        if VERSION_MAJOR >= 5:
            warnings.warn(f"WebDNN does not support Chainer version >= 5. Currently, Chainer {chainer.__version__} is installed.")

        # In v3, Many functions are represented as instance of `chainer.function_node.FunctionNode`. However some functions are still
        # instance of `chainer.function.Function` (ex. Im2Col).
        T_FUNCTION = (chainer.FunctionNode, chainer.Function)
        T_VARIABLE = chainer.variable.VariableNode


        def get_variable_data(variable: T_VARIABLE):
            # noinspection PyProtectedMember
            return variable._variable().data if variable.data is None else variable.data


        def to_variable_node(c_var: chainer.Variable):
            return c_var.node

    elif VERSION_MAJOR == 2:
        # v2.x.x
        T_FUNCTION = chainer.Function
        T_VARIABLE = chainer.variable.VariableNode


        def get_variable_data(variable: T_VARIABLE):
            # noinspection PyProtectedMember
            return variable._variable().data if variable.data is None else variable.data


        def to_variable_node(c_var: chainer.Variable):
            return c_var.node

    elif (VERSION_MAJOR == 1) and (VERSION_MINOR >= 23):
        # v1.x.x
        T_FUNCTION = chainer.Function
        T_VARIABLE = chainer.Variable


        def get_variable_data(variable: T_VARIABLE):
            return variable.data


        def to_variable_node(c_var: chainer.Variable):
            return c_var

    else:
        raise NotImplementedError(f"WebDNN does not support Chainer older than v1.23. Currently, Chainer {chainer.__version__} is installed.")

    FLAG_CHAINER_INSTALLED = True

except Exception as e:
    console.warning(traceback.format_exc())


def _listup_functions(inputs: Sequence[T_NODE], outputs: Sequence[T_NODE]):
    input_set = set(inputs)

    def get_prev_nodes(node: T_NODE) -> Sequence[T_NODE]:
        # NOTE(Kiikurage):
        # In chainer v1, "Variable" doesn't support "__eq__" method, so "list.__contains__" cannot be used for list of variables.
        # However, "Variable.__hash__" is implemented and "set.__contains__" is available.
        if node in input_set:
            return []

        elif isinstance(node, T_VARIABLE):
            return [] if node.creator is None else [node.creator]

        else:
            return node.inputs

    result = []  # type: List[T_FUNCTION]
    stack = [(node, None) for node in outputs]  # type: List[Tuple[T_NODE, T_NODE]]
    dependency_count = {}  # type: Dict[T_NODE, int]

    while len(stack) > 0:
        node_from, node_to = stack.pop()

        if node_from not in dependency_count:
            stack.append((node_from, node_to))

            prev_nodes = get_prev_nodes(node_from)
            dependency_count[node_from] = 0
            for prev_node in prev_nodes:
                if dependency_count.get(prev_node, 1) > 0:
                    dependency_count[node_from] += 1
                    stack.append((prev_node, node_from))

        elif dependency_count[node_from] == 0:
            if isinstance(node_from, T_FUNCTION):
                result.append(node_from)

            if node_to is not None:
                dependency_count[node_to] -= 1

        else:
            raise CyclicGraphError("[ChainerConverter] Cycles are detected, but ChainerConverter cannot convert cyclic graph")

    return result


[docs]class ChainerConverter(Converter["T_FUNCTION"]): """ChainerConverter() Converter for `Chainer <https://chainer.org/>`_. Currently, from :code:`v1.23` to :code:`v4.0.0` is supported. """ def __init__(self): super(ChainerConverter, self).__init__() if not FLAG_CHAINER_INSTALLED: raise ImportError("[ChainerConverter] Failed to import Chainer.")
[docs] def convert_from_inout_vars(self, inputs: List["chainer.Variable"], outputs: List["chainer.Variable"]): """convert_from_inout_vars(inputs, output) Construct computational graph from input and output chainer variables, and convert the graph into WebDNN IR. Args: inputs(list of chainer.Variable): input chainer variables outputs(list of chainer.Variable): output chainer variables .. warning:: This method will be removed in the future version. Use :func:`~webdnn.frontend.chainer.ChainerConverter.convert(inputs, outputs)`. Returns: (:class:`~webdnn.Graph`): WebDNN Graph """ warnings.warn("This method will be removed in the future version. Use ChainerConverter#convert(inputs, outputs).", DeprecationWarning) return self.convert(inputs, outputs)
[docs] def convert(self, inputs: List["chainer.Variable"], outputs: List["chainer.Variable"]) -> Graph: """convert(inputs, outputs) Convert chainer computational graph into WebDNN IR. Args: inputs(list of chainer.Variable): input chainer variables outputs(list of chainer.Variable): output chainer variables .. admonition:: example Convert pre-trained ResNet model .. code:: model = chainer.links.model.vision.resnet.ResNet50Layers() # Forward propagation with dummy input to build computational graph x = chainer.Variable(np.empty((1, 3, 224, 224), dtype=np.float32)) y = model(x, layers=["fc6"])["fc6"] graph = ChainerConverter().convert([x], [y]) Returns: (:class:`~webdnn.Graph`): WebDNN Graph """ for v in inputs: if isinstance(v, PlaceholderVariable): n_var = Variable(v.actual_shape, Order([None] * v.ndim)) self.set_variable(to_variable_node(v), n_var) inputs = [to_variable_node(v) for v in inputs] outputs = [to_variable_node(v) for v in outputs] # Convert parameters into constant variable input_set = set(inputs) for node in chainer.computational_graph.build_computational_graph(outputs).nodes: if isinstance(node, T_VARIABLE) and not self.has_variable(node) and node.creator is None: # If "c_var.creator" is None, it's input variable or parameters. # NOTE(Kiikurage): # In chainer v1, "Variable" doesn't support "__eq__" method, so "list.__contains__" cannot be used for list of variables. # However, "Variable.__hash__" is implemented and "set.__contains__" is available. self._convert_var(node, constant=node not in input_set) # Convert each Chainer function into WebDNN operators for c_opr in _listup_functions(inputs, outputs): self._convert_operator(c_opr) # Build graph graph = Graph([self.get_variable(c_var) for c_var in inputs], [self.get_variable(c_var) for c_var in outputs]) for v in graph.inputs: v.attributes.add(Input()) for v in graph.outputs: v.attributes.add(Output()) return graph
def _convert_var(self, c_var: T_VARIABLE, constant=False): assert not self.has_variable(c_var), f"{c_var} is already converted" ndim = len(c_var.shape) order = Order([None] * ndim) if constant: data = get_variable_data(c_var) n_var = ConstantVariable(chainer.cuda.to_cpu(data), order) # force on CPU else: n_var = Variable(c_var.shape, order) self.set_variable(c_var, n_var) return n_var