diff --git a/DALI_EXTRA_VERSION b/DALI_EXTRA_VERSION index 0f5b7e170a..48be880891 100644 --- a/DALI_EXTRA_VERSION +++ b/DALI_EXTRA_VERSION @@ -1 +1 @@ -1ffbeaf1d085bb00f124038503508b3cb68e1a05 +21dd6148f0cf4557531d54b810379c681c898e91 diff --git a/dali/operators/video/color_space.cu b/dali/operators/video/color_space.cu index b36f039b8f..1a289d0354 100644 --- a/dali/operators/video/color_space.cu +++ b/dali/operators/video/color_space.cu @@ -51,18 +51,20 @@ __global__ static void VideoColorSpaceConversionKernel( #pragma unroll for (int j = 0; j < 2; j++) { float cx = halfx + j * 0.5f + 0.25f; - u8vec3 yuv_val; + vec3 yuv_val; yuv_val[0] = Y.at(ivec2{x + j, y + i}, 0, kernels::BorderClamp()); UV(&yuv_val[1], vec2(cx, cy), kernels::BorderClamp()); - u8vec3 out_val; + yuv_val *= 1.0f / 255.0f; + + vec3 out_val; switch (conversion_type) { case VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_TO_RGB_FULL_RANGE: - out_val = dali::kernels::color::jpeg::ycbcr_to_rgb(yuv_val); + out_val = dali::kernels::color::jpeg::ycbcr_to_rgb(yuv_val); break; case VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_TO_RGB: - out_val = dali::kernels::color::itu_r_bt_601::ycbcr_to_rgb(yuv_val); + out_val = dali::kernels::color::itu_r_bt_601::ycbcr_to_rgb(yuv_val); break; case VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_UPSAMPLE: out_val = yuv_val; @@ -71,10 +73,11 @@ __global__ static void VideoColorSpaceConversionKernel( assert(false); } if (normalized_range) { - output({x + j, y + i, 0}) = ConvertNorm(out_val.x); - output({x + j, y + i, 1}) = ConvertNorm(out_val.y); - output({x + j, y + i, 2}) = ConvertNorm(out_val.z); + output({x + j, y + i, 0}) = ConvertSatNorm(out_val.x); + output({x + j, y + i, 1}) = ConvertSatNorm(out_val.y); + output({x + j, y + i, 2}) = ConvertSatNorm(out_val.z); } else { + out_val *= 255.0f; output({x + j, y + i, 0}) = ConvertSat(out_val.x); output({x + j, y + i, 1}) = ConvertSat(out_val.y); output({x + j, y + i, 2}) = ConvertSat(out_val.z); diff --git a/dali/operators/video/frames_decoder_cpu.cc b/dali/operators/video/frames_decoder_cpu.cc index fb624b7d1b..50d16d8a87 100644 --- a/dali/operators/video/frames_decoder_cpu.cc +++ b/dali/operators/video/frames_decoder_cpu.cc @@ -93,7 +93,7 @@ void FramesDecoderCpu::CopyToOutput(uint8_t *data) { Width(), Height(), sws_output_format, - SWS_BILINEAR, + SWS_BILINEAR|SWS_FULL_CHR_H_INT|SWS_ACCURATE_RND, nullptr, nullptr, nullptr), diff --git a/dali/operators/video/legacy/reader/nvdecoder/imgproc.cu b/dali/operators/video/legacy/reader/nvdecoder/imgproc.cu deleted file mode 100644 index d9078e0c30..0000000000 --- a/dali/operators/video/legacy/reader/nvdecoder/imgproc.cu +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dali/operators/video/legacy/reader/nvdecoder/imgproc.h" - -#include - -namespace dali { - -namespace { - -// using math from https://msdn.microsoft.com/en-us/library/windows/desktop/dd206750(v=vs.85).aspx - -template -struct YCbCr { - T y, cb, cr; -}; - -// https://docs.microsoft.com/en-gb/windows/desktop/medfound/recommended-8-bit-yuv-formats-for-video-rendering#converting-8-bit-yuv-to-rgb888 -__constant__ float ycbcr2rgb_mat_norm[9] = { - 1.164383f, 0.0f, 1.596027f, - 1.164383f, -0.391762f, -0.812968f, - 1.164383f, 2.017232f, 0.0f -}; - -// not normalized need *255 -__constant__ float ycbcr2rgb_mat[9] = { - 1.164383f * 255.0f, 0.0f, 1.596027f * 255.0f, - 1.164383f * 255.0f, -0.391762f * 255.0f, -0.812968f * 255.0f, - 1.164383f * 255.0f, 2.017232f * 255.0f, 0.0f -}; - - -// https://en.wikipedia.org/wiki/YUV#Y%E2%80%B2UV444_to_RGB888_conversion -__constant__ float ycbcr2rgb_mat_norm_full_range[9] = { - 1, 0.0f, 1.402f, - 1, -0.344136285f, -0.714136285f, - 1, 1.772f, 0.0f -}; - -// not normalized need *255 -__constant__ float ycbcr2rgb_mat_full_range[9] = { - 1 * 255, 0.0f, 1.402f * 255, - 1 * 255, -0.344136285f * 255, -0.714136285f * 255, - 1 * 255, 1.772f * 255, 0.0f -}; - -__device__ float clip(float x, float max) { - return fminf(fmaxf(x, 0.0f), max); -} - -template -__device__ T convert(const float x) { - return static_cast(x); -} - -#if 0 -template<> -__device__ half convert(const float x) { - return __float2half(x); -} - -template<> -__device__ uint8_t convert(const float x) { - return static_cast(roundf(x)); -} -#endif - -template -__device__ void ycbcr2rgb(const YCbCr& ycbcr, RGB_T* rgb, - size_t stride) { - auto y = (static_cast(ycbcr.y) - 16.0f/255.0f); - auto cb = (static_cast(ycbcr.cb) - 128.0f/255.0f); - auto cr = (static_cast(ycbcr.cr) - 128.0f/255.0f); - - - float r, g, b; - if (Normalized) { - auto& m = ycbcr2rgb_mat_norm; - r = clip(y*m[0] + cb*m[1] + cr*m[2], 1.0f); - g = clip(y*m[3] + cb*m[4] + cr*m[5], 1.0f); - b = clip(y*m[6] + cb*m[7] + cr*m[8], 1.0f); - } else { - auto& m = ycbcr2rgb_mat; - r = clip(y*m[0] + cb*m[1] + cr*m[2], 255.0f); - g = clip(y*m[3] + cb*m[4] + cr*m[5], 255.0f); - b = clip(y*m[6] + cb*m[7] + cr*m[8], 255.0f); - } - - rgb[0] = convert(r); - rgb[stride] = convert(g); - rgb[stride*2] = convert(b); -} - - -template -__device__ void ycbcr2rgb_full_range(const YCbCr& ycbcr, RGB_T* rgb, - size_t stride) { - auto y = (static_cast(ycbcr.y)); - auto cb = (static_cast(ycbcr.cb) - 128.0f/255.0f); - auto cr = (static_cast(ycbcr.cr) - 128.0f/255.0f); - - - float r, g, b; - if (Normalized) { - auto& m = ycbcr2rgb_mat_norm_full_range; - r = clip(y*m[0] + cb*m[1] + cr*m[2], 1.0f); - g = clip(y*m[3] + cb*m[4] + cr*m[5], 1.0f); - b = clip(y*m[6] + cb*m[7] + cr*m[8], 1.0f); - } else { - auto& m = ycbcr2rgb_mat_full_range; - r = clip(y*m[0] + cb*m[1] + cr*m[2], 255.0f); - g = clip(y*m[3] + cb*m[4] + cr*m[5], 255.0f); - b = clip(y*m[6] + cb*m[7] + cr*m[8], 255.0f); - } - - rgb[0] = convert(r); - rgb[stride] = convert(g); - rgb[stride*2] = convert(b); -} - -template -__global__ void process_frame_kernel( - cudaTextureObject_t luma, cudaTextureObject_t chroma, - T* dst, int index, - float fx, float fy, - int dst_width, int dst_height, int c) { - const int dst_x = blockIdx.x * blockDim.x + threadIdx.x; - const int dst_y = blockIdx.y * blockDim.y + threadIdx.y; - - if (dst_x >= dst_width || dst_y >= dst_height) - return; - - auto src_x = 0.0f; - // TODO(spanev) something less hacky here, why 4:2:0 fails on this edge? - float shift = (dst_x == dst_width - 1) ? 0 : 0.5f; - src_x = static_cast(dst_x) * fx + shift; - auto src_y = static_cast(dst_y) * fy + shift; - - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#tex2d-object - YCbCr ycbcr; - ycbcr.y = tex2D(luma, src_x, src_y); - auto cbcr = tex2D(chroma, src_x * 0.5f, src_y * 0.5f); - ycbcr.cb = cbcr.x; - ycbcr.cr = cbcr.y; - - auto* out = &dst[(dst_x + dst_y * dst_width) * c]; - - constexpr size_t stride = 1; - if (RGB) { - if (FullRange) { - ycbcr2rgb_full_range(ycbcr, out, stride); - } else { - ycbcr2rgb(ycbcr, out, stride); - } - } else { - constexpr float scaling = Normalized ? 1.0f : 255.0f; - out[0] = convert(ycbcr.y * scaling); - out[stride] = convert(ycbcr.cb * scaling); - out[stride*2] = convert(ycbcr.cr * scaling); - } -} - -inline constexpr int divUp(int total, int grain) { - return (total + grain - 1) / grain; -} - -} // namespace - -template -void process_frame( - cudaTextureObject_t chroma, cudaTextureObject_t luma, - SequenceWrapper& output, int index, cudaStream_t stream, - uint16_t input_width, uint16_t input_height, - bool rgb, bool normalized, bool full_range) { - auto scale_width = input_width; - auto scale_height = input_height; - - auto fx = static_cast(input_width) / scale_width; - auto fy = static_cast(input_height) / scale_height; - - dim3 block(32, 8); - dim3 grid(divUp(output.width, block.x), divUp(output.height, block.y)); - - auto frame_stride = - static_cast(index) * output.height * output.width * output.channels; - LOG_LINE << "Processing frame " << index - << " (frame_stride=" << frame_stride << ")" << std::endl; - auto* tensor_out = output.sequence.mutable_data() + frame_stride; - - if (normalized) { - if (rgb) { - if (full_range) { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } else { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } - } else { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } - } else { - if (rgb) { - if (full_range) { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } else { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } - } else { - process_frame_kernel<<>> - (luma, chroma, tensor_out, index, fx, fy, output.width, output.height, output.channels); - } - } -} - -template -void process_frame( - cudaTextureObject_t chroma, cudaTextureObject_t luma, - SequenceWrapper& output, int index, cudaStream_t stream, - uint16_t input_width, uint16_t input_height, - bool rgb, bool normalized, bool full_range); - -template -void process_frame( - cudaTextureObject_t chroma, cudaTextureObject_t luma, - SequenceWrapper& output, int index, cudaStream_t stream, - uint16_t input_width, uint16_t input_height, - bool rgb, bool normalized, bool full_range); - -} // namespace dali diff --git a/dali/operators/video/legacy/reader/nvdecoder/imgproc.h b/dali/operators/video/legacy/reader/nvdecoder/imgproc.h deleted file mode 100644 index fd65b3ec85..0000000000 --- a/dali/operators/video/legacy/reader/nvdecoder/imgproc.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DALI_OPERATORS_READER_NVDECODER_IMGPROC_H_ -#define DALI_OPERATORS_READER_NVDECODER_IMGPROC_H_ - - -#include "dali/core/common.h" -#include "dali/operators/video/legacy/reader/nvdecoder/sequencewrapper.h" - -namespace dali { - -template -DLL_PUBLIC void process_frame( - cudaTextureObject_t chroma, cudaTextureObject_t luma, - SequenceWrapper& output, int index, cudaStream_t stream, - uint16_t input_width, uint16_t input_height, - bool rgb, bool normalized, bool full_range); - -} // namespace dali - -#endif // DALI_OPERATORS_READER_NVDECODER_IMGPROC_H_ diff --git a/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.cc b/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.cc index f73a5976ca..72ec122394 100644 --- a/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.cc +++ b/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.cc @@ -28,7 +28,7 @@ #include "dali/core/error_handling.h" #include "dali/core/static_switch.h" #include "dali/operators/video/legacy/reader/nvdecoder/cuvideoparser.h" -#include "dali/operators/video/legacy/reader/nvdecoder/imgproc.h" +#include "dali/operators/video/color_space.h" namespace dali { @@ -49,7 +49,7 @@ NvDecoder::NvDecoder(int device_id, frame_in_use_(32), // 32 is cuvid's max number of decode surfaces frame_full_range_(32), // 32 is cuvid's max number of decode surfaces recv_queue_(), frame_queue_(), - current_recv_(), req_ready_(VidReqStatus::REQ_READY), textures_(), stop_(false) { + current_recv_(), req_ready_(VidReqStatus::REQ_READY), stop_(false) { DALI_ENFORCE(cuInitChecked(), "Failed to load libcuda.so. " @@ -279,45 +279,6 @@ unsigned int NvDecoder::MappedFrame::get_pitch() const { return pitch_; } -NvDecoder::TextureObject::TextureObject() : valid_{false} { -} - -NvDecoder::TextureObject::TextureObject(const cudaResourceDesc* pResDesc, - const cudaTextureDesc* pTexDesc, - const cudaResourceViewDesc* pResViewDesc) - : valid_{false} -{ - CUDA_CALL(cudaCreateTextureObject(&object_, pResDesc, pTexDesc, pResViewDesc)); - valid_ = true; -} - -NvDecoder::TextureObject::~TextureObject() { - if (valid_) { - cudaDestroyTextureObject(object_); - } -} - -NvDecoder::TextureObject::TextureObject(NvDecoder::TextureObject&& other) - : valid_{other.valid_}, object_{other.object_} -{ - other.valid_ = false; -} - -NvDecoder::TextureObject& NvDecoder::TextureObject::operator=(NvDecoder::TextureObject&& other) { - valid_ = other.valid_; - object_ = other.object_; - other.valid_ = false; - return *this; -} - -NvDecoder::TextureObject::operator cudaTextureObject_t() const { - if (valid_) { - return object_; - } else { - return cudaTextureObject_t{}; - } -} - // Callback called by the driver decoder once a frame has been decoded int NvDecoder::handle_display_(CUVIDPARSERDISPINFO* disp_info) { auto frame = av_rescale_q(disp_info->timestamp, @@ -417,13 +378,31 @@ void NvDecoder::receive_frames(SequenceWrapper& sequence) { auto* frame_disp_info = frame_queue_.pop(); if (stop_) break; - auto frame = MappedFrame{frame_disp_info, decoder_, stream_}; + MappedFrame frame{frame_disp_info, decoder_, stream_}; sequence.timestamps.push_back(frame_disp_info->timestamp * av_q2d( nv_time_base_)); if (stop_) break; - convert_frame(frame, sequence, i); + auto is_full_range = frame_full_range_[frame.disp_info->picture_index]; + auto conversion_type = rgb_ ? + is_full_range ? VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_TO_RGB_FULL_RANGE : + VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_TO_RGB : + VIDEO_COLOR_SPACE_CONVERSION_TYPE_YUV_UPSAMPLE; + auto frame_stride = + static_cast(i) * sequence.height * sequence.width * sequence.channels; + TYPE_SWITCH(dtype_, type2id, OutType, NVDECODER_SUPPORTED_TYPES, ( + auto* tensor_out = sequence.sequence.mutable_data() + frame_stride; + VideoColorSpaceConversion( + reinterpret_cast(tensor_out), sequence.width * sequence.channels, + reinterpret_cast(frame.get_ptr()), static_cast(frame.get_pitch()), + decoder_.height(), + decoder_.width(), + conversion_type, + normalized_, + stream_); + ), DALI_FAIL(make_string("Unsupported type: ", dtype_))); // synchronize before MappedFrame is destroyed and cuvidUnmapVideoFrame is called CUDA_CALL(cudaStreamSynchronize(stream_)); + frame_in_use_[frame.disp_info->picture_index] = false; } if (captured_exception_) std::rethrow_exception(captured_exception_); @@ -440,85 +419,6 @@ void NvDecoder::receive_frames(SequenceWrapper& sequence) { record_sequence_event_(sequence); } -// We assume here that a pointer and scale_method -// uniquely identifies a texture -const NvDecoder::TextureObjects& -NvDecoder::get_textures(uint8_t* input, unsigned int input_pitch, - uint16_t input_width, uint16_t input_height, - ScaleMethod scale_method) { - auto tex_id = std::make_tuple(input, scale_method, input_height, input_width, input_pitch); - auto tex = textures_.find(tex_id); - if (tex != textures_.end()) { - return tex->second; - } - TextureObjects objects; - cudaTextureDesc tex_desc = {}; - tex_desc.addressMode[0] = cudaAddressModeClamp; - tex_desc.addressMode[1] = cudaAddressModeClamp; - if (scale_method == ScaleMethod_Nearest) { - tex_desc.filterMode = cudaFilterModePoint; - } else { - tex_desc.filterMode = cudaFilterModeLinear; - } - tex_desc.readMode = cudaReadModeNormalizedFloat; - tex_desc.normalizedCoords = 0; - - cudaResourceDesc res_desc = {}; - res_desc.resType = cudaResourceTypePitch2D; - res_desc.res.pitch2D.devPtr = input; - res_desc.res.pitch2D.desc = cudaCreateChannelDesc(); - res_desc.res.pitch2D.width = input_width; - res_desc.res.pitch2D.height = input_height; - res_desc.res.pitch2D.pitchInBytes = input_pitch; - - objects.luma = TextureObject{&res_desc, &tex_desc, nullptr}; - - tex_desc.addressMode[0] = cudaAddressModeClamp; - tex_desc.addressMode[1] = cudaAddressModeClamp; - tex_desc.filterMode = cudaFilterModeLinear; - tex_desc.readMode = cudaReadModeNormalizedFloat; - tex_desc.normalizedCoords = 0; - - res_desc.resType = cudaResourceTypePitch2D; - res_desc.res.pitch2D.devPtr = input + (input_height * input_pitch); - res_desc.res.pitch2D.desc = cudaCreateChannelDesc(); - res_desc.res.pitch2D.width = input_width; - res_desc.res.pitch2D.height = input_height / 2; - res_desc.res.pitch2D.pitchInBytes = input_pitch; - - objects.chroma = TextureObject{&res_desc, &tex_desc, nullptr}; - - auto p = textures_.emplace(tex_id, std::move(objects)); - if (!p.second) { - DALI_FAIL("Unable to cache a new texture object."); - } - return p.first->second; -} - -void NvDecoder::convert_frame(const MappedFrame& frame, SequenceWrapper& sequence, - int index) { - auto input_width = ALIGN32(decoder_.width()); - auto input_height = decoder_.height(); - - auto output_idx = index; - // TODO(spanev) Add ScaleMethod choice - auto& textures = this->get_textures(frame.get_ptr(), - frame.get_pitch(), - input_width, - input_height, - ScaleMethod_Linear); - TYPE_SWITCH(dtype_, type2id, OutputType, NVDECODER_SUPPORTED_TYPES, ( - process_frame(textures.chroma, textures.luma, - sequence, - output_idx, stream_, - input_width, input_height, - rgb_, normalized_, frame_full_range_[frame.disp_info->picture_index]); - ), DALI_FAIL(make_string("Not supported output type:", dtype_, // NOLINT - "Only DALI_UINT8 and DALI_FLOAT are supported as the decoder outputs."));); - - frame_in_use_[frame.disp_info->picture_index] = false; -} - void NvDecoder::finish() { stop_ = true; recv_queue_.shutdown(); diff --git a/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.h b/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.h index a7802f8665..27fb20612a 100644 --- a/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.h +++ b/dali/operators/video/legacy/reader/nvdecoder/nvdecoder.h @@ -25,8 +25,6 @@ extern "C" { #include #include #include -#include -#include #include #include @@ -66,19 +64,6 @@ struct FrameReq { bool full_range; }; -enum ScaleMethod { - /** - * The value for the nearest neighbor is used, no interpolation - */ - ScaleMethod_Nearest, - - /** - * Simple bilinear interpolation of four nearest neighbors - */ - ScaleMethod_Linear -}; - - enum class VidReqStatus { REQ_READY = 0, REQ_IN_PROGRESS, @@ -156,35 +141,6 @@ class NvDecoder { CUVIDPROCPARAMS params_; }; - class TextureObject { - public: - TextureObject(); - TextureObject(const cudaResourceDesc* pResDesc, - const cudaTextureDesc* pTexDesc, - const cudaResourceViewDesc* pResViewDesc); - ~TextureObject(); - TextureObject(TextureObject&& other); - TextureObject& operator=(TextureObject&& other); - TextureObject(const TextureObject&) = delete; - TextureObject& operator=(const TextureObject&) = delete; - operator cudaTextureObject_t() const; - private: - bool valid_; - cudaTextureObject_t object_ = 0; - }; - - struct TextureObjects { - TextureObject luma; - TextureObject chroma; - }; - - const TextureObjects& get_textures(uint8_t* input, unsigned int input_pitch, - uint16_t input_width, uint16_t input_height, - ScaleMethod scale_method); - void convert_frame(const MappedFrame& frame, SequenceWrapper& sequence, - int index); - - const int device_id_; CUDAStreamLease stream_; @@ -205,31 +161,6 @@ class NvDecoder { FrameReq current_recv_; VidReqStatus req_ready_; - using TexID = std::tuple; - - struct tex_hash { - // hash_combine taken from - // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n3876.pdf - template - inline void hash_combine(size_t& seed, const T& value) const { - std::hash hasher; - seed ^= hasher(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - - std::size_t operator () (const TexID& tex) const { - size_t seed = 0; - hash_combine(seed, std::get<0>(tex)); - hash_combine(seed, std::get<1>(tex)); - hash_combine(seed, std::get<2>(tex)); - hash_combine(seed, std::get<3>(tex)); - hash_combine(seed, std::get<4>(tex)); - - return seed; - } - }; - - std::unordered_map textures_; - volatile bool stop_; std::exception_ptr captured_exception_; @@ -241,16 +172,4 @@ class NvDecoder { } // namespace dali -namespace std { -template<> -struct hash { - public: - std::size_t operator()(dali::ScaleMethod const& s) const noexcept { - return std::hash()(s); - } -}; - -} // namespace std - - #endif // DALI_OPERATORS_READER_NVDECODER_NVDECODER_H_ diff --git a/dali/test/python/decoder/test_video.py b/dali/test/python/decoder/test_video.py index de19875b5e..7bde338c9b 100644 --- a/dali/test/python/decoder/test_video.py +++ b/dali/test/python/decoder/test_video.py @@ -16,6 +16,7 @@ from nvidia.dali import pipeline_def import numpy as np import cv2 +import av import nvidia.dali.types as types import glob import os @@ -709,7 +710,7 @@ def test_pipeline(): def extract_frames_from_video( video_path, start_frame=None, sequence_length=None, end_frame=None, stride=None ): - """Extracts frames from a video file using OpenCV's VideoCapture. + """Extracts frames from a video file using PyAV. Args: video_path: Path to the video file @@ -721,41 +722,43 @@ def extract_frames_from_video( Returns: List of frames as numpy arrays """ - frames = [] - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise RuntimeError(f"Failed to open video file: {video_path}") + frames = [] if sequence_length is not None and end_frame is not None: raise ValueError("Cannot specify both sequence_length and end_frame") - # Set starting frame - if start_frame is not None: - cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + container = av.open(video_path) + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise RuntimeError(f"No video stream found in {video_path}") - frame_count = 0 + # Calculate correct starting point + start = start_frame or 0 + stride = stride or 1 frames_read = 0 - while True: - if sequence_length is not None and frames_read >= sequence_length: - break - if end_frame is not None and frame_count + (start_frame or 0) >= end_frame: - break + decoded_frame_idx = 0 - ret, frame = cap.read() - if not ret: + for frame in container.decode(video_stream): + # Early exit if possible + if (sequence_length is not None and frames_read >= sequence_length) or ( + end_frame is not None and decoded_frame_idx >= end_frame + ): break - - # Only append frames according to stride - if stride is None or frame_count % stride == 0: - # Convert BGR to RGB since OpenCV uses BGR by default - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame) + if decoded_frame_idx < start: + decoded_frame_idx += 1 + continue + + # If stride requested, process only on stride-boundary + if (decoded_frame_idx - start) % stride == 0: + # SWS_BILINEAR(0x40)|SWS_FULL_CHR_H_INT(0x2000)|SWS_ACCURATE_RND(0x40000) + nd_frame = frame.to_ndarray(format="rgb24", interpolation=0x40 + 0x2000 + 0x40000) + frames.append(nd_frame) frames_read += 1 - frame_count += 1 + decoded_frame_idx += 1 - cap.release() + container.close() return frames diff --git a/dali/test/python/test_video_reader.py b/dali/test/python/test_video_reader.py index 20c9f830b2..dc8084a609 100644 --- a/dali/test/python/test_video_reader.py +++ b/dali/test/python/test_video_reader.py @@ -51,7 +51,7 @@ def compare_frames( # Compare frames diff_pixels = np.count_nonzero(np.abs(np.float32(frame) - np.float32(ref_frame)) > diff_step) total_pixels = frame.size - # More than 3% of the pixels differ in more than 2 steps + # More than threshold of the pixels differ in more than 2 steps if diff_pixels / total_pixels > threshold: # Save the mismatched frames for inspection frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) @@ -64,7 +64,7 @@ def compare_frames( cv2.imwrite(ref_output_path, ref_frame_bgr) assert False, ( f"Frame {frame_idx+1} differs from reference by more than {diff_step} steps in " - + f"{diff_pixels/total_pixels*100}% of pixels. " + + f"{diff_pixels/total_pixels*100}% of pixels (threshold: {threshold}). " + f"Expected {ref_frame_bgr} but got {frame_bgr}" ) diff --git a/qa/TL0_multigpu/test_nofw.sh b/qa/TL0_multigpu/test_nofw.sh index bc551a0efb..e6c74e1a4b 100755 --- a/qa/TL0_multigpu/test_nofw.sh +++ b/qa/TL0_multigpu/test_nofw.sh @@ -1,6 +1,6 @@ #!/bin/bash -e # used pip packages -pip_packages='${python_test_runner_package} numpy librosa scipy nvidia-ml-py==11.450.51 psutil dill cloudpickle opencv-python-headless pillow' +pip_packages='${python_test_runner_package} numpy librosa scipy nvidia-ml-py==11.450.51 psutil dill cloudpickle opencv-python-headless pillow av' target_dir=./dali/test/python # test_body definition is in separate file so it can be used without setup diff --git a/qa/TL0_python-self-test-readers-decoders/test_nofw.sh b/qa/TL0_python-self-test-readers-decoders/test_nofw.sh index d2b1969f29..2f37c7be62 100755 --- a/qa/TL0_python-self-test-readers-decoders/test_nofw.sh +++ b/qa/TL0_python-self-test-readers-decoders/test_nofw.sh @@ -1,6 +1,6 @@ #!/bin/bash -e # used pip packages -pip_packages='${python_test_runner_package} numpy librosa scipy nvidia-ml-py==11.450.51 psutil dill cloudpickle pillow opencv-python-headless astropy' +pip_packages='${python_test_runner_package} numpy librosa scipy nvidia-ml-py==11.450.51 psutil dill cloudpickle pillow opencv-python-headless astropy av' target_dir=./dali/test/python diff --git a/qa/TL1_python-self-test_conda/test_nofw.sh b/qa/TL1_python-self-test_conda/test_nofw.sh index 2f08354ed9..63414c21e1 100755 --- a/qa/TL1_python-self-test_conda/test_nofw.sh +++ b/qa/TL1_python-self-test_conda/test_nofw.sh @@ -1,6 +1,6 @@ #!/bin/bash -e # used pip packages -pip_packages='${python_test_runner_package} dataclasses numpy opencv-python-headless pillow librosa scipy nvidia-ml-py==11.450.51 numba lz4 psutil dill cloudpickle astropy' +pip_packages='${python_test_runner_package} dataclasses numpy opencv-python-headless pillow librosa scipy nvidia-ml-py==11.450.51 numba lz4 psutil dill cloudpickle astropy av' target_dir=./dali/test/python # test_body definition is in separate file so it can be used without setup