Skip to content

Commit bcde35e

Browse files
committed
Implement initializeFiltersContext for CPU device interface
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent ebcb48d commit bcde35e

File tree

2 files changed

+108
-92
lines changed

2 files changed

+108
-92
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 96 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,35 @@ static bool g_cpu = registerDeviceInterface(
1313
torch::kCPU,
1414
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
1515

16+
ColorConversionLibrary getColorConversionLibrary(
17+
const VideoStreamOptions& videoStreamOptions,
18+
int width) {
19+
// By default, we want to use swscale for color conversion because it is
20+
// faster. However, it has width requirements, so we may need to fall back
21+
// to filtergraph. We also need to respect what was requested from the
22+
// options; we respect the options unconditionally, so it's possible for
23+
// swscale's width requirements to be violated. We don't expose the ability to
24+
// choose color conversion library publicly; we only use this ability
25+
// internally.
26+
27+
// swscale requires widths to be multiples of 32:
28+
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
29+
// so we fall back to filtergraph if the width is not a multiple of 32.
30+
auto defaultLibrary = (width % 32 == 0)
31+
? ColorConversionLibrary::SWSCALE
32+
: ColorConversionLibrary::FILTERGRAPH;
33+
34+
ColorConversionLibrary colorConversionLibrary =
35+
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
36+
37+
TORCH_CHECK(
38+
colorConversionLibrary == ColorConversionLibrary::SWSCALE ||
39+
colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH,
40+
"Invalid color conversion library: ",
41+
static_cast<int>(colorConversionLibrary));
42+
return colorConversionLibrary;
43+
}
44+
1645
} // namespace
1746

1847
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
@@ -22,6 +51,52 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
2251
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
2352
}
2453

54+
std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContextInternal(
55+
const VideoStreamOptions& videoStreamOptions,
56+
const UniqueAVFrame& avFrame,
57+
const AVRational& timeBase) {
58+
enum AVPixelFormat frameFormat =
59+
static_cast<enum AVPixelFormat>(avFrame->format);
60+
auto frameDims =
61+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
62+
int expectedOutputHeight = frameDims.height;
63+
int expectedOutputWidth = frameDims.width;
64+
65+
std::unique_ptr<FiltersContext> filtersContext =
66+
std::make_unique<FiltersContext>();
67+
68+
filtersContext->inputWidth = avFrame->width;
69+
filtersContext->inputHeight = avFrame->height;
70+
filtersContext->inputFormat = frameFormat;
71+
filtersContext->inputAspectRatio = avFrame->sample_aspect_ratio;
72+
filtersContext->outputWidth = expectedOutputWidth;
73+
filtersContext->outputHeight = expectedOutputHeight;
74+
filtersContext->outputFormat = AV_PIX_FMT_RGB24;
75+
filtersContext->timeBase = timeBase;
76+
77+
std::stringstream filters;
78+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
79+
filters << ":sws_flags=bilinear";
80+
81+
filtersContext->filters = filters.str();
82+
return filtersContext;
83+
}
84+
85+
std::unique_ptr<FiltersContext> CpuDeviceInterface::initializeFiltersContext(
86+
const VideoStreamOptions& videoStreamOptions,
87+
const UniqueAVFrame& avFrame,
88+
const AVRational& timeBase) {
89+
auto frameDims =
90+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
91+
int expectedOutputWidth = frameDims.width;
92+
93+
if (getColorConversionLibrary(videoStreamOptions, expectedOutputWidth) == ColorConversionLibrary::SWSCALE) {
94+
return nullptr;
95+
}
96+
97+
return initializeFiltersContextInternal(videoStreamOptions, avFrame, timeBase);
98+
}
99+
25100
// Note [preAllocatedOutputTensor with swscale and filtergraph]:
26101
// Callers may pass a pre-allocated tensor, where the output.data tensor will
27102
// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,56 +131,25 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
56131
}
57132

58133
torch::Tensor outputTensor;
59-
// We need to compare the current frame context with our previous frame
60-
// context. If they are different, then we need to re-create our colorspace
61-
// conversion objects. We create our colorspace conversion objects late so
62-
// that we don't have to depend on the unreliable metadata in the header.
63-
// And we sometimes re-create them because it's possible for frame
64-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
65-
// conversion objects as much as possible for performance reasons.
66-
enum AVPixelFormat frameFormat =
67-
static_cast<enum AVPixelFormat>(avFrame->format);
68-
FiltersContext filtersContext;
69-
70-
filtersContext.inputWidth = avFrame->width;
71-
filtersContext.inputHeight = avFrame->height;
72-
filtersContext.inputFormat = frameFormat;
73-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
74-
filtersContext.outputWidth = expectedOutputWidth;
75-
filtersContext.outputHeight = expectedOutputHeight;
76-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77-
filtersContext.timeBase = timeBase;
78-
79-
std::stringstream filters;
80-
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
81-
filters << ":sws_flags=bilinear";
82-
83-
filtersContext.filters = filters.str();
84-
85-
// By default, we want to use swscale for color conversion because it is
86-
// faster. However, it has width requirements, so we may need to fall back
87-
// to filtergraph. We also need to respect what was requested from the
88-
// options; we respect the options unconditionally, so it's possible for
89-
// swscale's width requirements to be violated. We don't expose the ability to
90-
// choose color conversion library publicly; we only use this ability
91-
// internally.
92-
93-
// swscale requires widths to be multiples of 32:
94-
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
95-
// so we fall back to filtergraph if the width is not a multiple of 32.
96-
auto defaultLibrary = (expectedOutputWidth % 32 == 0)
97-
? ColorConversionLibrary::SWSCALE
98-
: ColorConversionLibrary::FILTERGRAPH;
99-
100134
ColorConversionLibrary colorConversionLibrary =
101-
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
135+
getColorConversionLibrary(videoStreamOptions, expectedOutputWidth);
102136

