Source code for webdnn.graph.operators.embedding

from typing import Optional

from webdnn.graph.axis import Axis
from webdnn.graph.operator import Operator
from webdnn.graph.operators.attributes.tensorwise import Tensorwise
from webdnn.graph.order import OrderNTC, OrderNC, OrderNT
from webdnn.graph.variable import Variable


[docs]class Embedding(Operator): """Embedding(name) Word embedding operator. Args: name (str): Operator name. Signature .. code:: y, = op(x, w) - **x** - Input variables. It must has 2 axes, :obj:`~webdnn.Axis.N`, :obj:`~webdnn.Axis.T`. - **w** - Dictionary variable. It must has :obj:`~webdnn.Axis.N`, and :obj:`~webdnn.Axis.C`. Its size of :obj:`~webdnn.Axis.C` must be same as the vocabulary size. Its size of :obj:`~webdnn.Axis.N` must be same as the embed feature size. - **y** - Output variable. Its order is :obj:`~webdnn.graph.order.OrderNTC`. """ def __init__(self, name: Optional[str]): super().__init__(name) self.attributes.add(Tensorwise(Axis.N)) self.attributes.add(Tensorwise(Axis.T)) def __call__(self, x: Variable, w: Variable): x_shape_dict = x.shape_dict w_shape_dict = w.shape_dict assert x.order.check_same_axes(OrderNT), f""" [Embedding] Input variable "x" must have only Axis.N and Axis.T: (x.order.axes) = {w.order.axes}""" assert w.order.check_same_axes(OrderNC), f""" [Embedding] Dictionary variable "w" must have only Axis.N and Axis.C: (w.order.axes) = {w.order.axes}""" batch_size = x_shape_dict[Axis.N] sequence_len = x_shape_dict[Axis.T] embedding_dim = w_shape_dict[Axis.N] y = Variable([batch_size, sequence_len, embedding_dim], OrderNTC) self.append_input("x", x) self.append_input("w", w) self.append_output("y", y) return y,