Skip to content

Commit 84b01ce

Browse files
Dmitry Rogozhkindvrogozh
authored andcommitted
Generalize FilterGraph class to support HW backends
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent c7848ec commit 84b01ce

File tree

4 files changed

+98
-77
lines changed

4 files changed

+98
-77
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,22 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
6565
// conversion objects as much as possible for performance reasons.
6666
enum AVPixelFormat frameFormat =
6767
static_cast<enum AVPixelFormat>(avFrame->format);
68-
auto frameContext = DecodedFrameContext{
69-
avFrame->width,
70-
avFrame->height,
71-
frameFormat,
72-
avFrame->sample_aspect_ratio,
73-
expectedOutputWidth,
74-
expectedOutputHeight};
68+
FiltersContext filtersContext;
69+
70+
filtersContext.inputWidth = avFrame->width;
71+
filtersContext.inputHeight = avFrame->height;
72+
filtersContext.inputFormat = frameFormat;
73+
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
74+
filtersContext.outputWidth = expectedOutputWidth;
75+
filtersContext.outputHeight = expectedOutputHeight;
76+
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77+
filtersContext.timeBase = timeBase;
78+
79+
std::stringstream filters;
80+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
81+
filters << ":sws_flags=bilinear";
82+
83+
filtersContext.filters = filters.str();
7584

7685
// By default, we want to use swscale for color conversion because it is
7786
// faster. However, it has width requirements, so we may need to fall back
@@ -95,9 +104,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
95104
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
96105
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
97106

98-
if (!swsContext_ || prevFrameContext_ != frameContext) {
99-
createSwsContext(frameContext, avFrame->colorspace);
100-
prevFrameContext_ = frameContext;
107+
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108+
createSwsContext(filtersContext, avFrame->colorspace);
109+
prevFiltersContext_ = std::move(filtersContext);
101110
}
102111
int resultHeight =
103112
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
@@ -113,10 +122,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
113122

114123
frameOutput.data = outputTensor;
115124
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
116-
if (!filterGraphContext_ || prevFrameContext_ != frameContext) {
117-
filterGraphContext_ = std::make_unique<FilterGraph>(
118-
frameContext, videoStreamOptions, timeBase);
119-
prevFrameContext_ = frameContext;
125+
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126+
filterGraphContext_ =
127+
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
128+
prevFiltersContext_ = std::move(filtersContext);
120129
}
121130
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
122131

@@ -187,15 +196,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
187196
}
188197

