6
6
7
7
#include " src/torchcodec/_core/CpuDeviceInterface.h"
8
8
9
- extern " C" {
10
- #include < libavfilter/buffersink.h>
11
- #include < libavfilter/buffersrc.h>
12
- }
13
-
14
9
namespace facebook ::torchcodec {
15
10
namespace {
16
11
@@ -20,20 +15,6 @@ static bool g_cpu = registerDeviceInterface(
20
15
21
16
} // namespace
22
17
23
- bool CpuDeviceInterface::DecodedFrameContext::operator ==(
24
- const CpuDeviceInterface::DecodedFrameContext& other) {
25
- return decodedWidth == other.decodedWidth &&
26
- decodedHeight == other.decodedHeight &&
27
- decodedFormat == other.decodedFormat &&
28
- expectedWidth == other.expectedWidth &&
29
- expectedHeight == other.expectedHeight ;
30
- }
31
-
32
- bool CpuDeviceInterface::DecodedFrameContext::operator !=(
33
- const CpuDeviceInterface::DecodedFrameContext& other) {
34
- return !(*this == other);
35
- }
36
-
37
18
CpuDeviceInterface::CpuDeviceInterface (const torch::Device& device)
38
19
: DeviceInterface(device) {
39
20
TORCH_CHECK (g_cpu, " CpuDeviceInterface was not registered!" );
@@ -132,8 +113,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132
113
133
114
frameOutput.data = outputTensor;
134
115
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135
- if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136
- createFilterGraph (frameContext, videoStreamOptions, timeBase);
116
+ if (!filterGraphContext_ || prevFrameContext_ != frameContext) {
117
+ filterGraphContext_ = std::make_unique<FilterGraph>(
118
+ frameContext, videoStreamOptions, timeBase);
137
119
prevFrameContext_ = frameContext;
138
120
}
139
121
outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
@@ -187,14 +169,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187
169
188
170
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
189
171
const UniqueAVFrame& avFrame) {
190
- int status = av_buffersrc_write_frame (
191
- filterGraphContext_.sourceContext , avFrame.get ());
192
- TORCH_CHECK (
193
- status >= AVSUCCESS, " Failed to add frame to buffer source context" );
172
+ UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
194
173
195
- UniqueAVFrame filteredAVFrame (av_frame_alloc ());
196
- status = av_buffersink_get_frame (
197
- filterGraphContext_.sinkContext , filteredAVFrame.get ());
198
174
TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
199
175
200
176
auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
@@ -210,108 +186,6 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210
186
filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
211
187
}
212
188
213
- void CpuDeviceInterface::createFilterGraph (
214
- const DecodedFrameContext& frameContext,
215
- const VideoStreamOptions& videoStreamOptions,
216
- const AVRational& timeBase) {
217
- filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
218
- TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
219
-
220
- if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
221
- filterGraphContext_.filterGraph ->nb_threads =
222
- videoStreamOptions.ffmpegThreadCount .value ();
223
- }
224
-
225
- const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
226
- const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
227
-
228
- std::stringstream filterArgs;
229
- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
230
- << frameContext.decodedHeight ;
231
- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
232
- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
233
- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
234
- << frameContext.decodedAspectRatio .den ;
235
-
236
- int status = avfilter_graph_create_filter (
237
- &filterGraphContext_.sourceContext ,
238
- buffersrc,
239
- " in" ,
240
- filterArgs.str ().c_str (),
241
- nullptr ,
242
- filterGraphContext_.filterGraph .get ());
243
- TORCH_CHECK (
244
- status >= 0 ,
245
- " Failed to create filter graph: " ,
246
- filterArgs.str (),
247
- " : " ,
248
- getFFMPEGErrorStringFromErrorCode (status));
249
-
250
- status = avfilter_graph_create_filter (
251
- &filterGraphContext_.sinkContext ,
252
- buffersink,
253
- " out" ,
254
- nullptr ,
255
- nullptr ,
256
- filterGraphContext_.filterGraph .get ());
257
- TORCH_CHECK (
258
- status >= 0 ,
259
- " Failed to create filter graph: " ,
260
- getFFMPEGErrorStringFromErrorCode (status));
261
-
262
- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263
-
264
- status = av_opt_set_int_list (
265
- filterGraphContext_.sinkContext ,
266
- " pix_fmts" ,
267
- pix_fmts,
268
- AV_PIX_FMT_NONE,
269
- AV_OPT_SEARCH_CHILDREN);
270
- TORCH_CHECK (
271
- status >= 0 ,
272
- " Failed to set output pixel formats: " ,
273
- getFFMPEGErrorStringFromErrorCode (status));
274
-
275
- UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
276
- UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
277
-
278
- outputs->name = av_strdup (" in" );
279
- outputs->filter_ctx = filterGraphContext_.sourceContext ;
280
- outputs->pad_idx = 0 ;
281
- outputs->next = nullptr ;
282
- inputs->name = av_strdup (" out" );
283
- inputs->filter_ctx = filterGraphContext_.sinkContext ;
284
- inputs->pad_idx = 0 ;
285
- inputs->next = nullptr ;
286
-
287
- std::stringstream description;
288
- description << " scale=" << frameContext.expectedWidth << " :"
289
- << frameContext.expectedHeight ;
290
- description << " :sws_flags=bilinear" ;
291
-
292
- AVFilterInOut* outputsTmp = outputs.release ();
293
- AVFilterInOut* inputsTmp = inputs.release ();
294
- status = avfilter_graph_parse_ptr (
295
- filterGraphContext_.filterGraph .get (),
296
- description.str ().c_str (),
297
- &inputsTmp,
298
- &outputsTmp,
299
- nullptr );
300
- outputs.reset (outputsTmp);
301
- inputs.reset (inputsTmp);
302
- TORCH_CHECK (
303
- status >= 0 ,
304
- " Failed to parse filter description: " ,
305
- getFFMPEGErrorStringFromErrorCode (status));
306
-
307
- status =
308
- avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
309
- TORCH_CHECK (
310
- status >= 0 ,
311
- " Failed to configure filter graph: " ,
312
- getFFMPEGErrorStringFromErrorCode (status));
313
- }
314
-
315
189
void CpuDeviceInterface::createSwsContext (
316
190
const DecodedFrameContext& frameContext,
317
191
const enum AVColorSpace colorspace) {
0 commit comments