diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 0793c806..03f68f6b 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -88,6 +88,7 @@ function(make_torchcodec_libraries AVIOContextHolder.cpp AVIOTensorContext.cpp FFMPEGCommon.cpp + FilterGraph.cpp Frame.cpp DeviceInterface.cpp CpuDeviceInterface.cpp diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 4d0cbddf..ce24f20b 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -6,11 +6,6 @@ #include "src/torchcodec/_core/CpuDeviceInterface.h" -extern "C" { -#include -#include -} - namespace facebook::torchcodec { namespace { @@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface( } // namespace -bool CpuDeviceInterface::DecodedFrameContext::operator==( - const CpuDeviceInterface::DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; +bool CpuDeviceInterface::SwsFrameContext::operator==( + const CpuDeviceInterface::SwsFrameContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight; } -bool CpuDeviceInterface::DecodedFrameContext::operator!=( - const CpuDeviceInterface::DecodedFrameContext& other) { +bool CpuDeviceInterface::SwsFrameContext::operator!=( + const CpuDeviceInterface::SwsFrameContext& other) const { return !(*this == other); } @@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( } torch::Tensor outputTensor; - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. enum AVPixelFormat frameFormat = static_cast(avFrame->format); - auto frameContext = DecodedFrameContext{ - avFrame->width, - avFrame->height, - frameFormat, - avFrame->sample_aspect_ratio, - expectedOutputWidth, - expectedOutputHeight}; // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + SwsFrameContext swsFrameContext; + + swsFrameContext.inputWidth = avFrame->width; + swsFrameContext.inputHeight = avFrame->height; + swsFrameContext.inputFormat = frameFormat; + swsFrameContext.outputWidth = expectedOutputWidth; + swsFrameContext.outputHeight = expectedOutputHeight; + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - if (!swsContext_ || prevFrameContext_ != frameContext) { - createSwsContext(frameContext, avFrame->colorspace); - prevFrameContext_ = frameContext; + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { + createSwsContext(swsFrameContext, avFrame->colorspace); + prevSwsFrameContext_ = swsFrameContext; } int resultHeight = convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); @@ -132,9 +126,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { - createFilterGraph(frameContext, videoStreamOptions, timeBase); - prevFrameContext_ = frameContext; + FiltersContext filtersContext; + + filtersContext.inputWidth = avFrame->width; + filtersContext.inputHeight = avFrame->height; + filtersContext.inputFormat = frameFormat; + filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; + filtersContext.outputWidth = expectedOutputWidth; + filtersContext.outputHeight = expectedOutputHeight; + filtersContext.outputFormat = AV_PIX_FMT_RGB24; + filtersContext.timeBase = timeBase; + + std::stringstream filters; + filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; + filters << ":sws_flags=bilinear"; + + filtersContext.filtergraphStr = filters.str(); + + if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { + filterGraphContext_ = + std::make_unique(filtersContext, videoStreamOptions); + prevFiltersContext_ = std::move(filtersContext); } outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); @@ -187,14 +199,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame) { - int status = av_buffersrc_write_frame( - filterGraphContext_.sourceContext, avFrame.get()); - TORCH_CHECK( - status >= AVSUCCESS, "Failed to add frame to buffer source context"); + UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); - UniqueAVFrame filteredAVFrame(av_frame_alloc()); - status = av_buffersink_get_frame( - filterGraphContext_.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); @@ -210,117 +216,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); } -void CpuDeviceInterface::createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase) { - filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); - - if (videoStreamOptions.ffmpegThreadCount.has_value()) { - filterGraphContext_.filterGraph->nb_threads = - videoStreamOptions.ffmpegThreadCount.value(); - } - - const AVFilter* buffersrc = avfilter_get_by_name("buffer"); - const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - - std::stringstream filterArgs; - filterArgs << "video_size=" << frameContext.decodedWidth << "x" - << frameContext.decodedHeight; - filterArgs << ":pix_fmt=" << frameContext.decodedFormat; - filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; - filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" - << frameContext.decodedAspectRatio.den; - - int status = avfilter_graph_create_filter( - &filterGraphContext_.sourceContext, - buffersrc, - "in", - filterArgs.str().c_str(), - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - filterArgs.str(), - ": ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = avfilter_graph_create_filter( - &filterGraphContext_.sinkContext, - buffersink, - "out", - nullptr, - nullptr, - filterGraphContext_.filterGraph.get()); - TORCH_CHECK( - status >= 0, - "Failed to create filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); - - enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; - - status = av_opt_set_int_list( - filterGraphContext_.sinkContext, - "pix_fmts", - pix_fmts, - AV_PIX_FMT_NONE, - AV_OPT_SEARCH_CHILDREN); - TORCH_CHECK( - status >= 0, - "Failed to set output pixel formats: ", - getFFMPEGErrorStringFromErrorCode(status)); - - UniqueAVFilterInOut outputs(avfilter_inout_alloc()); - UniqueAVFilterInOut inputs(avfilter_inout_alloc()); - - outputs->name = av_strdup("in"); - outputs->filter_ctx = filterGraphContext_.sourceContext; - outputs->pad_idx = 0; - outputs->next = nullptr; - inputs->name = av_strdup("out"); - inputs->filter_ctx = filterGraphContext_.sinkContext; - inputs->pad_idx = 0; - inputs->next = nullptr; - - std::stringstream description; - description << "scale=" << frameContext.expectedWidth << ":" - << frameContext.expectedHeight; - description << ":sws_flags=bilinear"; - - AVFilterInOut* outputsTmp = outputs.release(); - AVFilterInOut* inputsTmp = inputs.release(); - status = avfilter_graph_parse_ptr( - filterGraphContext_.filterGraph.get(), - description.str().c_str(), - &inputsTmp, - &outputsTmp, - nullptr); - outputs.reset(outputsTmp); - inputs.reset(inputsTmp); - TORCH_CHECK( - status >= 0, - "Failed to parse filter description: ", - getFFMPEGErrorStringFromErrorCode(status)); - - status = - avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); - TORCH_CHECK( - status >= 0, - "Failed to configure filter graph: ", - getFFMPEGErrorStringFromErrorCode(status)); -} - void CpuDeviceInterface::createSwsContext( - const DecodedFrameContext& frameContext, + const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace) { SwsContext* swsContext = sws_getContext( - frameContext.decodedWidth, - frameContext.decodedHeight, - frameContext.decodedFormat, - frameContext.expectedWidth, - frameContext.expectedHeight, + swsFrameContext.inputWidth, + swsFrameContext.inputHeight, + swsFrameContext.inputFormat, + swsFrameContext.outputWidth, + swsFrameContext.outputHeight, AV_PIX_FMT_RGB24, SWS_BILINEAR, nullptr, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 404289bd..5d142913 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -8,6 +8,7 @@ #include "src/torchcodec/_core/DeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/FilterGraph.h" namespace facebook::torchcodec { @@ -41,40 +42,29 @@ class CpuDeviceInterface : public DeviceInterface { torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame); - struct FilterGraphContext { - UniqueAVFilterGraph filterGraph; - AVFilterContext* sourceContext = nullptr; - AVFilterContext* sinkContext = nullptr; - }; - - struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - AVRational decodedAspectRatio; - int expectedWidth; - int expectedHeight; - bool operator==(const DecodedFrameContext&); - bool operator!=(const DecodedFrameContext&); + struct SwsFrameContext { + int inputWidth; + int inputHeight; + AVPixelFormat inputFormat; + int outputWidth; + int outputHeight; + bool operator==(const SwsFrameContext&) const; + bool operator!=(const SwsFrameContext&) const; }; void createSwsContext( - const DecodedFrameContext& frameContext, + const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace); - void createFilterGraph( - const DecodedFrameContext& frameContext, - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase); - // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. - FilterGraphContext filterGraphContext_; + std::unique_ptr filterGraphContext_; UniqueSwsContext swsContext_; // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. - DecodedFrameContext prevFrameContext_; + SwsFrameContext prevSwsFrameContext_; + FiltersContext prevFiltersContext_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index e03f8079..b8c9e621 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -13,6 +13,7 @@ extern "C" { #include #include +#include #include #include #include @@ -41,6 +42,15 @@ struct Deleterp { } }; +template +struct Deleterv { + inline void operator()(T* p) const { + if (p) { + Fn(&p); + } + } +}; + template struct Deleter { inline void operator()(T* p) const { @@ -78,6 +88,9 @@ using UniqueAVAudioFifo = std:: unique_ptr>; using UniqueAVBufferRef = std::unique_ptr>; +using UniqueAVBufferSrcParameters = std::unique_ptr< + AVBufferSrcParameters, + Deleterv>; // These 2 classes share the same underlying AVPacket object. They are meant to // be used in tandem, like so: diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp new file mode 100644 index 00000000..f4e53b1b --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/FilterGraph.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { + +bool operator==(const AVRational& lhs, const AVRational& rhs) { + return lhs.num == rhs.num && lhs.den == rhs.den; +} + +bool FiltersContext::operator==(const FiltersContext& other) const { + return inputWidth == other.inputWidth && inputHeight == other.inputHeight && + inputFormat == other.inputFormat && outputWidth == other.outputWidth && + outputHeight == other.outputHeight && + outputFormat == other.outputFormat && + filtergraphStr == other.filtergraphStr && timeBase == other.timeBase && + hwFramesCtx.get() == other.hwFramesCtx.get(); +} + +bool FiltersContext::operator!=(const FiltersContext& other) const { + return !(*this == other); +} + +FilterGraph::FilterGraph( + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions) { + filterGraph_.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraph_.get() != nullptr); + + if (videoStreamOptions.ffmpegThreadCount.has_value()) { + filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value(); + } + + const AVFilter* buffersrc = avfilter_get_by_name("buffer"); + const AVFilter* buffersink = avfilter_get_by_name("buffersink"); + + UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc()); + TORCH_CHECK(srcParams, "Failed to allocate buffersrc params"); + + srcParams->format = filtersContext.inputFormat; + srcParams->width = filtersContext.inputWidth; + srcParams->height = filtersContext.inputHeight; + srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio; + srcParams->time_base = filtersContext.timeBase; + if (filtersContext.hwFramesCtx) { + srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get()); + } + + sourceContext_ = + avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in"); + TORCH_CHECK(sourceContext_, "Failed to allocate filter graph"); + + int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_init_str(sourceContext_, nullptr); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph : ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_create_filter( + &sinkContext_, buffersink, "out", nullptr, nullptr, filterGraph_.get()); + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); + + enum AVPixelFormat pix_fmts[] = { + filtersContext.outputFormat, AV_PIX_FMT_NONE}; + + status = av_opt_set_int_list( + sinkContext_, + "pix_fmts", + pix_fmts, + AV_PIX_FMT_NONE, + AV_OPT_SEARCH_CHILDREN); + TORCH_CHECK( + status >= 0, + "Failed to set output pixel formats: ", + getFFMPEGErrorStringFromErrorCode(status)); + + UniqueAVFilterInOut outputs(avfilter_inout_alloc()); + UniqueAVFilterInOut inputs(avfilter_inout_alloc()); + + outputs->name = av_strdup("in"); + outputs->filter_ctx = sourceContext_; + outputs->pad_idx = 0; + outputs->next = nullptr; + inputs->name = av_strdup("out"); + inputs->filter_ctx = sinkContext_; + inputs->pad_idx = 0; + inputs->next = nullptr; + + AVFilterInOut* outputsTmp = outputs.release(); + AVFilterInOut* inputsTmp = inputs.release(); + status = avfilter_graph_parse_ptr( + filterGraph_.get(), + filtersContext.filtergraphStr.c_str(), + &inputsTmp, + &outputsTmp, + nullptr); + outputs.reset(outputsTmp); + inputs.reset(inputsTmp); + TORCH_CHECK( + status >= 0, + "Failed to parse filter description: ", + getFFMPEGErrorStringFromErrorCode(status)); + + status = avfilter_graph_config(filterGraph_.get(), nullptr); + TORCH_CHECK( + status >= 0, + "Failed to configure filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); +} + +UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) { + int status = av_buffersrc_write_frame(sourceContext_, avFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to add frame to buffer source context"); + + UniqueAVFrame filteredAVFrame(av_frame_alloc()); + status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get()); + TORCH_CHECK( + status >= AVSUCCESS, "Failed to get frame from buffer sink context"); + + return filteredAVFrame; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FilterGraph.h b/src/torchcodec/_core/FilterGraph.h new file mode 100644 index 00000000..a99507dc --- /dev/null +++ b/src/torchcodec/_core/FilterGraph.h @@ -0,0 +1,45 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/_core/FFMPEGCommon.h" +#include "src/torchcodec/_core/StreamOptions.h" + +namespace facebook::torchcodec { + +struct FiltersContext { + int inputWidth = 0; + int inputHeight = 0; + AVPixelFormat inputFormat = AV_PIX_FMT_NONE; + AVRational inputAspectRatio = {0, 0}; + int outputWidth = 0; + int outputHeight = 0; + AVPixelFormat outputFormat = AV_PIX_FMT_NONE; + + std::string filtergraphStr; + AVRational timeBase = {0, 0}; + UniqueAVBufferRef hwFramesCtx; + + bool operator==(const FiltersContext&) const; + bool operator!=(const FiltersContext&) const; +}; + +class FilterGraph { + public: + FilterGraph( + const FiltersContext& filtersContext, + const VideoStreamOptions& videoStreamOptions); + + UniqueAVFrame convert(const UniqueAVFrame& avFrame); + + private: + UniqueAVFilterGraph filterGraph_; + AVFilterContext* sourceContext_ = nullptr; + AVFilterContext* sinkContext_ = nullptr; +}; + +} // namespace facebook::torchcodec