Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch-labs/torchcodec into metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jun 25, 2024
2 parents bc235ed + 5717599 commit d4a24fc
Show file tree
Hide file tree
Showing 24 changed files with 226 additions and 65 deletions.
86 changes: 69 additions & 17 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,30 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
return output;
}

torch::Tensor VideoDecoder::getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata) {
if (options.shape == "NHWC") {
return torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
return torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
// TODO: should this be a TORCH macro of some kind?
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp(
double seconds) {
for (auto& [streamIndex, stream] : streams_) {
Expand Down Expand Up @@ -778,26 +802,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
throw std::runtime_error(
"Invalid stream index=" + std::to_string(streamIndex));
}

BatchDecodedOutput output;
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& options = streams_[streamIndex].options;
if (options.shape == "NHWC") {
output.frames = torch::empty(
{(long)frameIndexes.size(),
options.height.value_or(*streamMetadata.height),
options.width.value_or(*streamMetadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
output.frames = torch::empty(
{(long)frameIndexes.size(),
3,
options.height.value_or(*streamMetadata.height),
options.width.value_or(*streamMetadata.width)},
{torch::kUInt8});
} else {
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
output.frames =
getEmptyTensorForBatch(frameIndexes.size(), options, streamMetadata);

int i = 0;
if (streams_.count(streamIndex) == 0) {
throw std::runtime_error(
Expand All @@ -817,6 +828,47 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
return output;
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
int streamIndex,
int64_t start,
int64_t stop,
int64_t step) {
TORCH_CHECK(
streamIndex >= 0 || streamIndex < containerMetadata_.streams.size(),
"Invalid stream index=" + std::to_string(streamIndex));
TORCH_CHECK(
streams_.count(streamIndex) > 0,
"Invalid stream index=" + std::to_string(streamIndex));

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
TORCH_CHECK(
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
TORCH_CHECK(
stop <= stream.allFrames.size(),
"Range stop, " + std::to_string(stop) +
", is more than the number of frames, " +
std::to_string(stream.allFrames.size()));
TORCH_CHECK(
step > 0, "Step must be greater than 0; is " + std::to_string(step));

int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& options = stream.options;
BatchDecodedOutput output;
output.frames =
getEmptyTensorForBatch(numOutputFrames, options, streamMetadata);

int64_t f = 0;
for (int64_t i = start; i < stop; i += step) {
int64_t pts = stream.allFrames[i].pts;
setCursorPtsInSeconds(1.0 * pts / stream.timeBase.den);
torch::Tensor frame = getNextDecodedOutput().frame;
output.frames[f++] = frame;
}

return output;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() {
return getDecodedOutputWithFilter(
[this](int frameStreamIndex, AVFrame* frame) {
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ class VideoDecoder {
BatchDecodedOutput getFramesAtIndexes(
int streamIndex,
const std::vector<int64_t>& frameIndexes);
// Returns frames within a given range for a given stream as a single stacked
// Tensor. The range is defined by [start, stop). The values retrieved from
// the range are:
// [start, start+step, start+(2*step), start+(3*step), ..., stop)
// The default for step is 1.
BatchDecodedOutput
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);

// --------------------------------------------------------------------------
// DECODER PERFORMANCE STATISTICS API
Expand Down Expand Up @@ -273,6 +280,10 @@ class VideoDecoder {
DecodedOutput convertAVFrameToDecodedOutput(
int streamIndex,
UniqueAVFrame frame);
torch::Tensor getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand Down
27 changes: 21 additions & 6 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("get_next_frame(Tensor(a!) decoder) -> Tensor");
m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor");
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> Tensor");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -124,24 +126,36 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) {

at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
int64_t stream_index) {
int64_t stream_index,
int64_t frame_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return result.frame;
}

at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
int64_t stream_index) {
int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec);
return result.frames;
}

at::Tensor get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result = videoDecoder->getFramesInRange(
stream_index, start, stop, step.value_or(1));
return result.frames;
}

std::string quoteValue(const std::string& value) {
return "\"" + value + "\"";
}
Expand Down Expand Up @@ -330,6 +344,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("get_frame_at_pts", &get_frame_at_pts);
m.impl("get_frame_at_index", &get_frame_at_index);
m.impl("get_frames_at_indices", &get_frames_at_indices);
m.impl("get_frames_in_range", &get_frames_in_range);
}

} // namespace facebook::torchcodec
17 changes: 13 additions & 4 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,24 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds);
// Return the frame that is visible at a given index in the video.
at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
std::optional<int64_t> stream_index = std::nullopt);
int64_t stream_index,
int64_t frame_index);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
std::optional<int64_t> stream_index = std::nullopt);
int64_t stream_index,
at::IntArrayRef frame_indices);

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
at::Tensor get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt);

// Get the next frame from the video as a tensor.
at::Tensor get_next_frame(at::Tensor& decoder);
Expand Down
16 changes: 15 additions & 1 deletion src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def load_torchcodec_extension():
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
_get_container_json_metadata = (
torch.ops.torchcodec_ns.get_container_json_metadata.default
Expand Down Expand Up @@ -118,7 +119,7 @@ def get_frame_at_pts_abstract(decoder: torch.Tensor, seconds: float) -> torch.Te

@register_fake("torchcodec_ns::get_frame_at_index")
def get_frame_at_index_abstract(
decoder: torch.Tensor, *, frame_index: int, stream_index: int
decoder: torch.Tensor, *, stream_index: int, frame_index: int
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(3)]
return torch.empty(image_size)
Expand All @@ -128,8 +129,21 @@ def get_frame_at_index_abstract(
def get_frames_at_indices_abstract(
decoder: torch.Tensor,
*,
stream_index: int,
frame_indices: List[int],
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)


@register_fake("torchcodec_ns::get_frames_in_range")
def get_frames_in_range_abstract(
decoder: torch.Tensor,
*,
stream_index: int,
start: int,
stop: int,
step: Optional[int] = None,
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def _get_clips_for_index_based_sampling(
]
frames = get_frames_at_indices(
video_decoder,
frame_indices=batch_indexes,
stream_index=metadata_json["bestVideoStreamIndex"],
frame_indices=batch_indexes,
)
clips.append(frames)

Expand Down
File renamed without changes.
33 changes: 0 additions & 33 deletions test/decoders/generate_reference_resources.sh

This file was deleted.

Loading

0 comments on commit d4a24fc

Please sign in to comment.