189198
void CpuDeviceInterface::createSwsContext(
190-
const DecodedFrameContext& frameContext,
199+
const FiltersContext& filtersContext,
191200
const enum AVColorSpace colorspace) {
192201
SwsContext* swsContext = sws_getContext(
193-
frameContext.decodedWidth,
194-
frameContext.decodedHeight,
195-
frameContext.decodedFormat,
196-
frameContext.expectedWidth,
197-
frameContext.expectedHeight,
198-
AV_PIX_FMT_RGB24,
202+
filtersContext.inputWidth,
203+
filtersContext.inputHeight,
204+
filtersContext.inputFormat,
205+
filtersContext.outputWidth,
206+
filtersContext.outputHeight,
207+
filtersContext.outputFormat,
199208
SWS_BILINEAR,
200209
nullptr,
201210
nullptr,

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,17 @@ class CpuDeviceInterface : public DeviceInterface {
4343
const UniqueAVFrame& avFrame);
4444

4545
void createSwsContext(
46-
const DecodedFrameContext& frameContext,
46+
const FiltersContext& filtersContext,
4747
const enum AVColorSpace colorspace);
4848

49-
void createFilterGraph(
50-
const DecodedFrameContext& frameContext,
51-
const VideoStreamOptions& videoStreamOptions,
52-
const AVRational& timeBase);
53-
5449
// color-conversion fields. Only one of FilterGraphContext and
5550
// UniqueSwsContext should be non-null.
5651
std::unique_ptr<FilterGraph> filterGraphContext_;
5752
UniqueSwsContext swsContext_;
5853

5954
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
6055
// be created before decoding a new frame.
61-
DecodedFrameContext prevFrameContext_;
56+
FiltersContext prevFiltersContext_;
6257
};
6358

6459
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,26 @@ extern "C" {
1313

1414
namespace facebook::torchcodec {
1515

16-
bool DecodedFrameContext::operator==(const DecodedFrameContext& other) {
17-
return decodedWidth == other.decodedWidth &&
18-
decodedHeight == other.decodedHeight &&
19-
decodedFormat == other.decodedFormat &&
20-
expectedWidth == other.expectedWidth &&
21-
expectedHeight == other.expectedHeight;
16+
bool operator==(const AVRational& lhs, const AVRational& rhs) {
17+
return lhs.num == rhs.num && lhs.den == rhs.den;
2218
}
2319

24-
bool DecodedFrameContext::operator!=(const DecodedFrameContext& other) {
20+
bool FiltersContext::operator==(const FiltersContext& other) {
21+
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
22+
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
23+
outputHeight == other.outputHeight &&
24+
outputFormat == other.outputFormat && filters == other.filters &&
25+
timeBase == other.timeBase &&
26+
hwFramesCtx.get() == other.hwFramesCtx.get();
27+
}
28+
29+
bool FiltersContext::operator!=(const FiltersContext& other) {
2530
return !(*this == other);
2631
}
2732

2833
FilterGraph::FilterGraph(
29-
const DecodedFrameContext& frameContext,
30-
const VideoStreamOptions& videoStreamOptions,
31-
const AVRational& timeBase) {
34+
const FiltersContext& filtersContext,
35+
const VideoStreamOptions& videoStreamOptions) {
3236
filterGraph_.reset(avfilter_graph_alloc());
3337
TORCH_CHECK(filterGraph_.get() != nullptr);
3438

@@ -39,26 +43,40 @@ FilterGraph::FilterGraph(
3943
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
4044
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
4145

42-
std::stringstream filterArgs;
43-
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
44-
<< frameContext.decodedHeight;
45-
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
46-
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
47-
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
48-
<< frameContext.decodedAspectRatio.den;
49-
50-
int status = avfilter_graph_create_filter(
51-
&sourceContext_,
52-
buffersrc,
53-
"in",
54-
filterArgs.str().c_str(),
55-
nullptr,
56-
filterGraph_.get());
46+
auto deleter = [](AVBufferSrcParameters* p) {
47+
if (p) {
48+
av_freep(&p);
49+
}
50+
};
51+
std::unique_ptr<AVBufferSrcParameters, decltype(deleter)> srcParams(
52+
nullptr, deleter);
53+
54+
srcParams.reset(av_buffersrc_parameters_alloc());
55+
TORCH_CHECK(srcParams, "Failed to allocate buffersrc params");
56+
57+
srcParams->format = filtersContext.inputFormat;
58+
srcParams->width = filtersContext.inputWidth;
59+
srcParams->height = filtersContext.inputHeight;
60+
srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio;
61+
srcParams->time_base = filtersContext.timeBase;
62+
if (filtersContext.hwFramesCtx) {
63+
srcParams->hw_frames_ctx = av_buffer_ref(filtersContext.hwFramesCtx.get());
64+
}
65+
66+
sourceContext_ =
67+
avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in");
68+
TORCH_CHECK(sourceContext_, "Failed to allocate filter graph");
69+
70+
int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get());
5771
TORCH_CHECK(
5872
status >= 0,
5973
"Failed to create filter graph: ",
60-
filterArgs.str(),
61-
": ",
74+
getFFMPEGErrorStringFromErrorCode(status));
75+
76+
status = avfilter_init_str(sourceContext_, nullptr);
77+
TORCH_CHECK(
78+
status >= 0,
79+
"Failed to create filter graph : ",
6280
getFFMPEGErrorStringFromErrorCode(status));
6381

6482
status = avfilter_graph_create_filter(
@@ -68,7 +86,8 @@ FilterGraph::FilterGraph(
6886
"Failed to create filter graph: ",
6987
getFFMPEGErrorStringFromErrorCode(status));
7088

71-
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
89+
enum AVPixelFormat pix_fmts[] = {
90+
filtersContext.outputFormat, AV_PIX_FMT_NONE};
7291

7392
status = av_opt_set_int_list(
7493
sinkContext_,
@@ -93,16 +112,11 @@ FilterGraph::FilterGraph(
93112
inputs->pad_idx = 0;
94113
inputs->next = nullptr;
95114

96-
std::stringstream description;
97-
description << "scale=" << frameContext.expectedWidth << ":"
98-
<< frameContext.expectedHeight;
99-
description << ":sws_flags=bilinear";
100-
101115
AVFilterInOut* outputsTmp = outputs.release();
102116
AVFilterInOut* inputsTmp = inputs.release();
103117
status = avfilter_graph_parse_ptr(
104118
filterGraph_.get(),
105-
description.str().c_str(),
119+
filtersContext.filters.c_str(),
106120
&inputsTmp,
107121
&outputsTmp,
108122
nullptr);
@@ -128,8 +142,7 @@ UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {
128142
UniqueAVFrame filteredAVFrame(av_frame_alloc());
129143
status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get());
130144
TORCH_CHECK(
131-
status >= AVSUCCESS, "Failed to fet frame from buffer sink context");
132-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
145+
status >= AVSUCCESS, "Failed to get frame from buffer sink context");
133146

134147
return filteredAVFrame;
135148
}

src/torchcodec/_core/FilterGraph.h

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,28 @@
1111

1212
namespace facebook::torchcodec {
1313

14-
struct DecodedFrameContext {
15-
int decodedWidth;
16-
int decodedHeight;
17-
AVPixelFormat decodedFormat;
18-
AVRational decodedAspectRatio;
19-
int expectedWidth;
20-
int expectedHeight;
21-
22-
bool operator==(const DecodedFrameContext&);
23-
bool operator!=(const DecodedFrameContext&);
14+
struct FiltersContext {
15+
int inputWidth = 0;
16+
int inputHeight = 0;
17+
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
18+
AVRational inputAspectRatio = {0, 0};
19+
int outputWidth = 0;
20+
int outputHeight = 0;
21+
AVPixelFormat outputFormat = AV_PIX_FMT_NONE;
22+
23+
std::string filters;
24+
AVRational timeBase = {0, 0};
25+
UniqueAVBufferRef hwFramesCtx;
26+
27+
bool operator==(const FiltersContext&);
28+
bool operator!=(const FiltersContext&);
2429
};
2530

2631
class FilterGraph {
2732
public:
2833
FilterGraph(
29-
const DecodedFrameContext& frameContext,
30-
const VideoStreamOptions& videoStreamOptions,
31-
const AVRational& timeBase);
34+
const FiltersContext& filtersContext,
35+
const VideoStreamOptions& videoStreamOptions);
3236

3337
UniqueAVFrame convert(const UniqueAVFrame& avFrame);
3438

0 commit comments

Comments
 (0)