@@ -13,22 +13,26 @@ extern "C" {
13
13
14
14
namespace facebook ::torchcodec {
15
15
16
- bool DecodedFrameContext::operator ==(const DecodedFrameContext& other) {
17
- return decodedWidth == other.decodedWidth &&
18
- decodedHeight == other.decodedHeight &&
19
- decodedFormat == other.decodedFormat &&
20
- expectedWidth == other.expectedWidth &&
21
- expectedHeight == other.expectedHeight ;
16
+ bool operator ==(const AVRational& lhs, const AVRational& rhs) {
17
+ return lhs.num == rhs.num && lhs.den == rhs.den ;
22
18
}
23
19
24
- bool DecodedFrameContext::operator !=(const DecodedFrameContext& other) {
20
+ bool FiltersContext::operator ==(const FiltersContext& other) {
21
+ return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
22
+ inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
23
+ outputHeight == other.outputHeight &&
24
+ outputFormat == other.outputFormat && filters == other.filters &&
25
+ timeBase == other.timeBase &&
26
+ hwFramesCtx.get () == other.hwFramesCtx .get ();
27
+ }
28
+
29
+ bool FiltersContext::operator !=(const FiltersContext& other) {
25
30
return !(*this == other);
26
31
}
27
32
28
33
FilterGraph::FilterGraph (
29
- const DecodedFrameContext& frameContext,
30
- const VideoStreamOptions& videoStreamOptions,
31
- const AVRational& timeBase) {
34
+ const FiltersContext& filtersContext,
35
+ const VideoStreamOptions& videoStreamOptions) {
32
36
filterGraph_.reset (avfilter_graph_alloc ());
33
37
TORCH_CHECK (filterGraph_.get () != nullptr );
34
38
@@ -39,26 +43,40 @@ FilterGraph::FilterGraph(
39
43
const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
40
44
const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
41
45
42
- std::stringstream filterArgs;
43
- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
44
- << frameContext.decodedHeight ;
45
- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
46
- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
47
- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
48
- << frameContext.decodedAspectRatio .den ;
49
-
50
- int status = avfilter_graph_create_filter (
51
- &sourceContext_,
52
- buffersrc,
53
- " in" ,
54
- filterArgs.str ().c_str (),
55
- nullptr ,
56
- filterGraph_.get ());
46
+ auto deleter = [](AVBufferSrcParameters* p) {
47
+ if (p) {
48
+ av_freep (&p);
49
+ }
50
+ };
51
+ std::unique_ptr<AVBufferSrcParameters, decltype (deleter)> srcParams (
52
+ nullptr , deleter);
53
+
54
+ srcParams.reset (av_buffersrc_parameters_alloc ());
55
+ TORCH_CHECK (srcParams, " Failed to allocate buffersrc params" );
56
+
57
+ srcParams->format = filtersContext.inputFormat ;
58
+ srcParams->width = filtersContext.inputWidth ;
59
+ srcParams->height = filtersContext.inputHeight ;
60
+ srcParams->sample_aspect_ratio = filtersContext.inputAspectRatio ;
61
+ srcParams->time_base = filtersContext.timeBase ;
62
+ if (filtersContext.hwFramesCtx ) {
63
+ srcParams->hw_frames_ctx = av_buffer_ref (filtersContext.hwFramesCtx .get ());
64
+ }
65
+
66
+ sourceContext_ =
67
+ avfilter_graph_alloc_filter (filterGraph_.get (), buffersrc, " in" );
68
+ TORCH_CHECK (sourceContext_, " Failed to allocate filter graph" );
69
+
70
+ int status = av_buffersrc_parameters_set (sourceContext_, srcParams.get ());
57
71
TORCH_CHECK (
58
72
status >= 0 ,
59
73
" Failed to create filter graph: " ,
60
- filterArgs.str (),
61
- " : " ,
74
+ getFFMPEGErrorStringFromErrorCode (status));
75
+
76
+ status = avfilter_init_str (sourceContext_, nullptr );
77
+ TORCH_CHECK (
78
+ status >= 0 ,
79
+ " Failed to create filter graph : " ,
62
80
getFFMPEGErrorStringFromErrorCode (status));
63
81
64
82
status = avfilter_graph_create_filter (
@@ -68,7 +86,8 @@ FilterGraph::FilterGraph(
68
86
" Failed to create filter graph: " ,
69
87
getFFMPEGErrorStringFromErrorCode (status));
70
88
71
- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
89
+ enum AVPixelFormat pix_fmts[] = {
90
+ filtersContext.outputFormat , AV_PIX_FMT_NONE};
72
91
73
92
status = av_opt_set_int_list (
74
93
sinkContext_,
@@ -93,16 +112,11 @@ FilterGraph::FilterGraph(
93
112
inputs->pad_idx = 0 ;
94
113
inputs->next = nullptr ;
95
114
96
- std::stringstream description;
97
- description << " scale=" << frameContext.expectedWidth << " :"
98
- << frameContext.expectedHeight ;
99
- description << " :sws_flags=bilinear" ;
100
-
101
115
AVFilterInOut* outputsTmp = outputs.release ();
102
116
AVFilterInOut* inputsTmp = inputs.release ();
103
117
status = avfilter_graph_parse_ptr (
104
118
filterGraph_.get (),
105
- description. str () .c_str (),
119
+ filtersContext. filters .c_str (),
106
120
&inputsTmp,
107
121
&outputsTmp,
108
122
nullptr );
@@ -128,8 +142,7 @@ UniqueAVFrame FilterGraph::convert(const UniqueAVFrame& avFrame) {
128
142
UniqueAVFrame filteredAVFrame (av_frame_alloc ());
129
143
status = av_buffersink_get_frame (sinkContext_, filteredAVFrame.get ());
130
144
TORCH_CHECK (
131
- status >= AVSUCCESS, " Failed to fet frame from buffer sink context" );
132
- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
145
+ status >= AVSUCCESS, " Failed to get frame from buffer sink context" );
133
146
134
147
return filteredAVFrame;
135
148
}
0 commit comments