Skip to content

Commit

Permalink
2024-11-21 nightly release (fd4c288)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 21, 2024
1 parent d67d9c2 commit 6354933
Show file tree
Hide file tree
Showing 8 changed files with 505 additions and 479 deletions.
94 changes: 53 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

# TorchCodec

TorchCodec is a Python library for decoding videos into PyTorch tensors. It aims
to be fast, easy to use, and well integrated into the PyTorch ecosystem. If you
want to use PyTorch to train ML models on videos, TorchCodec is how you turn
those videos into data.
TorchCodec is a Python library for decoding videos into PyTorch tensors, on CPU
and CUDA GPU. It aims to be fast, easy to use, and well integrated into the
PyTorch ecosystem. If you want to use PyTorch to train ML models on videos,
TorchCodec is how you turn those videos into data.

We achieve these capabilities through:

Expand All @@ -19,21 +19,24 @@ We achieve these capabilities through:
or used directly to train models.

> [!NOTE]
> ⚠️ TorchCodec is still in early development stage and some APIs may be updated
> in future versions without a deprecation cycle, depending on user feedback.
> ⚠️ TorchCodec is still in development stage and some APIs may be updated
> in future versions, depending on user feedback.
> If you have any suggestions or issues, please let us know by
> [opening an issue](https://github.com/pytorch/torchcodec/issues/new/choose)!
## Using TorchCodec

Here's a condensed summary of what you can do with TorchCodec. For a more
detailed example, [check out our
Here's a condensed summary of what you can do with TorchCodec. For more detailed
examples, [check out our
documentation](https://pytorch.org/torchcodec/stable/generated_examples/)!

#### Decoding

```python
from torchcodec.decoders import VideoDecoder

decoder = VideoDecoder("path/to/video.mp4")
device = "cpu" # or e.g. "cuda" !
decoder = VideoDecoder("path/to/video.mp4", device=device)

decoder.metadata
# VideoStreamMetadata:
Expand All @@ -44,39 +47,47 @@ decoder.metadata
# average_fps: 25.0
# ... (truncated output)

len(decoder) # == decoder.metadata.num_frames!
# 250
decoder.metadata.average_fps # Note: instantaneous fps can be higher or lower
# 25.0

# Simple Indexing API
decoder[0] # uint8 tensor of shape [C, H, W]
decoder[0 : -1 : 20] # uint8 stacked tensor of shape [N, C, H, W]

# Indexing, with PTS and duration info:
decoder.get_frames_at(indices=[2, 100])
# FrameBatch:
# data (shape): torch.Size([2, 3, 270, 480])
# pts_seconds: tensor([0.0667, 3.3367], dtype=torch.float64)
# duration_seconds: tensor([0.0334, 0.0334], dtype=torch.float64)

# Iterate over frames:
for frame in decoder:
pass
# Time-based indexing with PTS and duration info
decoder.get_frames_played_at(seconds=[0.5, 10.4])
# FrameBatch:
# data (shape): torch.Size([2, 3, 270, 480])
# pts_seconds: tensor([ 0.4671, 10.3770], dtype=torch.float64)
# duration_seconds: tensor([0.0334, 0.0334], dtype=torch.float64)
```

# Indexing, with PTS and duration info
decoder.get_frame_at(len(decoder) - 1)
# Frame:
# data (shape): torch.Size([3, 400, 640])
# pts_seconds: 9.960000038146973
# duration_seconds: 0.03999999910593033
#### Clip sampling

decoder.get_frames_in_range(start=10, stop=30, step=5)
# FrameBatch:
# data (shape): torch.Size([4, 3, 400, 640])
# pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000])
# duration_seconds: tensor([0.0400, 0.0400, 0.0400, 0.0400])
```python

# Time-based indexing with PTS and duration info
decoder.get_frame_played_at(pts_seconds=2)
# Frame:
# data (shape): torch.Size([3, 400, 640])
# pts_seconds: 2.0
# duration_seconds: 0.03999999910593033
from torchcodec.samplers import clips_at_regular_timestamps

clips_at_regular_timestamps(
decoder,
seconds_between_clip_starts=1.5,
num_frames_per_clip=4,
seconds_between_frames=0.1
)
# FrameBatch:
# data (shape): torch.Size([9, 4, 3, 270, 480])
# pts_seconds: tensor([[ 0.0000, 0.0667, 0.1668, 0.2669],
# [ 1.4681, 1.5682, 1.6683, 1.7684],
# [ 2.9696, 3.0697, 3.1698, 3.2699],
# ... (truncated), dtype=torch.float64)
# duration_seconds: tensor([[0.0334, 0.0334, 0.0334, 0.0334],
# [0.0334, 0.0334, 0.0334, 0.0334],
# [0.0334, 0.0334, 0.0334, 0.0334],
# ... (truncated), dtype=torch.float64)
```

You can use the following snippet to generate a video with FFmpeg and tryout
Expand Down Expand Up @@ -142,7 +153,7 @@ format you want. Refer to Nvidia's GPU support matrix for more details
[official instructions](https://pytorch.org/get-started/locally/).

3. Install or compile FFmpeg with NVDEC support.
TorchCodec with CUDA should work with FFmpeg versions in [4, 7].
TorchCodec with CUDA should work with FFmpeg versions in [5, 7].

If FFmpeg is not already installed, or you need a more recent version, an
easy way to install it is to use `conda`:
Expand Down Expand Up @@ -172,16 +183,17 @@ format you want. Refer to Nvidia's GPU support matrix for more details
ffmpeg -hwaccel cuda -hwaccel_output_format cuda -i test/resources/nasa_13013.mp4 -f null -
```

4. Install TorchCodec by passing in an `--index-url` parameter that corresponds to your CUDA
Toolkit version, example:
4. Install TorchCodec by passing in an `--index-url` parameter that corresponds
to your CUDA Toolkit version, example:

```bash
# This corresponds to CUDA Toolkit version 12.4 and nightly Pytorch.
pip install torchcodec --index-url=https://download.pytorch.org/whl/nightly/cu124
# This corresponds to CUDA Toolkit version 12.4. It should be the same one
# you used when you installed PyTorch (If you installed PyTorch with pip).
pip install torchcodec --index-url=https://download.pytorch.org/whl/cu124
```

Note that without passing in the `--index-url` parameter, `pip` installs TorchCodec
binaries from PyPi which are CPU-only and do not have CUDA support.
Note that without passing in the `--index-url` parameter, `pip` installs
the CPU-only version of TorchCodec.

## Benchmark Results

Expand Down
20 changes: 13 additions & 7 deletions benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):

class TorchCodecPublic(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu"):
self._num_ffmpeg_threads = (
int(num_ffmpeg_threads) if num_ffmpeg_threads else None
)
self._num_ffmpeg_threads = int(num_ffmpeg_threads) if num_ffmpeg_threads else 1
self._device = device

from torchvision.transforms import v2 as transforms_v2
Expand Down Expand Up @@ -446,11 +444,11 @@ def retrieve_videos(urls_and_dest_paths):
urllib.request.urlretrieve(url, path)


def plot_data(df_data, plot_path):
def plot_data(json_data, plot_path):
plt.rcParams["font.size"] = 18

# Creating the DataFrame
df = pd.DataFrame(df_data)
df = pd.DataFrame(json_data["experiments"])

# Sorting by video, type, and frame_count
df_sorted = df.sort_values(by=["video", "type", "frame_count"])
Expand Down Expand Up @@ -520,6 +518,14 @@ def plot_data(df_data, plot_path):
for col in range(video_type_combinations[unique_videos[row]], max_combinations):
fig.delaxes(axes[row, col])

plt.gcf().text(
0.005,
0.87,
"\n".join([f"{k}: {v}" for k, v in json_data["system_metadata"].items()]),
fontsize=11,
bbox=dict(facecolor="white"),
)

# Adjust layout to avoid overlap
plt.tight_layout()

Expand Down Expand Up @@ -628,7 +634,7 @@ def run_benchmarks(
},
label=f"video={video_file_path} {metadata_label}",
sub_label=decoder_name,
description=f"dataloader[threads={bp.num_threads} batch_size={bp.batch_size}] {num_samples} decode_and_transform()",
description=f"dataloader[threads={bp.num_threads},batch_size={bp.batch_size}] {num_samples} decode_and_transform()",
)
results.append(
dataloader_result.blocked_autorange(
Expand Down Expand Up @@ -673,7 +679,7 @@ def run_benchmarks(
decoder_name,
video_file_path,
num_samples,
f"{kind} seek()+next()",
f"{num_samples} x {kind} seek()+next()",
)
)

Expand Down
Binary file modified benchmarks/decoders/benchmark_readme_chart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 6354933

Please sign in to comment.