-
Notifications
You must be signed in to change notification settings - Fork 647
Adding contrast-limited adaptive histogram equalization (CLAHE) to DALI image operators #6069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tonyreina
wants to merge
43
commits into
NVIDIA:main
Choose a base branch
from
tonyreina:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,455
−56
Open
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
6ae9c22
Add CLAHE (Contrast Limited Adaptive Histogram Equalization) operator…
tonyreina 0dff114
Update copyright year to 2025 in CLAHE files and improve code formatt…
tonyreina 5e3c8d7
Add CLAHE subdirectory and improve code formatting in test files
tonyreina 840c30a
Refactor error handling in CLAHE operators to use exceptions instead …
tonyreina 939e9ba
Refactor LAB conversion constants for consistency and clarity in CLAH…
tonyreina 205beeb
Enhance CLAHE testing with synthetic image generation and OpenCV comp…
tonyreina ab7f607
Refactor CLAHE function calls for improved readability and consistency
tonyreina 1a8374b
Optimize index calculation in histogram kernels for improved performance
tonyreina 20a35d4
Enhance CLAHE implementation and tests in DALI
tonyreina 7eeceef
Refactor CLAHE tests to use defined thresholds for MSE and MAE, and s…
tonyreina e8de829
Update CLAHE functions for OpenCV compatibility and enhance test conf…
tonyreina 8810da1
Add CLAHE tests for CPU and variable batch size support
tonyreina 64cc55a
Refactor CLAHE GPU tests to support variable batch sizes and enhance …
tonyreina 9fc49e0
Refactor CLAHE tests to use a global tolerance constant for CPU vs GP…
tonyreina 2352f74
moving from floats to exact hex values in OpenCV repo
tonyreina a4b9142
more defines for opencv constants as hex
tonyreina 5dcd23d
more conversion to defines for OpenCV constants. makes equations easi…
tonyreina 2d48e20
Refactor CLAHE constants for improved readability and accuracy in LAB…
tonyreina 89908e4
Add D65 white point constants and refactor LAB conversion to use them…
tonyreina 87a2ace
Refactor LAB conversion to use defined RGB constants and improve comp…
tonyreina e30e9ed
Refactor constants and calculations for improved readability and cons…
tonyreina 0fb437c
Refactor CLAHE test output message for improved readability
tonyreina 5004b3b
clang formatter
tonyreina 1775186
Refactor LAB conversion thresholds and scaling to align with OpenCV s…
tonyreina f0f48ea
Refactor LAB conversion constants and comments for improved clarity a…
tonyreina e5c359e
Refactor LAB conversion macros to use CV_HEX_CONST_F for improved typ…
tonyreina ef5d7fc
Implement histogram clipping, redistribution, and CDF calculation hel…
tonyreina 8e5dab8
updating jupyter notebook tutorial
tonyreina 6acc4fd
Merge branch 'NVIDIA:main' into main
tonyreina bd4dc23
updating tests
tonyreina f932dfe
handling unsigned versus signed int
tonyreina 2f1c4f3
adding clahe tests to cpu and variable batch size
tonyreina f361e19
adding eager coverage tet for clahe
tonyreina 5942e15
Merge remote-tracking branch 'upstream/main'
tonyreina 5fd29f2
fix thread for eager execution test
tonyreina dfcde76
Merge upstream NVIDIA/DALI main into fork
tonyreina 0d777f0
update based on greptile comments
tonyreina 051e307
update test
tonyreina 8b77fb2
adding skip for checkpoint
tonyreina b06754a
Optimize CLAHE GPU implementation - Phase 1
tonyreina 4ba27d8
Merge remote-tracking branch 'upstream/main'
tonyreina 53953d8
adding warning about luma_only flag for color images
tonyreina 5dee7a5
Add critical runtime warning for RGB channel order requirement
tonyreina File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Get all the source files and dump test files | ||
| collect_headers(DALI_INST_HDRS PARENT_SCOPE) | ||
| collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) | ||
| collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| // Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include <opencv2/imgproc.hpp> | ||
| #include <opencv2/opencv.hpp> | ||
|
|
||
| #include "dali/core/error_handling.h" | ||
| #include "dali/pipeline/data/views.h" | ||
| #include "dali/pipeline/operator/operator.h" | ||
| #include "dali/pipeline/workspace/workspace.h" | ||
| #include "dali/util/ocv.h" | ||
|
|
||
| namespace dali { | ||
|
|
||
| // ----------------------------------------------------------------------------- | ||
| // CPU CLAHE Operator using OpenCV | ||
| // ----------------------------------------------------------------------------- | ||
| class ClaheCPU : public Operator<CPUBackend> { | ||
| public: | ||
| explicit ClaheCPU(const OpSpec &spec) | ||
| : Operator<CPUBackend>(spec), | ||
| tiles_x_(spec.GetArgument<int>("tiles_x")), | ||
| tiles_y_(spec.GetArgument<int>("tiles_y")), | ||
| clip_limit_(spec.GetArgument<float>("clip_limit")), | ||
| luma_only_(spec.GetArgument<bool>("luma_only")) { | ||
| // Create OpenCV CLAHE object with specified parameters | ||
| clahe_ = cv::createCLAHE(clip_limit_, cv::Size(tiles_x_, tiles_y_)); | ||
| } | ||
|
|
||
| bool SetupImpl(std::vector<OutputDesc> &outputs, const Workspace &ws) override { | ||
| const auto &in = ws.Input<CPUBackend>(0); | ||
|
|
||
| if (in.type() != DALI_UINT8) { | ||
| throw std::invalid_argument("ClaheCPU currently supports only uint8 input."); | ||
| } | ||
|
|
||
| outputs.resize(1); | ||
| outputs[0].type = in.type(); | ||
| outputs[0].shape = in.shape(); // same layout/shape as input | ||
| return true; | ||
| } | ||
|
|
||
| void RunImpl(Workspace &ws) override { | ||
| const auto &input = ws.Input<CPUBackend>(0); | ||
| auto &output = ws.Output<CPUBackend>(0); | ||
| auto in_view = view<const uint8_t>(input); | ||
| auto out_view = view<uint8_t>(output); | ||
|
|
||
| int ndim = in_view.shape.sample_dim(); | ||
| if (ndim != 2 && ndim != 3) { | ||
| throw std::invalid_argument("ClaheCPU expects HW (grayscale) or HWC (color) input layout."); | ||
| } | ||
|
|
||
| // Warn user about RGB channel order requirement for RGB images | ||
| static bool warned_rgb_order = false; | ||
| if (luma_only_ && !warned_rgb_order && ndim == 3) { | ||
| // Check if we have any RGB samples (3 channels) | ||
| bool has_rgb = false; | ||
| for (int i = 0; i < in_view.num_samples(); i++) { | ||
| if (in_view[i].shape.size() == 3 && in_view[i].shape[2] == 3) { | ||
| has_rgb = true; | ||
| break; | ||
| } | ||
| } | ||
| if (has_rgb) { | ||
| DALI_WARN("CRITICAL: CLAHE expects RGB channel order (Red, Green, Blue). " | ||
| "If your images are in BGR order (common with OpenCV cv2.imread), " | ||
| "the luminance calculation will be INCORRECT. " | ||
| "Convert BGR to RGB using fn.reinterpret or similar operators before CLAHE."); | ||
| warned_rgb_order = true; | ||
| } | ||
| } | ||
|
|
||
| auto &tp = ws.GetThreadPool(); | ||
| int num_samples = in_view.num_samples(); | ||
|
|
||
| for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { | ||
| tp.AddWork([this, &in_view, &out_view, sample_idx](int) { | ||
| // Create a thread-local CLAHE object to avoid race conditions | ||
| // OpenCV CLAHE objects are not thread-safe | ||
| auto local_clahe = cv::createCLAHE(clip_limit_, cv::Size(tiles_x_, tiles_y_)); | ||
| ProcessSample(out_view[sample_idx], in_view[sample_idx], local_clahe); | ||
| }, in_view[sample_idx].shape.num_elements()); | ||
tonyreina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| tp.RunAll(); | ||
| } | ||
|
|
||
| private: | ||
| template <int ndim> | ||
| void ProcessSample(TensorView<StorageCPU, uint8_t, ndim> out_sample, | ||
| TensorView<StorageCPU, const uint8_t, ndim> in_sample, | ||
| cv::Ptr<cv::CLAHE> clahe) { | ||
| auto &shape = in_sample.shape; | ||
| int H = shape[0]; | ||
| int W = shape[1]; | ||
| int C = (shape.size() >= 3) ? shape[2] : 1; | ||
tonyreina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (C != 1 && C != 3) { | ||
| throw std::invalid_argument("ClaheCPU supports 1 or 3 channels."); | ||
| } | ||
|
|
||
| if (C == 1) { | ||
| // Grayscale processing | ||
| cv::Mat src(H, W, CV_8UC1, const_cast<uint8_t *>(in_sample.data)); | ||
| cv::Mat dst(H, W, CV_8UC1, out_sample.data); | ||
| clahe->apply(src, dst); | ||
tonyreina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } else { | ||
| // RGB processing | ||
| cv::Mat src(H, W, CV_8UC3, const_cast<uint8_t *>(in_sample.data)); | ||
| cv::Mat dst(H, W, CV_8UC3, out_sample.data); | ||
tonyreina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (luma_only_) { | ||
| // Apply CLAHE to luminance channel only (preserves color relationships) | ||
| cv::Mat lab, lab_dst; | ||
| cv::cvtColor(src, lab, cv::COLOR_RGB2Lab); | ||
|
|
||
| std::vector<cv::Mat> lab_channels; | ||
| cv::split(lab, lab_channels); | ||
|
|
||
| // Apply CLAHE to L (luminance) channel | ||
| clahe->apply(lab_channels[0], lab_channels[0]); | ||
|
|
||
| cv::merge(lab_channels, lab_dst); | ||
| cv::cvtColor(lab_dst, dst, cv::COLOR_Lab2RGB); | ||
| } else { | ||
| // Apply CLAHE to each channel independently | ||
| std::vector<cv::Mat> channels; | ||
| cv::split(src, channels); | ||
|
|
||
| for (auto &channel : channels) { | ||
| clahe->apply(channel, channel); | ||
| } | ||
|
|
||
| cv::merge(channels, dst); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| int tiles_x_, tiles_y_; | ||
| float clip_limit_; | ||
| bool luma_only_; | ||
| cv::Ptr<cv::CLAHE> clahe_; | ||
| }; | ||
|
|
||
| DALI_REGISTER_OPERATOR(Clahe, ClaheCPU, CPU); | ||
|
|
||
| } // namespace dali | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.