Skip to content

Commit

Permalink
2024-11-20 nightly release (c8f5deb)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 20, 2024
1 parent c9c6cf2 commit d67d9c2
Show file tree
Hide file tree
Showing 6 changed files with 408 additions and 129 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ format you want. Refer to Nvidia's GPU support matrix for more details

## Benchmark Results

The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 56-core machine.
The following was generated by running [our benchmark script](./benchmarks/decoders/generate_readme_data.py) on a lightly loaded 22-core machine with an Nvidia A100 with
5 [NVDEC decoders](https://docs.nvidia.com/video-technologies/video-codec-sdk/12.1/nvdec-application-note/index.html#).

![benchmark_results](./benchmarks/decoders/benchmark_readme_chart.png)

Expand Down
99 changes: 97 additions & 2 deletions benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def __init__(self):
def get_frames_from_video(self, video_file, pts_list):
pass

@abc.abstractmethod
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
pass

@abc.abstractmethod
def decode_and_transform(self, video_file, pts_list, height, width, device):
pass


class DecordAccurate(AbstractDecoder):
def __init__(self):
Expand Down Expand Up @@ -89,8 +97,10 @@ def __init__(self, backend):
self._backend = backend
self._print_each_iteration_time = False
import torchvision # noqa: F401
from torchvision.transforms import v2 as transforms_v2

self.torchvision = torchvision
self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
self.torchvision.set_video_backend(self._backend)
Expand All @@ -111,6 +121,20 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
frames.append(frame["data"].permute(1, 2, 0))
return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
self.torchvision.set_video_backend(self._backend)
reader = self.torchvision.io.VideoReader(video_file, "video")
frames = []
for pts in pts_list:
reader.seek(pts)
frame = next(reader)
frames.append(frame["data"].permute(1, 2, 0))
frames = [
self.transforms_v2.functional.resize(frame.to(device), (height, width))
for frame in frames
]
return frames


class TorchCodecCore(AbstractDecoder):
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
Expand Down Expand Up @@ -239,6 +263,10 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"):
)
self._device = device

from torchvision.transforms import v2 as transforms_v2

self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
Expand All @@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
break
return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
decoder = VideoDecoder(
video_file, num_ffmpeg_threads=self._num_ffmpeg_threads, device=self._device
)
frames = decoder.get_frames_played_at(pts_list)
frames = self.transforms_v2.functional.resize(frames.data, (height, width))
return frames


@torch.compile(fullgraph=True, backend="eager")
def compiled_seek_and_next(decoder, pts):
Expand Down Expand Up @@ -299,7 +335,9 @@ def __init__(self):

self.torchaudio = torchaudio

pass
from torchvision.transforms import v2 as transforms_v2

self.transforms_v2 = transforms_v2

def get_frames_from_video(self, video_file, pts_list):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
Expand All @@ -325,6 +363,21 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):

return frames

def decode_and_transform(self, video_file, pts_list, height, width, device):
stream_reader = self.torchaudio.io.StreamReader(src=video_file)
stream_reader.add_basic_video_stream(frames_per_chunk=1)
frames = []
for pts in pts_list:
stream_reader.seek(pts)
stream_reader.fill_buffer()
clip = stream_reader.pop_chunks()
frames.append(clip[0][0])
frames = [
self.transforms_v2.functional.resize(frame.to(device), (height, width))
for frame in frames
]
return frames


def create_torchcodec_decoder_from_file(video_file):
video_decoder = create_from_file(video_file)
Expand Down Expand Up @@ -443,7 +496,7 @@ def plot_data(df_data, plot_path):

# Set the title for the subplot
base_video = Path(video).name.removesuffix(".mp4")
ax.set_title(f"{base_video}\n{vcount} x {vtype}", fontsize=11)
ax.set_title(f"{base_video}\n{vtype}", fontsize=11)

# Plot bars with error bars
ax.barh(
Expand Down Expand Up @@ -486,6 +539,14 @@ class BatchParameters:
batch_size: int


@dataclass
class DataLoaderInspiredWorkloadParameters:
batch_parameters: BatchParameters
resize_height: int
resize_width: int
resize_device: str


def run_batch_using_threads(
function,
*args,
Expand Down Expand Up @@ -525,6 +586,7 @@ def run_benchmarks(
num_sequential_frames_from_start: list[int],
min_runtime_seconds: float,
benchmark_video_creation: bool,
dataloader_parameters: DataLoaderInspiredWorkloadParameters = None,
batch_parameters: BatchParameters = None,
) -> list[dict[str, str | float | int]]:
# Ensure that we have the same seed across benchmark runs.
Expand All @@ -550,6 +612,39 @@ def run_benchmarks(
for decoder_name, decoder in decoder_dict.items():
print(f"video={video_file_path}, decoder={decoder_name}")

if dataloader_parameters:
bp = dataloader_parameters.batch_parameters
dataloader_result = benchmark.Timer(
stmt="run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)",
globals={
"video_file": str(video_file_path),
"pts_list": uniform_pts_list,
"decoder": decoder,
"run_batch_using_threads": run_batch_using_threads,
"batch_parameters": dataloader_parameters.batch_parameters,
"height": dataloader_parameters.resize_height,
"width": dataloader_parameters.resize_width,
"device": dataloader_parameters.resize_device,
},
label=f"video={video_file_path} {metadata_label}",
sub_label=decoder_name,
description=f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} decode_and_transform()",
)
results.append(
dataloader_result.blocked_autorange(
min_run_time=min_runtime_seconds
)
)
df_data.append(
convert_result_to_df_item(
results[-1],
decoder_name,
video_file_path,
num_samples * dataloader_parameters.batch_parameters.batch_size,
f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} x decode_and_transform()",
)
)

for kind, pts_list in [
("uniform", uniform_pts_list),
("random", random_pts_list),
Expand Down
Binary file modified benchmarks/decoders/benchmark_readme_chart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d67d9c2

Please sign in to comment.