Source code for webdnn.graph.operators.prod

from typing import Optional

import numpy as np

from webdnn.graph.axis import Axis
from webdnn.graph.graph import Graph
from webdnn.graph.operators.reduce import Reduce
from webdnn.graph.optimize_rule import OptimizeRule
from webdnn.graph.order import Order
from webdnn.graph.variables.constant_variable import ConstantVariable


[docs]class Prod(Reduce): """Prod(name, axis) return product of the input tensor along to specified axis Args: name (str) : Operator name. axis (:obj:`~webdnn.graph.axis.Axis`) : axis which will be reduced. Signature .. code:: y, = op(x) - **x** - Input variables. - **y** - Output variable. """ def __init__(self, name: Optional[str], axis: Axis): super().__init__(name, axis=axis) def fold_constance(self, graph: Graph): x = self.inputs["x"] # type: ConstantVariable y = self.outputs["y"] new_axes = list(x.order.axes) new_axes.remove(self.axis) new_y = ConstantVariable(np.prod(x.data, axis=x.order.axes_dict[self.axis]), Order(new_axes)) new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) self.remove_all()