103137
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
104138
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
105139
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
106140

141+
// We need to compare the current frame context with our previous frame
142+
// context. If they are different, then we need to re-create our colorspace
143+
// conversion objects. We create our colorspace conversion objects late so
144+
// that we don't have to depend on the unreliable metadata in the header.
145+
// And we sometimes re-create them because it's possible for frame
146+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
147+
// conversion objects as much as possible for performance reasons.
148+
std::unique_ptr<FiltersContext> filtersContext =
149+
initializeFiltersContextInternal(videoStreamOptions, avFrame, timeBase);
150+
107151
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108-
createSwsContext(filtersContext, avFrame->colorspace);
152+
createSwsContext(*filtersContext, avFrame->colorspace);
109153
prevFiltersContext_ = std::move(filtersContext);
110154
}
111155
int resultHeight =
@@ -122,25 +166,16 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
122166

123167
frameOutput.data = outputTensor;
124168
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
125-
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126-
filterGraphContext_ =
127-
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
128-
prevFiltersContext_ = std::move(filtersContext);
129-
}
130-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
169+
TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24);
131170

132-
// Similarly to above, if this check fails it means the frame wasn't
133-
// reshaped to its expected dimensions by filtergraph.
134-
auto shape = outputTensor.sizes();
135-
TORCH_CHECK(
136-
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
137-
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
138-
"Expected output tensor of shape ",
139-
expectedOutputHeight,
140-
"x",
141-
expectedOutputWidth,
142-
"x3, got ",
143-
shape);
171+
std::vector<int64_t> shape = {expectedOutputHeight, expectedOutputWidth, 3};
172+
std::vector<int64_t> strides = {avFrame->linesize[0], 3, 1};
173+
AVFrame* avFramePtr = avFrame.release();
174+
auto deleter = [avFramePtr](void*) {
175+
UniqueAVFrame avFrameToDelete(avFramePtr);
176+
};
177+
outputTensor = torch::from_blob(
178+
avFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
144179

145180
if (preAllocatedOutputTensor.has_value()) {
146181
// We have already validated that preAllocatedOutputTensor and
@@ -150,11 +185,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
150185
} else {
151186
frameOutput.data = outputTensor;
152187
}
153-
} else {
154-
TORCH_CHECK(
155-
false,
156-
"Invalid color conversion library: ",
157-
static_cast<int>(colorConversionLibrary));
158188
}
159189
}
160190

@@ -176,25 +206,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
176206
return resultHeight;
177207
}
178208

179-
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
180-
const UniqueAVFrame& avFrame) {
181-
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
182-
183-
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
184-
185-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
186-
int height = frameDims.height;
187-
int width = frameDims.width;
188-
std::vector<int64_t> shape = {height, width, 3};
189-
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
190-
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
191-
auto deleter = [filteredAVFramePtr](void*) {
192-
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
193-
};
194-
return torch::from_blob(
195-
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
196-
}
197-
198209
void CpuDeviceInterface::createSwsContext(
199210
const FiltersContext& filtersContext,
200211
const enum AVColorSpace colorspace) {

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class CpuDeviceInterface : public DeviceInterface {
2626
void initializeContext(
2727
[[maybe_unused]] AVCodecContext* codecContext) override {}
2828

29+
std::unique_ptr<FiltersContext> initializeFiltersContext(
30+
const VideoStreamOptions& videoStreamOptions,
31+
const UniqueAVFrame& avFrame,
32+
const AVRational& timeBase) override;
33+
2934
void convertAVFrameToFrameOutput(
3035
const VideoStreamOptions& videoStreamOptions,
3136
const AVRational& timeBase,
@@ -39,21 +44,21 @@ class CpuDeviceInterface : public DeviceInterface {
3944
const UniqueAVFrame& avFrame,
4045
torch::Tensor& outputTensor);
4146

42-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
43-
const UniqueAVFrame& avFrame);
47+
std::unique_ptr<FiltersContext> initializeFiltersContextInternal(
48+
const VideoStreamOptions& videoStreamOptions,
49+
const UniqueAVFrame& avFrame,
50+
const AVRational& timeBase);
4451

4552
void createSwsContext(
4653
const FiltersContext& filtersContext,
4754
const enum AVColorSpace colorspace);
4855

49-
// color-conversion fields. Only one of FilterGraphContext and
50-
// UniqueSwsContext should be non-null.
51-
std::unique_ptr<FilterGraph> filterGraphContext_;
56+
// SWS color conversion context
5257
UniqueSwsContext swsContext_;
5358

54-
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
59+
// Used to know whether a new UniqueSwsContext should
5560
// be created before decoding a new frame.
56-
FiltersContext prevFiltersContext_;
61+
std::unique_ptr<FiltersContext> prevFiltersContext_;
5762
};
5863

5964
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)