Source code for webdnn.graph.operators.slice

from typing import Optional, Union

from webdnn.graph.axis import AxisKeyDict
from webdnn.graph.graph import Graph
from webdnn.graph.operator import Operator
from webdnn.graph.operators.attributes.tensorwise import Tensorwise
from webdnn.graph.optimize_rule import OptimizeRule
from webdnn.graph.order import Order
from webdnn.graph.variable import Variable
from webdnn.graph.variables.constant_variable import ConstantVariable


def normalize_slice(s: slice, size: int):
    step = 1 if s.step is None else s.step
    start = (s.start if s.start >= 0 else s.start + size) if s.start is not None else (0 if step > 0 else size - 1)
    stop = (s.stop if s.stop >= 0 else s.stop + size) if s.stop is not None else (size if step > 0 else -1)

    return slice(start, stop, step)


[docs]class Slice(Operator): """Slice(name, indices) Slice original variable to sub variable. Args: name (str): Operator name. indices (:class:`AxisKeyDict` of slice, int, or None): slicing indices Signature .. code:: y = x[indices] - **x** - Input variable. - **y** - Sliced variable. """ def __init__(self, name: Optional[str], indices: AxisKeyDict[Union[slice, int, None]]): super().__init__(name) for axis, index in indices.items(): if not (isinstance(index, (slice, int)) or index is None): raise TypeError(f""" [Slice] Each index must be an instance of slice, int, or NoneType: (indices) = {indices} (type(indices[{axis.name}])) = {type(index)} """) self.parameters["indices"] = indices def __call__(self, x: Variable): # assert index is valid for axis, index in self.indices.items(): if axis in x.order.axes: if isinstance(index, slice): index = normalize_slice(index, x.shape_dict[axis]) valid_start = -x.shape_dict[axis] <= index.start <= x.shape_dict[axis] valid_stop = -x.shape_dict[axis] <= index.stop <= x.shape_dict[axis] if not valid_start or not valid_stop: raise ValueError(f""" [Slice] Index {index} in {axis} is out of range: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) if ((abs(index.stop - index.start) - 1) // abs(index.step)) + 1 < 0: raise ValueError(f""" [Slice] Slice operator doesn't support 0-size output: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) elif isinstance(index, int): if not -x.shape_dict[axis] <= index < x.shape_dict[axis]: raise ValueError(f""" [Slice] Index {index} in {axis} is out of range: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} (valid range) = [{-x.shape_dict[axis]}, {x.shape_dict[axis]}) """) elif index is None: raise ValueError(f""" [Slice] Axis {axis} is already exist: (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) else: if index is not None: raise ValueError(f""" [Slice] Axis {axis} is not exist in input variable. In this case, index must be "None" (=insert new axis): (x.order) = {x.order} (x.shape) = {x.shape} (indices) = {self.indices} (indices[{axis.name}]) = {index} """) if all(isinstance(index, int) for index in self.indices.values()): raise NotImplementedError(f""" [Slice] Accessing to one element is not supported: (indices) = {self.indices} """) y_shape_dict = AxisKeyDict() for axis, index in self.indices.items(): if isinstance(index, slice): index = normalize_slice(index, x.shape_dict[axis]) y_shape_dict[axis] = ((abs(index.stop - index.start) - 1) // abs(index.step)) + 1 elif isinstance(index, int): pass # Remove axis elif index is None: y_shape_dict[axis] = 1 # Insert axis y = Variable(list(y_shape_dict.values()), Order(list(y_shape_dict.keys()))) for axis in x.order.axes: if axis in self.indices: index = self.indices[axis] if isinstance(index, slice) and index.start is None and index.stop is None and index.step is None: # This axis is not sliced. self.attributes.add(Tensorwise(axis)) else: # This axis is not sliced. self.attributes.add(Tensorwise(axis)) self.append_input("x", x) self.append_output("y", y) return y, @property def indices(self) -> AxisKeyDict[Union[slice, int, None]]: return self.parameters["indices"] def fold_constance(self, graph: Graph): x = self.inputs["x"] # type: ConstantVariable y = self.outputs["y"] remained_axes_in_x_order = [a for a in x.order.axes if a in y.order.axes] new_axes = [a for a in y.order.axes if a not in x.order.axes] slices = [self.indices[a] for a in x.order.axes] + [None] * len(new_axes) new_y = ConstantVariable(x.data[slices], Order(remained_axes_in_x_order + new_axes)) new_y.change_order(y.order) OptimizeRule.replace_variable(graph, y, new_y) self.remove_all()