Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function(make_torchcodec_libraries
AVIOContextHolder.cpp
AVIOTensorContext.cpp
FFMPEGCommon.cpp
FilterGraph.cpp
Frame.cpp
DeviceInterface.cpp
CpuDeviceInterface.cpp
Expand Down
179 changes: 31 additions & 148 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@

#include "src/torchcodec/_core/CpuDeviceInterface.h"

extern "C" {
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
}

namespace facebook::torchcodec {
namespace {

Expand All @@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface(

} // namespace

bool CpuDeviceInterface::DecodedFrameContext::operator==(
const CpuDeviceInterface::DecodedFrameContext& other) {
return decodedWidth == other.decodedWidth &&
decodedHeight == other.decodedHeight &&
decodedFormat == other.decodedFormat &&
expectedWidth == other.expectedWidth &&
expectedHeight == other.expectedHeight;
}

bool CpuDeviceInterface::DecodedFrameContext::operator!=(
const CpuDeviceInterface::DecodedFrameContext& other) {
return !(*this == other);
}

CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
Expand Down Expand Up @@ -84,13 +65,22 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);
auto frameContext = DecodedFrameContext{
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
expectedOutputWidth,
expectedOutputHeight};
FiltersContext filtersContext;

filtersContext.inputWidth = avFrame->width;
filtersContext.inputHeight = avFrame->height;
filtersContext.inputFormat = frameFormat;
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
filtersContext.outputWidth = expectedOutputWidth;
filtersContext.outputHeight = expectedOutputHeight;
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
filtersContext.timeBase = timeBase;

std::stringstream filters;
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
filters << ":sws_flags=bilinear";

filtersContext.filters = filters.str();

// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
Expand All @@ -114,9 +104,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!swsContext_ || prevFrameContext_ != frameContext) {
createSwsContext(frameContext, avFrame->colorspace);
prevFrameContext_ = frameContext;
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
createSwsContext(filtersContext, avFrame->colorspace);
prevFiltersContext_ = std::move(filtersContext);
}
int resultHeight =
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
Expand All @@ -132,9 +122,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(

frameOutput.data = outputTensor;
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
createFilterGraph(frameContext, videoStreamOptions, timeBase);
prevFrameContext_ = frameContext;
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
filterGraphContext_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
prevFiltersContext_ = std::move(filtersContext);
}
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);

Expand Down Expand Up @@ -187,14 +178,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame) {
int status = av_buffersrc_write_frame(
filterGraphContext_.sourceContext, avFrame.get());
TORCH_CHECK(
status >= AVSUCCESS, "Failed to add frame to buffer source context");
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);

UniqueAVFrame filteredAVFrame(av_frame_alloc());
status = av_buffersink_get_frame(
filterGraphContext_.sinkContext, filteredAVFrame.get());
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
Expand All @@ -210,118 +195,16 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
}

void CpuDeviceInterface::createFilterGraph(
const DecodedFrameContext& frameContext,
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase) {
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);

if (videoStreamOptions.ffmpegThreadCount.has_value()) {
filterGraphContext_.filterGraph->nb_threads =
videoStreamOptions.ffmpegThreadCount.value();
}

const AVFilter* buffersrc = avfilter_get_by_name("buffer");
const AVFilter* buffersink = avfilter_get_by_name("buffersink");

std::stringstream filterArgs;
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
<< frameContext.decodedHeight;
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
<< frameContext.decodedAspectRatio.den;

int status = avfilter_graph_create_filter(
&filterGraphContext_.sourceContext,
buffersrc,
"in",
filterArgs.str().c_str(),
nullptr,
filterGraphContext_.filterGraph.get());
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
filterArgs.str(),
": ",
getFFMPEGErrorStringFromErrorCode(status));

status = avfilter_graph_create_filter(
&filterGraphContext_.sinkContext,
buffersink,
"out",
nullptr,
nullptr,
filterGraphContext_.filterGraph.get());
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));

enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};

status = av_opt_set_int_list(
filterGraphContext_.sinkContext,
"pix_fmts",
pix_fmts,
AV_PIX_FMT_NONE,
AV_OPT_SEARCH_CHILDREN);
TORCH_CHECK(
status >= 0,
"Failed to set output pixel formats: ",
getFFMPEGErrorStringFromErrorCode(status));

UniqueAVFilterInOut outputs(avfilter_inout_alloc());
UniqueAVFilterInOut inputs(avfilter_inout_alloc());

outputs->name = av_strdup("in");
outputs->filter_ctx = filterGraphContext_.sourceContext;
outputs->pad_idx = 0;
outputs->next = nullptr;
inputs->name = av_strdup("out");
inputs->filter_ctx = filterGraphContext_.sinkContext;
inputs->pad_idx = 0;
inputs->next = nullptr;

std::stringstream description;
description << "scale=" << frameContext.expectedWidth << ":"
<< frameContext.expectedHeight;
description << ":sws_flags=bilinear";

AVFilterInOut* outputsTmp = outputs.release();
AVFilterInOut* inputsTmp = inputs.release();
status = avfilter_graph_parse_ptr(
filterGraphContext_.filterGraph.get(),
description.str().c_str(),
&inputsTmp,
&outputsTmp,
nullptr);
outputs.reset(outputsTmp);
inputs.reset(inputsTmp);
TORCH_CHECK(
status >= 0,
"Failed to parse filter description: ",
getFFMPEGErrorStringFromErrorCode(status));

status =
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
TORCH_CHECK(
status >= 0,
"Failed to configure filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));
}

void CpuDeviceInterface::createSwsContext(
const DecodedFrameContext& frameContext,
const FiltersContext& filtersContext,
const enum AVColorSpace colorspace) {
SwsContext* swsContext = sws_getContext(
frameContext.decodedWidth,
frameContext.decodedHeight,
frameContext.decodedFormat,
frameContext.expectedWidth,
frameContext.expectedHeight,
AV_PIX_FMT_RGB24,
filtersContext.inputWidth,
filtersContext.inputHeight,
filtersContext.inputFormat,
filtersContext.outputWidth,
filtersContext.outputHeight,
filtersContext.outputFormat,
SWS_BILINEAR,
nullptr,
nullptr,
Expand Down
29 changes: 4 additions & 25 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "src/torchcodec/_core/DeviceInterface.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/FilterGraph.h"

namespace facebook::torchcodec {

Expand Down Expand Up @@ -41,40 +42,18 @@ class CpuDeviceInterface : public DeviceInterface {
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame);

struct FilterGraphContext {
UniqueAVFilterGraph filterGraph;
AVFilterContext* sourceContext = nullptr;
AVFilterContext* sinkContext = nullptr;
};

struct DecodedFrameContext {
int decodedWidth;
int decodedHeight;
AVPixelFormat decodedFormat;
AVRational decodedAspectRatio;
int expectedWidth;
int expectedHeight;
bool operator==(const DecodedFrameContext&);
bool operator!=(const DecodedFrameContext&);
};

void createSwsContext(
const DecodedFrameContext& frameContext,
const FiltersContext& filtersContext,
const enum AVColorSpace colorspace);

void createFilterGraph(
const DecodedFrameContext& frameContext,
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase);

// color-conversion fields. Only one of FilterGraphContext and
// UniqueSwsContext should be non-null.
FilterGraphContext filterGraphContext_;
std::unique_ptr<FilterGraph> filterGraphContext_;
UniqueSwsContext swsContext_;

// Used to know whether a new FilterGraphContext or UniqueSwsContext should
// be created before decoding a new frame.
DecodedFrameContext prevFrameContext_;
FiltersContext prevFiltersContext_;
};

} // namespace facebook::torchcodec
Loading
Loading