import itertools
from typing import Tuple, Sequence, Union
from webdnn.graph.axis import Axis, AxisKeyDict
[docs]class Order:
"""Order(axes)
Descriptor class for representing semantics of data order.
For example, :obj:`~webdnn.graph.order.OrderNHWC` means that the data is aligned as Channel-major (Batch-size-minor).
Popular data order is already defined in :mod:`webdnn.graph.order`. You should use pre-defined order instance.
If you have to define new data order, you can simply as follows.
.. code::
OrderHCNW = Order([Axis.H, Axis.C, Axis.N, Axis.W])
Args:
axes(list of :class:`~webdnn.Axis`): list of axis.
"""
def __init__(self, axes: Sequence[Union[Axis, None]]):
axes = tuple(Axis() if a is None else a for a in axes)
for a1, a2 in itertools.permutations(axes, 2):
assert a1 != a2, f"""
[Order] Axes are duplicated:
(axes) = {axes}
"""
self._axes = axes
@property
def axes(self) -> Tuple[Axis, ...]:
return self._axes
@property
def ndim(self) -> int:
return len(self.axes)
@property
def axes_dict(self) -> AxisKeyDict[int]:
return AxisKeyDict(self.axes, range(self.ndim))
def __eq__(self, other):
if isinstance(other, Order):
return self.axes == other.axes
else:
return False
def __repr__(self):
return self.__str__()
def __str__(self):
return f"[{', '.join([axis.name for axis in self.axes])}]"
[docs] def check_same_axes(self, other: "Order") -> bool:
"""
check_same_axes(order)
check if 2 orders have same axes (the axis order is not considered)
Args:
other: other order
"""
return all(axis in other.axes for axis in self.axes) and all(axis in self.axes for axis in other.axes)
[docs] def get_common_axes(self, other: "Order") -> Sequence[Axis]:
"""
get_common_axes(order)
return axes which are included in both two order.
Args:
other: other order
"""
return [axis for axis in self.axes if axis in other.axes]
[docs] def get_all_axes(self, other: "Order") -> Sequence[Axis]:
"""
get_all_axes(order)
return axes which are included in either two order.
Args:
other: other order
"""
return list(self.axes) + [axis for axis in other.axes if axis not in self.axes]
def unify(self, other: "Order"):
if self.ndim != other.ndim:
raise ValueError(f"""
Unification failed: Number of dimension mismatch
(self.ndim) = {self.ndim}
(other.ndim) = {other.ndim}""")
for (i, axis1), axis2 in zip(enumerate(self.axes), other.axes):
try:
axis1.unify(axis2)
except ValueError:
raise ValueError(f"""
Unification failed: self.axes[{i}] != other.axes[{i}]
(self) = {self}
(other) = {other}""")
"""
usage:
Bias Filter
"""
OrderC = Order([Axis.C])
"""
usage:
Fully-Connected Input/Output.
"""
OrderNC = Order([Axis.N, Axis.C])
"""
usage:
Fully-Connected Filter
"""
OrderCN = Order([Axis.C, Axis.N])
"""
usage:
Convolution2D Input/Output of WebGPU
"""
OrderNHWC = Order([Axis.N, Axis.H, Axis.W, Axis.C])
"""
usage:
Convolution2D Filter of WebGPU
"""
OrderHWNC = Order([Axis.H, Axis.W, Axis.N, Axis.C])
"""
usage:
Fully-Connected Filter when Input variable is 4D.
"""
OrderHWCN = Order([Axis.H, Axis.W, Axis.C, Axis.N])
"""
usage:
Chainer
"""
OrderNCHW = Order([Axis.N, Axis.C, Axis.H, Axis.W])
"""
usage:
Chainer Deconvolution2D Filter
"""
OrderCNHW = Order([Axis.C, Axis.N, Axis.H, Axis.W])
"""
usage:
Chainer Deconvolution2D Filter
"""
OrderCHWN = Order([Axis.C, Axis.H, Axis.W, Axis.N])
OrderNT = Order([Axis.N, Axis.T])
OrderNTC = Order([Axis.N, Axis.T, Axis.C])