Skip to content

Commit c7848ec

Browse files
Dmitry Rogozhkindvrogozh
authored andcommitted
Move filter graph to stand alone class
FFmpeg filter graphs allow to cover a lot of use cases including cpu and gpu usages. This commit moves filter graph support out of CPU device interface which allows flexibility in usage across other contexts. Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 6dc8d12 commit c7848ec

File tree

5 files changed

+185
-148
lines changed

5 files changed

+185
-148
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ function(make_torchcodec_libraries
8888
AVIOContextHolder.cpp
8989
AVIOTensorContext.cpp
9090
FFMPEGCommon.cpp
91+
FilterGraph.cpp
9192
Frame.cpp
9293
DeviceInterface.cpp
9394
CpuDeviceInterface.cpp

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 4 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66

77
#include "src/torchcodec/_core/CpuDeviceInterface.h"
88

9-
extern "C" {
10-
#include <libavfilter/buffersink.h>
11-
#include <libavfilter/buffersrc.h>
12-
}
13-
149
namespace facebook::torchcodec {
1510
namespace {
1611

@@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface(
2015

2116
} // namespace
2217

23-
bool CpuDeviceInterface::DecodedFrameContext::operator==(
24-
const CpuDeviceInterface::DecodedFrameContext& other) {
25-
return decodedWidth == other.decodedWidth &&
26-
decodedHeight == other.decodedHeight &&
27-
decodedFormat == other.decodedFormat &&
28-
expectedWidth == other.expectedWidth &&
29-
expectedHeight == other.expectedHeight;
30-
}
31-
32-
bool CpuDeviceInterface::DecodedFrameContext::operator!=(
33-
const CpuDeviceInterface::DecodedFrameContext& other) {
34-
return !(*this == other);
35-
}
36-
3718
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
3819
: DeviceInterface(device) {
3920
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
@@ -132,8 +113,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132113

133114
frameOutput.data = outputTensor;
134115
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135-
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136-
createFilterGraph(frameContext, videoStreamOptions, timeBase);
116+
if (!filterGraphContext_ || prevFrameContext_ != frameContext) {
117+
filterGraphContext_ = std::make_unique<FilterGraph>(
118+
frameContext, videoStreamOptions, timeBase);
137119
prevFrameContext_ = frameContext;
138120
}
139121
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
@@ -187,14 +169,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187169

188170
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189171
const UniqueAVFrame& avFrame) {
190-
int status = av_buffersrc_write_frame(
191-
filterGraphContext_.sourceContext, avFrame.get());
192-
TORCH_CHECK(
193-
status >= AVSUCCESS, "Failed to add frame to buffer source context");
172+
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
194173

195-
UniqueAVFrame filteredAVFrame(av_frame_alloc());
196-
status = av_buffersink_get_frame(
197-
filterGraphContext_.sinkContext, filteredAVFrame.get());
198174
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
199175

200176
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
@@ -210,108 +186,6 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210186
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
211187
}
212188

213-
void CpuDeviceInterface::createFilterGraph(
214-
const DecodedFrameContext& frameContext,
215-
const VideoStreamOptions& videoStreamOptions,
216-
const AVRational& timeBase) {
217-
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
218-
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
219-
220-
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
221-
filterGraphContext_.filterGraph->nb_threads =
222-
videoStreamOptions.ffmpegThreadCount.value();
223-
}
224-
225-
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
226-
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
227-
228-
std::stringstream filterArgs;
229-
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
230-
<< frameContext.decodedHeight;
231-
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
232-
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
233-
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
234-
<< frameContext.decodedAspectRatio.den;
235-
236-
int status = avfilter_graph_create_filter(
237-
&filterGraphContext_.sourceContext,
238-
buffersrc,
239-
"in",
240-
filterArgs.str().c_str(),
241-
nullptr,
242-
filterGraphContext_.filterGraph.get());
243-
TORCH_CHECK(
244-
status >= 0,
245-
"Failed to create filter graph: ",
246-
filterArgs.str(),
247-
": ",
248-
getFFMPEGErrorStringFromErrorCode(status));
249-
250-
status = avfilter_graph_create_filter(
251-
&filterGraphContext_.sinkContext,
252-
buffersink,
253-
"out",
254-
nullptr,
255-
nullptr,
256-
filterGraphContext_.filterGraph.get());
257-
TORCH_CHECK(
258-
status >= 0,
259-
"Failed to create filter graph: ",
260-
getFFMPEGErrorStringFromErrorCode(status));
261-
262-
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263-
264-
status = av_opt_set_int_list(
265-
filterGraphContext_.sinkContext,
266-
"pix_fmts",
267-
pix_fmts,
268-
AV_PIX_FMT_NONE,
269-
AV_OPT_SEARCH_CHILDREN);
270-
TORCH_CHECK(
271-
status >= 0,
272-
"Failed to set output pixel formats: ",
273-
getFFMPEGErrorStringFromErrorCode(status));
274-
275-
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
276-
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
277-
278-
outputs->name = av_strdup("in");
279-
outputs->filter_ctx = filterGraphContext_.sourceContext;
280-
outputs->pad_idx = 0;
281-
outputs->next = nullptr;
282-
inputs->name = av_strdup("out");
283-
inputs->filter_ctx = filterGraphContext_.sinkContext;
284-
inputs->pad_idx = 0;
285-
inputs->next = nullptr;
286-
287-
std::stringstream description;
288-
description << "scale=" << frameContext.expectedWidth << ":"
289-
<< frameContext.expectedHeight;
290-
description << ":sws_flags=bilinear";
291-
292-
AVFilterInOut* outputsTmp = outputs.release();
293-
AVFilterInOut* inputsTmp = inputs.release();
294-
status = avfilter_graph_parse_ptr(
295-
filterGraphContext_.filterGraph.get(),
296-
description.str().c_str(),
297-
&inputsTmp,
298-
&outputsTmp,
299-
nullptr);
300-
outputs.reset(outputsTmp);
301-
inputs.reset(inputsTmp);
302-
TORCH_CHECK(
303-
status >= 0,
304-
"Failed to parse filter description: ",
305-
getFFMPEGErrorStringFromErrorCode(status));
306-
307-
status =
308-
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
309-
TORCH_CHECK(
310-
status >= 0,
311-
"Failed to configure filter graph: ",
312-
getFFMPEGErrorStringFromErrorCode(status));
313-
}
314-
315189
void CpuDeviceInterface::createSwsContext(
316190
const DecodedFrameContext& frameContext,
317191
const enum AVColorSpace colorspace) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "src/torchcodec/_core/DeviceInterface.h"
1010
#include "src/torchcodec/_core/FFMPEGCommon.h"
11+
#include "src/torchcodec/_core/FilterGraph.h"
1112

1213
namespace facebook::torchcodec {
1314

@@ -41,23 +42,6 @@ class CpuDeviceInterface : public DeviceInterface {
4142
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
4243
const UniqueAVFrame& avFrame);
4344

44-
struct FilterGraphContext {
45-
UniqueAVFilterGraph filterGraph;
46-
AVFilterContext* sourceContext = nullptr;
47-
AVFilterContext* sinkContext = nullptr;
48-
};
49-
50-
struct DecodedFrameContext {
51-
int decodedWidth;
52-
int decodedHeight;
53-
AVPixelFormat decodedFormat;
54-
AVRational decodedAspectRatio;
55-
int expectedWidth;
56-
int expectedHeight;
57-
bool operator==(const DecodedFrameContext&);
58-
bool operator!=(const DecodedFrameContext&);
59-
};
60-
6145
void createSwsContext(
6246
const DecodedFrameContext& frameContext,
6347
const enum AVColorSpace colorspace);
@@ -69,7 +53,7 @@ class CpuDeviceInterface : public DeviceInterface {
6953

7054
// color-conversion fields. Only one of FilterGraphContext and
7155
// UniqueSwsContext should be non-null.
72-
FilterGraphContext filterGraphContext_;
56+
std::unique_ptr<FilterGraph> filterGraphContext_;
7357
UniqueSwsContext swsContext_;
7458

7559
// Used to know whether a new FilterGraphContext or UniqueSwsContext should

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/FilterGraph.h"
8+
9+
extern "C" {
10+
#include <libavfilter/buffersink.h>
11+
#include <libavfilter/buffersrc.h>
12+
}
13+
14+
namespace facebook::torchcodec {
15+
16+
bool DecodedFrameContext::operator==(const DecodedFrameContext& other) {
17+
return decodedWidth == other.decodedWidth &&
18+
decodedHeight == other.decodedHeight &&
19+
decodedFormat == other.decodedFormat &&
20+
expectedWidth == other.expectedWidth &&
21+
expectedHeight == other.expectedHeight;
22+
}
23+
24+
bool DecodedFrameContext::operator!=(const DecodedFrameContext& other) {
25+
return !(*this == other);
26+
}
27+
28+
FilterGraph::FilterGraph(
29+
const DecodedFrameContext& frameContext,
30+
const VideoStreamOptions& videoStreamOptions,
31+
const AVRational& timeBase) {
32+
filterGraph_.reset(avfilter_graph_alloc());
33+
TORCH_CHECK(filterGraph_.get() != nullptr);
34+
35+
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
36+
filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value();
37+
}
38+
39+
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
40+
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
41+
42+
std::stringstream filterArgs;
43+
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
44+
<< frameContext.decodedHeight;
45+
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
46+
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
47+
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
48+
<< frameContext.decodedAspectRatio.den;
49+
50+
int status = avfilter_graph_create_filter(
51+
&sourceContext_,
52+
buffersrc,
53+
"in",
54+
filterArgs.str().c_str(),
55+
nullptr,
56+
filterGraph_.get());
57+
TORCH_CHECK(
58+
status >= 0,
59+
"Failed to create filter graph: ",
60+
filterArgs.str(),
61+
": ",
62+
getFFMPEGErrorStringFromErrorCode(status));
63+
64+
status = avfilter_graph_create_filter(
65+
&sinkContext_, buffersink, "out", nullptr, nullptr, filterGraph_.get());
66+
TORCH_CHECK(
67+
status >= 0,
68+
"Failed to create filter graph: ",
69+
getFFMPEGErrorStringFromErrorCode(status));
70+
71+
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
72+
73+
status = av_opt_set_int_list(
74+
sinkContext_,
75+
"pix_fmts",
76+
pix_fmts,
77+
AV_PIX_FMT_NONE,
78+
AV_OPT_SEARCH_CHILDREN);
79+
TORCH_CHECK(
80+
status >= 0,
81+
"Failed to set output pixel formats: ",
82+
getFFMPEGErrorStringFromErrorCode(status));
83+
84+
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
85+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
86+
87+
outputs->name = av_strdup("in");
88+
outputs->filter_ctx = sourceContext_;
89+
outputs->pad_idx = 0;
90+
outputs->next = nullptr;
91+
inputs->name = av_strdup("out");
92+
inputs->filter_ctx = sinkContext_;
93+
inputs->pad_idx = 0;
94+
inputs->next = nullptr;
95+
96+
std::stringstream description;
97+
description << "scale=" << frameContext.expectedWidth << ":"
98+
<< frameContext.expectedHeight;
99+
description << ":sws_flags=bilinear";
100+
101+
AVFilterInOut* outputsTmp = outputs.release();
102+
AVFilterInOut* inputsTmp = inputs.release();
103+
status = avfilter_graph_parse_ptr(
104+
filterGraph_.get(),
105+
description.str().c_str(),
106+
&inputsTmp,
107+
&outputsTmp,
108+
nullptr);
109+
outputs.reset(outputsTmp);
110+
inputs.reset(inputsTmp);
111+
TORCH_CHECK(
112+
status >= 0,
113+
"Failed to parse filter description: ",
114+
getFFMPEGErrorStringFromErrorCode(status));
115+
116+
status = avfilter_graph_config(filterGraph_.get(), nullptr);
117+
TORCH_CHECK(
118+
status >= 0,
119+
"Failed to configure filter graph: ",
120+
getFFMPEGErrorStringFromErrorCode(status));
121+
}
122+
123+
UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {
124+
int status = av_buffersrc_write_frame(sourceContext_, avFrame.get());
125+
TORCH_CHECK(
126+
status >= AVSUCCESS, "Failed to add frame to buffer source context");
127+
128+
UniqueAVFrame filteredAVFrame(av_frame_alloc());
129+
status = av_buffersink_get_frame(sinkContext_, filteredAVFrame.get());
130+
TORCH_CHECK(
131+
status >= AVSUCCESS, "Failed to fet frame from buffer sink context");
132+
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
133+
134+
return filteredAVFrame;
135+
}
136+
137+
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include "src/torchcodec/_core/FFMPEGCommon.h"
10+
#include "src/torchcodec/_core/StreamOptions.h"
11+
12+
namespace facebook::torchcodec {
13+
14+
struct DecodedFrameContext {
15+
int decodedWidth;
16+
int decodedHeight;
17+
AVPixelFormat decodedFormat;
18+
AVRational decodedAspectRatio;
19+
int expectedWidth;
20+
int expectedHeight;
21+
22+
bool operator==(const DecodedFrameContext&);
23+
bool operator!=(const DecodedFrameContext&);
24+
};
25+
26+
class FilterGraph {
27+
public:
28+
FilterGraph(
29+
const DecodedFrameContext& frameContext,
30+
const VideoStreamOptions& videoStreamOptions,
31+
const AVRational& timeBase);
32+
33+
UniqueAVFrame convert(const UniqueAVFrame& avFrame);
34+
35+
private:
36+
UniqueAVFilterGraph filterGraph_;
37+
AVFilterContext* sourceContext_ = nullptr;
38+
AVFilterContext* sinkContext_ = nullptr;
39+
};
40+
41+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)