Skip to content

Commit

Permalink
Added Python side support, extended tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Dec 20, 2024
1 parent 97ac764 commit 35f2e59
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 186 deletions.
36 changes: 17 additions & 19 deletions benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
get_frames_by_pts,
get_json_metadata,
get_next_frame,
scan_all_streams_to_update_metadata,
seek_to_pts,
)

Expand Down Expand Up @@ -154,8 +153,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
self._device = device

def decode_frames(self, video_file, pts_list):
decoder = create_from_file(video_file)
scan_all_streams_to_update_metadata(decoder)
decoder = create_from_file(video_file, seek_mode="exact")
_add_video_stream(
decoder,
num_threads=self._num_threads,
Expand All @@ -170,7 +168,7 @@ def decode_frames(self, video_file, pts_list):
return frames

def decode_first_n_frames(self, video_file, n):
decoder = create_from_file(video_file)
decoder = create_from_file(video_file, seek_mode="approximate")
_add_video_stream(
decoder,
num_threads=self._num_threads,
Expand All @@ -197,7 +195,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
self.transforms_v2 = transforms_v2

def decode_frames(self, video_file, pts_list):
decoder = create_from_file(video_file)
decoder = create_from_file(video_file, seek_mode="approximate")
num_threads = int(self._num_threads) if self._num_threads else 0
_add_video_stream(
decoder,
Expand All @@ -216,7 +214,7 @@ def decode_frames(self, video_file, pts_list):

def decode_first_n_frames(self, video_file, n):
num_threads = int(self._num_threads) if self._num_threads else 0
decoder = create_from_file(video_file)
decoder = create_from_file(video_file, seek_mode="approximate")
_add_video_stream(
decoder,
num_threads=num_threads,
Expand All @@ -233,7 +231,7 @@ def decode_first_n_frames(self, video_file, n):

def decode_and_resize(self, video_file, pts_list, height, width, device):
num_threads = int(self._num_threads) if self._num_threads else 1
decoder = create_from_file(video_file)
decoder = create_from_file(video_file, seek_mode="approximate")
_add_video_stream(
decoder,
num_threads=num_threads,
Expand Down Expand Up @@ -263,8 +261,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"
self._device = device

def decode_frames(self, video_file, pts_list):
decoder = create_from_file(video_file)
scan_all_streams_to_update_metadata(decoder)
decoder = create_from_file(video_file, seek_mode="exact")
_add_video_stream(
decoder,
num_threads=self._num_threads,
Expand All @@ -279,8 +276,7 @@ def decode_frames(self, video_file, pts_list):
return frames

def decode_first_n_frames(self, video_file, n):
decoder = create_from_file(video_file)
scan_all_streams_to_update_metadata(decoder)
decoder = create_from_file(video_file, seek_mode="exact")
_add_video_stream(
decoder,
num_threads=self._num_threads,
Expand All @@ -297,9 +293,10 @@ def decode_first_n_frames(self, video_file, n):


class TorchCodecPublic(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"):
self._num_ffmpeg_threads = num_ffmpeg_threads
self._device = device
self._seek_mode = seek_mode

from torchvision.transforms import v2 as transforms_v2

Expand All @@ -310,7 +307,7 @@ def decode_frames(self, video_file, pts_list):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)
return decoder.get_frames_played_at(pts_list)

Expand All @@ -319,7 +316,7 @@ def decode_first_n_frames(self, video_file, n):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)
frames = []
count = 0
Expand All @@ -335,17 +332,18 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)
frames = decoder.get_frames_played_at(pts_list)
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
return frames


class TorchCodecPublicNonBatch(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"):
self._num_ffmpeg_threads = num_ffmpeg_threads
self._device = device
self._seek_mode = seek_mode

from torchvision.transforms import v2 as transforms_v2

Expand All @@ -356,7 +354,7 @@ def decode_frames(self, video_file, pts_list):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)

frames = []
Expand All @@ -370,7 +368,7 @@ def decode_first_n_frames(self, video_file, n):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)
frames = []
count = 0
Expand All @@ -386,7 +384,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1
)
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device
video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device, seek_mode=self._seek_mode
)

frames = []
Expand Down
Loading

0 comments on commit 35f2e59

Please sign in to comment.