Source code for webdnn.graph.operators.linear

from typing import Optional

from webdnn.graph.axis import Axis
from webdnn.graph.operator import Operator
from webdnn.graph.order import OrderNC
from webdnn.graph.placeholder import Placeholder
from webdnn.graph.variable import Variable


[docs]class Linear(Operator): """Linear(name) Linear operator a.k.a. Dense, Fully-Connected operator. This operator compute dot product with input and weight variables. Except :obj:`~webdnn.Axis.N`, all axes are used for calculation. If x is :obj:`~webdnn.graph.order.OrderNC` and w is :obj:`~webdnn.graph.order.OrderCN`, this operation is just same as :code:`np.tensor_dot(x, w, (1, 0))` in numpy's :code:`tensordot`. Also, 4D variable is acceptable for input and weight variable. If x is :obj:`~webdnn.graph.order.OrderNCHW` and w is :obj:`~webdnn.graph.order.OrderNHWC`, this operation is just same as :code:`np.tensor_dot(x, w, ((1,2,3), (3,1,2))` in numpy's :code:`tensordot`. Input and weight variables must be same dimension. Otherwise, an assertion error is raised. Args: name (str): Operator name. Signature .. code:: y, = op(x, w) - **x** - Input variable. - **w** - Weight variable. - **y** - Output variable. Its order is :obj:`~webdnn.graph.order.OrderNC`. Its :obj:`~webdnn.Axis.N` size is same as :code:`x.shape_dict[Axis.N]`, and its :obj:`~webdnn.Axis.C` size is same as :code:`w.shape_dict[Axis.N]` """ def __init__(self, name: Optional[str]): super().__init__(name) def __call__(self, x: Variable, w: Variable): if Placeholder.check_resolved(x.shape_dict[Axis.C]) and Placeholder.check_resolved(w.shape_dict[Axis.C]): assert x.shape_dict[Axis.C] == w.shape_dict[Axis.C], "Input and Weight variable of Linear operator must have same shape " \ "except Axis.N " \ f"size: x.shape_dict[Axis.C]={x.shape_dict[Axis.C]}, " \ f"size: w.shape_dict[Axis.C]={w.shape_dict[Axis.C]}" assert x.ndim == w.ndim, "Input and Weight variable of Linear operator must be same dimension: " \ f"len(x.ndim)={x.ndim}, len(w.ndim)={w.ndim}" if x.ndim == 4: if Placeholder.check_resolved(x.shape_dict[Axis.H]) and Placeholder.check_resolved(w.shape_dict[Axis.H]): assert x.shape_dict[Axis.H] == w.shape_dict[Axis.H], "Input and Weight variable of Linear operator must have same shape " \ "except Axis.N " \ f"size: x.shape_dict[Axis.H]={x.shape_dict[Axis.H]}, " \ f"size: w.shape_dict[Axis.H]={w.shape_dict[Axis.H]}" if Placeholder.check_resolved(x.shape_dict[Axis.W]) and Placeholder.check_resolved(w.shape_dict[Axis.W]): assert x.shape_dict[Axis.W] == w.shape_dict[Axis.W], "Input and Weight variable of Linear operator must have same shape " \ "except Axis.N " \ f"size: x.shape_dict[Axis.W]={x.shape_dict[Axis.W]}, " \ f"size: w.shape_dict[Axis.W]={w.shape_dict[Axis.W]}" y = Variable([x.shape_dict[Axis.N], w.shape_dict[Axis.N]], OrderNC) self.append_input("x", x) self.append_input("w", w) self.append_output("y", y) return y,