Skip to content

Commit

Permalink
ExprOp.convolution: remove overload bloat and unify return type
Browse files Browse the repository at this point in the history
  • Loading branch information
Ichunjo committed Jan 12, 2025
1 parent 4382494 commit 6fe0158
Showing 1 changed file with 30 additions and 77 deletions.
107 changes: 30 additions & 77 deletions vsexprtools/exprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from vstools import (
ColorRange, ConvMode, CustomEnum, CustomIndexError, CustomValueError, FuncExceptT,
HoldsVideoFormatT, PlanesT, StrArrOpt, StrList, VideoFormatT, VideoNodeIterable, flatten,
get_lowest_value, get_neutral_value, get_peak_value, vs
HoldsVideoFormatT, PlanesT, StrArrOpt, StrList, VideoFormatT, VideoNodeIterable, copy_signature,
flatten, flatten_vnodes, get_lowest_value, get_neutral_value, get_peak_value, vs
)

from .util import ExprVarRangeT, ExprVars, ExprVarsT, complexpr_available

__all__ = [
'ExprOp', 'ExprToken', 'ExprList'
'ExprOp', 'ExprToken', 'ExprList', 'TupleExprList'
]


Expand Down Expand Up @@ -123,6 +123,19 @@ def __call__(
clips, self, planes, format, opt, boundary, force_akarin, func, split_planes, **kwargs # type: ignore
)

class TupleExprList(tuple[ExprList, ...]):
@copy_signature(ExprList.__call__)
def __call__(self, *clips: VideoNodeIterable, **kwargs: Any) -> vs.VideoNode:
clip: list[vs.VideoNode] | vs.VideoNode = flatten_vnodes(clips)

for exprlist in self:
clip = exprlist(clip, **kwargs)

return clip[0] if isinstance(clip, list) else clip

def __str__(self) -> str:
return str(tuple(str(e) for e in self))


class ExprOpBase(str):
value: str
Expand Down Expand Up @@ -237,33 +250,10 @@ def clamp(

return ExprList([c, min, ExprOp.MAX, max, ExprOp.MIN])

@overload
@classmethod
def matrix(
cls, var: str, radius: int, mode: Literal[ConvMode.HV], exclude: Iterable[tuple[int, int]] | None = ...
) -> tuple[ExprList, ExprList]:
...

@overload
@classmethod
def matrix(
cls, var: str, radius: int,
mode: Literal[ConvMode.HORIZONTAL] | Literal[ConvMode.VERTICAL] | Literal[ConvMode.SQUARE],
exclude: Iterable[tuple[int, int]] | None = ...
) -> ExprList:
...

@overload
@classmethod
def matrix(
cls, var: str, radius: int, mode: ConvMode, exclude: Iterable[tuple[int, int]] | None = ...
) -> tuple[ExprList, ExprList] | ExprList:
...

@classmethod
def matrix(
cls, var: str, radius: int, mode: ConvMode, exclude: Iterable[tuple[int, int]] | None = None
) -> tuple[ExprList, ExprList] | ExprList:
) -> TupleExprList:
exclude = list(exclude) if exclude else list()

match mode:
Expand All @@ -278,60 +268,27 @@ def matrix(
case ConvMode.HORIZONTAL:
coordinates = [(xy, 0) for xy in range(-radius, radius + 1)]
case ConvMode.HV:
return (
cls.matrix(var, radius, ConvMode.VERTICAL, exclude),
cls.matrix(var, radius, ConvMode.HORIZONTAL, exclude),
)
return TupleExprList([
cls.matrix(var, radius, ConvMode.VERTICAL, exclude)[0],
cls.matrix(var, radius, ConvMode.HORIZONTAL, exclude)[0],
])
case _:
raise NotImplementedError

return ExprList([
return TupleExprList([ExprList([
var if x == y == 0 else
ExprOp.REL_PIX(var, x, y)
for (x, y) in coordinates
if (x, y) not in exclude
])

@overload
@classmethod
def convolution(
cls, var: str, matrix: Iterable[SupportsFloat] | Iterable[Iterable[SupportsFloat]],
bias: float | None = None, divisor: float | bool = True, saturate: bool = True,
mode: Literal[ConvMode.HV] = ...,
premultiply: float | int | None = None,
multiply: float | int | None = None, clamp: bool = False
) -> tuple[ExprList, ExprList]:
...

@overload
@classmethod
def convolution(
cls, var: str, matrix: Iterable[SupportsFloat] | Iterable[Iterable[SupportsFloat]],
bias: float | None = None, divisor: float | bool = True, saturate: bool = True,
mode: Literal[ConvMode.HORIZONTAL] | Literal[ConvMode.VERTICAL] | Literal[ConvMode.SQUARE] = ...,
premultiply: float | int | None = None,
multiply: float | int | None = None, clamp: bool = False
) -> ExprList:
...

@overload
@classmethod
def convolution(
cls, var: str, matrix: Iterable[SupportsFloat] | Iterable[Iterable[SupportsFloat]],
bias: float | None = None, divisor: float | bool = True, saturate: bool = True,
mode: ConvMode = ...,
premultiply: float | int | None = None,
multiply: float | int | None = None, clamp: bool = False
) -> tuple[ExprList, ExprList] | ExprList:
...
])])

@classmethod
def convolution(
cls, var: str, matrix: Iterable[SupportsFloat] | Iterable[Iterable[SupportsFloat]],
bias: float | None = None, divisor: float | bool = True, saturate: bool = True,
mode: ConvMode = ConvMode.HV, premultiply: float | int | None = None,
multiply: float | int | None = None, clamp: bool = False
) -> tuple[ExprList, ExprList] | ExprList:
) -> TupleExprList:
convolution = list[float](flatten(matrix)) # type: ignore

if not (conv_len := len(convolution)) % 2:
Expand All @@ -348,17 +305,13 @@ def convolution(

rel_pixels = cls.matrix(var, radius, mode)

def _make_output(rel_pixels: ExprList) -> ExprList:
return ExprList([
output = TupleExprList([
ExprList([
rel_pix if weight == 1 else [rel_pix, weight, ExprOp.MUL]
for rel_pix, weight in zip(rel_pixels, convolution)
for rel_pix, weight in zip(rel_px, convolution)
if weight != 0
])

output = (
[_make_output(rp) for rp in rel_pixels] if isinstance(rel_pixels, tuple) else [_make_output(rel_pixels)]
)

]) for rel_px in rel_pixels
])

for out in output:
out.extend(ExprOp.ADD * out.mlength)
Expand All @@ -385,7 +338,7 @@ def _make_output(rel_pixels: ExprList) -> ExprList:
if clamp:
out.append(ExprOp.clamp(ExprToken.RangeMin, ExprToken.RangeMax))

return output[0] if len(output) == 1 else tuple(output) # type: ignore[return-value]
return output

@staticmethod
def _parse_planes(
Expand Down

0 comments on commit 6fe0158

Please sign in to comment.