@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
13
13
torch::kCPU ,
14
14
[](const torch::Device& device) { return new CpuDeviceInterface (device); });
15
15
16
+ ColorConversionLibrary getColorConversionLibrary (
17
+ const VideoStreamOptions& videoStreamOptions,
18
+ int width) {
19
+ // By default, we want to use swscale for color conversion because it is
20
+ // faster. However, it has width requirements, so we may need to fall back
21
+ // to filtergraph. We also need to respect what was requested from the
22
+ // options; we respect the options unconditionally, so it's possible for
23
+ // swscale's width requirements to be violated. We don't expose the ability to
24
+ // choose color conversion library publicly; we only use this ability
25
+ // internally.
26
+
27
+ // swscale requires widths to be multiples of 32:
28
+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29
+ // so we fall back to filtergraph if the width is not a multiple of 32.
30
+ auto defaultLibrary = (width % 32 == 0 )
31
+ ? ColorConversionLibrary::SWSCALE
32
+ : ColorConversionLibrary::FILTERGRAPH;
33
+
34
+ ColorConversionLibrary colorConversionLibrary =
35
+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
36
+
37
+ TORCH_CHECK (
38
+ colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39
+ colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40
+ " Invalid color conversion library: " ,
41
+ static_cast <int >(colorConversionLibrary));
42
+ return colorConversionLibrary;
43
+ }
44
+
16
45
} // namespace
17
46
18
47
CpuDeviceInterface::CpuDeviceInterface (const torch::Device& device)
@@ -22,6 +51,52 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
22
51
device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
23
52
}
24
53
54
+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContextInternal (
55
+ const VideoStreamOptions& videoStreamOptions,
56
+ const UniqueAVFrame& avFrame,
57
+ const AVRational& timeBase) {
58
+ enum AVPixelFormat frameFormat =
59
+ static_cast <enum AVPixelFormat>(avFrame->format );
60
+ auto frameDims =
61
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
62
+ int expectedOutputHeight = frameDims.height ;
63
+ int expectedOutputWidth = frameDims.width ;
64
+
65
+ std::unique_ptr<FiltersContext> filtersContext =
66
+ std::make_unique<FiltersContext>();
67
+
68
+ filtersContext->inputWidth = avFrame->width ;
69
+ filtersContext->inputHeight = avFrame->height ;
70
+ filtersContext->inputFormat = frameFormat;
71
+ filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio ;
72
+ filtersContext->outputWidth = expectedOutputWidth;
73
+ filtersContext->outputHeight = expectedOutputHeight;
74
+ filtersContext->outputFormat = AV_PIX_FMT_RGB24;
75
+ filtersContext->timeBase = timeBase;
76
+
77
+ std::stringstream filters;
78
+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
79
+ filters << " :sws_flags=bilinear" ;
80
+
81
+ filtersContext->filters = filters.str ();
82
+ return filtersContext;
83
+ }
84
+
85
+ std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext (
86
+ const VideoStreamOptions& videoStreamOptions,
87
+ const UniqueAVFrame& avFrame,
88
+ const AVRational& timeBase) {
89
+ auto frameDims =
90
+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
91
+ int expectedOutputWidth = frameDims.width ;
92
+
93
+ if (getColorConversionLibrary (videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
94
+ return nullptr ;
95
+ }
96
+
97
+ return initializeFiltersContextInternal (videoStreamOptions, avFrame, timeBase);
98
+ }
99
+
25
100
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
26
101
// Callers may pass a pre-allocated tensor, where the output.data tensor will
27
102
// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,56 +131,25 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
56
131
}
57
132
58
133
torch::Tensor outputTensor;
59
- // We need to compare the current frame context with our previous frame
60
- // context. If they are different, then we need to re-create our colorspace
61
- // conversion objects. We create our colorspace conversion objects late so
62
- // that we don't have to depend on the unreliable metadata in the header.
63
- // And we sometimes re-create them because it's possible for frame
64
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
65
- // conversion objects as much as possible for performance reasons.
66
- enum AVPixelFormat frameFormat =
67
- static_cast <enum AVPixelFormat>(avFrame->format );
68
- FiltersContext filtersContext;
69
-
70
- filtersContext.inputWidth = avFrame->width ;
71
- filtersContext.inputHeight = avFrame->height ;
72
- filtersContext.inputFormat = frameFormat;
73
- filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
74
- filtersContext.outputWidth = expectedOutputWidth;
75
- filtersContext.outputHeight = expectedOutputHeight;
76
- filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77
- filtersContext.timeBase = timeBase;
78
-
79
- std::stringstream filters;
80
- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
81
- filters << " :sws_flags=bilinear" ;
82
-
83
- filtersContext.filters = filters.str ();
84
-
85
- // By default, we want to use swscale for color conversion because it is
86
- // faster. However, it has width requirements, so we may need to fall back
87
- // to filtergraph. We also need to respect what was requested from the
88
- // options; we respect the options unconditionally, so it's possible for
89
- // swscale's width requirements to be violated. We don't expose the ability to
90
- // choose color conversion library publicly; we only use this ability
91
- // internally.
92
-
93
- // swscale requires widths to be multiples of 32:
94
- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
95
- // so we fall back to filtergraph if the width is not a multiple of 32.
96
- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
97
- ? ColorConversionLibrary::SWSCALE
98
- : ColorConversionLibrary::FILTERGRAPH;
99
-
100
134
ColorConversionLibrary colorConversionLibrary =
101
- videoStreamOptions. colorConversionLibrary . value_or (defaultLibrary );
135
+ getColorConversionLibrary (videoStreamOptions, expectedOutputWidth );
102
136
103
137
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
104
138
outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
105
139
expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
106
140
141
+ // We need to compare the current frame context with our previous frame
142
+ // context. If they are different, then we need to re-create our colorspace
143
+ // conversion objects. We create our colorspace conversion objects late so
144
+ // that we don't have to depend on the unreliable metadata in the header.
145
+ // And we sometimes re-create them because it's possible for frame
146
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
147
+ // conversion objects as much as possible for performance reasons.
148
+ std::unique_ptr<FiltersContext> filtersContext =
149
+ initializeFiltersContextInternal (videoStreamOptions, avFrame, timeBase);
150
+
107
151
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108
- createSwsContext (filtersContext, avFrame->colorspace );
152
+ createSwsContext (* filtersContext, avFrame->colorspace );
109
153
prevFiltersContext_ = std::move (filtersContext);
110
154
}
111
155
int resultHeight =
@@ -122,25 +166,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
122
166
123
167
frameOutput.data = outputTensor;
124
168
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
125
- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126
- filterGraphContext_ =
127
- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
128
- prevFiltersContext_ = std::move (filtersContext);
129
- }
130
- outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
169
+ TORCH_CHECK_EQ (avFrame->format , AV_PIX_FMT_RGB24);
131
170
132
- // Similarly to above, if this check fails it means the frame wasn't
133
- // reshaped to its expected dimensions by filtergraph.
134
- auto shape = outputTensor.sizes ();
135
- TORCH_CHECK (
136
- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
137
- (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
138
- " Expected output tensor of shape " ,
139
- expectedOutputHeight,
140
- " x" ,
141
- expectedOutputWidth,
142
- " x3, got " ,
143
- shape);
171
+ std::vector<int64_t > shape = {expectedOutputHeight, expectedOutputWidth, 3 };
172
+ std::vector<int64_t > strides = {avFrame->linesize [0 ], 3 , 1 };
173
+ AVFrame* avFramePtr = avFrame.release ();
174
+ auto deleter = [avFramePtr](void *) {
175
+ UniqueAVFrame avFrameToDelete (avFramePtr);
176
+ };
177
+ outputTensor = torch::from_blob (
178
+ avFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
144
179
145
180
if (preAllocatedOutputTensor.has_value ()) {
146
181
// We have already validated that preAllocatedOutputTensor and
@@ -150,11 +185,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
150
185
} else {
151
186
frameOutput.data = outputTensor;
152
187
}
153
- } else {
154
- TORCH_CHECK (
155
- false ,
156
- " Invalid color conversion library: " ,
157
- static_cast <int >(colorConversionLibrary));
158
188
}
159
189
}
160
190
@@ -176,25 +206,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
176
206
return resultHeight;
177
207
}
178
208
179
- torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
180
- const UniqueAVFrame& avFrame) {
181
- UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
182
-
183
- TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
184
-
185
- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
186
- int height = frameDims.height ;
187
- int width = frameDims.width ;
188
- std::vector<int64_t > shape = {height, width, 3 };
189
- std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
190
- AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
191
- auto deleter = [filteredAVFramePtr](void *) {
192
- UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
193
- };
194
- return torch::from_blob (
195
- filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
196
- }
197
-
198
209
void CpuDeviceInterface::createSwsContext (
199
210
const FiltersContext& filtersContext,
200
211
const enum AVColorSpace colorspace) {
0 commit comments