diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ca21cc4..4664aab3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,3 +4,10 @@ repos: hooks: - id: ruff - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.4 + hooks: + - id: clang-format + files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$' + exclude: '^third_party/|/pybind11/' + name: clang-format \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 2fdd68f7..7848e1fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,8 +9,11 @@ set(CMAKE_CUDA_STANDARD 17) # Set default build type to Release if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) - set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" + "MinSizeRel" "RelWithDebInfo") endif() # Check for CUDA @@ -23,35 +26,48 @@ if(CMAKE_CUDA_COMPILER) message(STATUS "CUDA compiler found: ${CMAKE_CUDA_COMPILER}") if(NOT SKBUILD) - message(FATAL_ERROR "Building standalone project directly without pip install is not supported" - "Please use pip install to build the project") + message( + FATAL_ERROR + "Building standalone project directly without pip install is not supported" + "Please use pip install to build the project") else() find_package(CUDAToolkit REQUIRED) - find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + # Add the executable + find_package( + Python 3.11 REQUIRED + COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.SABIModule) + execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE XLA_DIR) + message(STATUS "XLA include directory: ${XLA_DIR}") # Detect the installed nanobind package and import it into CMake - execute_process( - COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT) - find_package(nanobind CONFIG REQUIRED) - - nanobind_add_module(_s2fft STABLE_ABI - ${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc - ${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu - ${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_callbacks.cu - ${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc - ${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu - ) + find_package(nanobind CONFIG REQUIRED) + + nanobind_add_module( + _s2fft + STABLE_ABI + ${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc + ${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu + ${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc + ${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu) target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static CUDA::culibos) - target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include) - set_target_properties(_s2fft PROPERTIES - LINKER_LANGUAGE CUDA - CUDA_SEPARABLE_COMPILATION ON) - set(CMAKE_CUDA_ARCHITECTURES "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") + target_include_directories( + _s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR} ${CUDAToolkit_INCLUDE_DIRS}) + set_target_properties(_s2fft PROPERTIES LINKER_LANGUAGE CUDA + CUDA_SEPARABLE_COMPILATION ON) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true") + set(CMAKE_CUDA_ARCHITECTURES + "70;80;89" + CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") message(STATUS "CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") - set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}") + set_target_properties(_s2fft PROPERTIES CUDA_ARCHITECTURES + "${CMAKE_CUDA_ARCHITECTURES}") install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib) endif() @@ -60,26 +76,35 @@ else() if(SKBUILD) message(WARNING "CUDA compiler not found, building without CUDA support") - find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + find_package( + Python 3.11 + COMPONENTS Interpreter Development.Module + REQUIRED) + # Add the executable execute_process( - COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT) - find_package(nanobind CONFIG REQUIRED) + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE XLA_DIR) + message(STATUS "XLA include directory: ${XLA_DIR}") - nanobind_add_module(_s2fft STABLE_ABI - ${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc - ) + # Detect the installed nanobind package and import it into CMake + find_package(nanobind CONFIG REQUIRED) + + nanobind_add_module(_s2fft STABLE_ABI + ${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc) target_compile_definitions(_s2fft PRIVATE NO_CUDA_COMPILER) - target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include) + target_include_directories( + _s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR}) install(TARGETS _s2fft LIBRARY DESTINATION s2fft_lib) else() - message(FATAL_ERROR "Building standalone project directly without pip install is not supported" - "Please use pip install to build the project") + message( + FATAL_ERROR + "Building standalone project directly without pip install is not supported" + "Please use pip install to build the project") endif() endif() - - diff --git a/lib/include/cudastreamhandler.hpp b/lib/include/cudastreamhandler.hpp new file mode 100644 index 00000000..f1b4ab4d --- /dev/null +++ b/lib/include/cudastreamhandler.hpp @@ -0,0 +1,165 @@ + +/** + * @file cudastreamhandler.hpp + * @brief Singleton class for managing CUDA streams and events. + * + * This header provides a singleton implementation that encapsulates the creation, + * management, and cleanup of CUDA streams and events. It offers functions to fork + * streams, add new streams, and synchronize (join) streams with a given dependency. + * + * Usage example: + * @code + * #include "cudastreamhandler.hpp" + * + * int main() { + * // Create a handler instance + * CudaStreamHandler handler; + * + * // Fork 4 streams dependent on a given stream 'stream_main' + * handler.Fork(stream_main, 4); + * + * // Do work on the forked streams... + * + * // Join the streams back to 'stream_main' + * handler.join(stream_main); + * + * return 0; + * } + * @endcode + * + * Author: Wassim KABALAN + */ + +#ifndef CUDASTREAMHANDLER_HPP +#define CUDASTREAMHANDLER_HPP + +#include +#include +#include +#include +#include +#include + +// Singleton class managing CUDA streams and events +class CudaStreamHandlerImpl { +public: + static CudaStreamHandlerImpl &instance() { + static CudaStreamHandlerImpl instance; + return instance; + } + + void AddStreams(int numStreams) { + if (numStreams > m_streams.size()) { + int streamsToAdd = numStreams - m_streams.size(); + m_streams.resize(numStreams); + std::generate(m_streams.end() - streamsToAdd, m_streams.end(), []() { + cudaStream_t stream; + cudaStreamCreate(&stream); + return stream; + }); + } + } + + void join(cudaStream_t finalStream) { + std::for_each(m_streams.begin(), m_streams.end(), [this, finalStream](cudaStream_t stream) { + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, stream); + cudaStreamWaitEvent(finalStream, event, 0); + m_events.push_back(event); + }); + + if (!cleanup_thread.joinable()) { + stop_thread.store(false); + cleanup_thread = std::thread([this]() { this->AsyncEventCleanup(); }); + } + } + + // Fork function to add streams and set dependency on a given stream + void Fork(cudaStream_t dependentStream, int N) { + AddStreams(N); // Add N streams + + // Set dependency on the provided stream + std::for_each(m_streams.end() - N, m_streams.end(), [this, dependentStream](cudaStream_t stream) { + cudaEvent_t event; + cudaEventCreate(&event); + cudaEventRecord(event, dependentStream); + cudaStreamWaitEvent(stream, event, 0); // Set the stream to wait on the event + m_events.push_back(event); + }); + } + + auto getIterator() { return StreamIterator(m_streams.begin(), m_streams.end()); } + + ~CudaStreamHandlerImpl() { + stop_thread.store(true); + if (cleanup_thread.joinable()) { + cleanup_thread.join(); + } + + std::for_each(m_streams.begin(), m_streams.end(), cudaStreamDestroy); + std::for_each(m_events.begin(), m_events.end(), cudaEventDestroy); + } + + // Custom Iterator class to iterate over streams + class StreamIterator { + public: + StreamIterator(std::vector::iterator begin, std::vector::iterator end) + : current(begin), end(end) {} + + cudaStream_t next() { + if (current == end) { + throw std::out_of_range("No more streams."); + } + return *current++; + } + + bool hasNext() const { return current != end; } + + private: + std::vector::iterator current; + std::vector::iterator end; + }; + +private: + CudaStreamHandlerImpl() : stop_thread(false) {} + CudaStreamHandlerImpl(const CudaStreamHandlerImpl &) = delete; + CudaStreamHandlerImpl &operator=(const CudaStreamHandlerImpl &) = delete; + + void AsyncEventCleanup() { + while (!stop_thread.load()) { + std::for_each(m_events.begin(), m_events.end(), [this](cudaEvent_t &event) { + if (cudaEventQuery(event) == cudaSuccess) { + cudaEventDestroy(event); + event = nullptr; + } + }); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + std::vector m_streams; + std::vector m_events; + std::thread cleanup_thread; + std::atomic stop_thread; +}; + +// Public class for encapsulating the singleton operations +class CudaStreamHandler { +public: + CudaStreamHandler() = default; + ~CudaStreamHandler() = default; + + void AddStreams(int numStreams) { CudaStreamHandlerImpl::instance().AddStreams(numStreams); } + + void join(cudaStream_t finalStream) { CudaStreamHandlerImpl::instance().join(finalStream); } + + void Fork(cudaStream_t cudastream, int N) { CudaStreamHandlerImpl::instance().Fork(cudastream, N); } + + // Get the custom iterator for CUDA streams + CudaStreamHandlerImpl::StreamIterator getIterator() { + return CudaStreamHandlerImpl::instance().getIterator(); + } +}; + +#endif // CUDASTREAMHANDLER_HPP diff --git a/lib/include/kernel_helpers.h b/lib/include/kernel_helpers.h deleted file mode 100644 index 12980e08..00000000 --- a/lib/include/kernel_helpers.h +++ /dev/null @@ -1,76 +0,0 @@ -// Adapted from code in a tutorial by Dan Foreman-Mackey -// https://github.com/dfm/extending-jax/blob/c33869665236877a2ae281/lib/kernel_helpers.h -// -// Original license: -// -// MIT License -// -// Copyright (c) 2021 Dan Foreman-Mackey -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// This header is not specific to our application and you'll probably want -// something like this for any extension you're building. This includes the -// infrastructure needed to serialize descriptors that are used with the -// "opaque" parameter of the GPU custom call. In our example we'll use this -// parameter to pass the size of our problem. - -#ifndef _KERNEL_HELPERS_H_ -#define _KERNEL_HELPERS_H_ - -#include -#include -#include -#include -#include - -namespace s2fft { - -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && - std::is_trivially_copyable::value, - To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to " - "be trivially constructible"); - - To dst; - - memcpy(&dst, &src, sizeof(To)); - return dst; -} - -template std::string PackDescriptorAsString(const T &descriptor) { - return std::string(bit_cast(&descriptor), sizeof(T)); -} - -template -const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { - if (opaque_len != sizeof(T)) { - throw std::runtime_error("Invalid opaque object size"); - } - return bit_cast(opaque); -} - -} // namespace s2fft - -#endif // _KERNEL_HELPERS_H_ diff --git a/lib/include/kernel_nanobind_helpers.h b/lib/include/kernel_nanobind_helpers.h deleted file mode 100644 index f076b79f..00000000 --- a/lib/include/kernel_nanobind_helpers.h +++ /dev/null @@ -1,51 +0,0 @@ -// Adapted from code by JAX authors -// https://github.com/jax-ml/jax/blob/3d389a7fb440c412d/jaxlib/kernel_nanobind_helpers.h - -/* Copyright 2019 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef _KERNEL_NANOBIND_HELPERS_H_ -#define _KERNEL_NANOBIND_HELPERS_H_ - -#include - -#include "nanobind/nanobind.h" -#include "kernel_helpers.h" - -namespace s2fft { - -// Descriptor objects are opaque host-side objects used to pass data from JAX -// to the custom kernel launched by XLA. Currently simply treat host-side -// structures as byte-strings; this is not portable across architectures. If -// portability is needed, we could switch to using a representation such as -// protocol buffers or flatbuffers. - -// Packs a descriptor object into a nanobind::bytes structure. -// UnpackDescriptor() is available in kernel_helpers.h. -template -nanobind::bytes PackDescriptor(const T& descriptor) { - std::string s = PackDescriptorAsString(descriptor); - return nanobind::bytes(s.data(), s.size()); -} - -template -nanobind::capsule EncapsulateFunction(T* fn) { - return nanobind::capsule(bit_cast(fn), - "xla._CUSTOM_CALL_TARGET"); -} - -} // namespace s2fft - -#endif // _KERNEL_NANOBIND_HELPERS_H_ diff --git a/lib/include/plan_cache.h b/lib/include/plan_cache.h index 5543d446..9038cb76 100644 --- a/lib/include/plan_cache.h +++ b/lib/include/plan_cache.h @@ -1,4 +1,3 @@ - #ifndef PLAN_CACHE_H #define PLAN_CACHE_H @@ -9,26 +8,67 @@ #include "hresult.h" #include "s2fft.h" #include +#include namespace s2fft { +/** + * @brief Manages and caches s2fftExec instances to optimize resource usage. + * + * This class implements the singleton pattern to ensure only one instance + * of the PlanCache exists throughout the application. It stores pre-initialized + * s2fftExec objects based on their descriptors (parameters like nside, L, etc.) + * to avoid redundant initialization, which can be computationally expensive. + */ class PlanCache { public: + /** + * @brief Returns the singleton instance of the PlanCache. + * + * @return A reference to the single PlanCache instance. + */ static PlanCache &GetInstance() { static PlanCache instance; return instance; } - HRESULT GetS2FFTExec(s2fftDescriptor &descriptor, std::shared_ptr> &executor); + /** + * @brief Retrieves an s2fftExec instance from the cache or initializes a new one. + * + * This templated method attempts to find an existing s2fftExec instance + * matching the provided descriptor in its internal cache (m_Descriptors32 or m_Descriptors64) + * based on the Complex type T. If a matching instance is found, it is returned. + * Otherwise, a new s2fftExec instance is created, initialized with the descriptor, + * and then stored in the cache before being returned. + * + * @tparam T The complex type (cufftComplex or cufftDoubleComplex) of the s2fftExec instance. + * @param descriptor The s2fftDescriptor containing the parameters for the FFT. + * @param executor A shared_ptr that will point to the retrieved or newly initialized s2fftExec instance. + * @return HRESULT indicating success (S_OK if new, S_FALSE if from cache) or failure. + */ + template + HRESULT GetS2FFTExec(s2fftDescriptor &descriptor, std::shared_ptr> &executor); - HRESULT GetS2FFTExec(s2fftDescriptor &descriptor, - std::shared_ptr> &executor); + /** + * @brief Clears all cached s2fftExec instances. + * + * This method is typically called during application shutdown to release + * all resources held by the cached FFT plans. + */ + void Finalize(); - ~PlanCache() {} + /** + * @brief Destructor for PlanCache. + * + * Ensures that Finalize() is called when the PlanCache instance is destroyed, + * performing necessary cleanup. + */ + ~PlanCache(); private: bool is_initialized = false; + // Unordered maps to store cached s2fftExec instances for double and single precision std::unordered_map>, std::hash, std::equal_to<>> m_Descriptors64; @@ -36,9 +76,16 @@ class PlanCache { std::equal_to<>> m_Descriptors32; + /** + * @brief Private constructor for PlanCache. + * + * Initializes the PlanCache instance. This constructor is private to enforce + * the singleton pattern. + */ PlanCache(); public: + // Delete copy constructor and assignment operator to prevent copying PlanCache(PlanCache const &) = delete; void operator=(PlanCache const &) = delete; }; diff --git a/lib/include/s2fft.h b/lib/include/s2fft.h index af89416e..176b50c6 100644 --- a/lib/include/s2fft.h +++ b/lib/include/s2fft.h @@ -1,4 +1,3 @@ - #ifndef S2FFT_H #define S2FFT_H @@ -15,42 +14,107 @@ #include "cufft.h" #include "cufftXt.h" #include "thrust/device_vector.h" -#include "s2fft_callbacks.h" +#include "s2fft_kernels.h" namespace s2fft { +/** + * @brief Returns the appropriate cuFFT C2C type for a given complex type. + * + * This function is overloaded for `cufftDoubleComplex` and `cufftComplex` + * to return `CUFFT_Z2Z` (double precision) or `CUFFT_C2C` (single precision) + * respectively. + * + * @param dummy A dummy complex object used for type deduction. + * @return The corresponding cuFFT C2C type. + */ static cufftType get_cufft_type_c2c(cufftDoubleComplex) { return CUFFT_Z2Z; } static cufftType get_cufft_type_c2c(cufftComplex) { return CUFFT_C2C; } +/** + * @brief Transforms data from ring-based indexing to nphi-based indexing. + * + * This function is a placeholder for the actual implementation which would + * reorder data in memory according to the specified indexing scheme. + * + * @param data Pointer to the input/output data. + * @param nside The HEALPix Nside parameter. + */ void s2fft_rings_2_nphi(float *data, int nside); +/** + * @brief Transforms data from nphi-based indexing to ring-based indexing. + * + * This function is a placeholder for the actual implementation which would + * reorder data in memory according to the specified indexing scheme. + * + * @param data Pointer to the input/output data. + * @param nside The HEALPix Nside parameter. + */ void s2fft_nphi_2_rings(float *data, int nside); +/** + * @brief Descriptor class for s2fft operations. + * + * This class encapsulates all the necessary parameters to define a unique + * Spherical Harmonic Transform (SHT) operation, including Nside, harmonic + * band limit, reality, adjoint flag, forward/backward transform direction, + * normalization, shifting, and double precision usage. + */ class s2fftDescriptor { public: - int nside; - int harmonic_band_limit; + int64_t nside; + int64_t harmonic_band_limit; bool reality; + bool adjoint; bool forward = true; s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD; bool shift = true; bool double_precision = false; - s2fftDescriptor(int nside, int harmonic_band_limit, bool reality, bool forward = true, - s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD, bool shift = true, - bool double_precision = false) + /** + * @brief Constructs an s2fftDescriptor object. + * + * @param nside The HEALPix Nside parameter. + * @param harmonic_band_limit The harmonic band limit L. + * @param reality Flag indicating if the signal is real. + * @param adjoint Flag indicating if the adjoint transform is to be performed. + * @param forward Flag indicating if it's a forward transform (default: true). + * @param norm The FFT normalization type (default: BACKWARD). + * @param shift Flag indicating if FFT shifting should be applied (default: true). + * @param double_precision Flag indicating if double precision should be used (default: false). + */ + s2fftDescriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool adjoint, + bool forward = true, s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD, + bool shift = true, bool double_precision = false) : nside(nside), harmonic_band_limit(harmonic_band_limit), reality(reality), + adjoint(adjoint), norm(norm), forward(forward), shift(shift), double_precision(double_precision) {} + /** + * @brief Default constructor for s2fftDescriptor. + */ s2fftDescriptor() = default; + + /** + * @brief Destructor for s2fftDescriptor. + */ ~s2fftDescriptor() = default; + /** + * @brief Equality operator for s2fftDescriptor. + * + * Compares two s2fftDescriptor objects for equality based on their member values. + * + * @param other The other s2fftDescriptor to compare against. + * @return True if the descriptors are equal, false otherwise. + */ bool operator==(const s2fftDescriptor &other) const { return nside == other.nside && harmonic_band_limit == other.harmonic_band_limit && reality == other.reality && norm == other.norm && shift == other.shift && @@ -58,25 +122,78 @@ class s2fftDescriptor { } }; +/** + * @brief Executes Spherical Harmonic Transform (SHT) operations. + * + * This templated class provides methods for initializing FFT plans and executing + * forward and backward SHTs. It manages cuFFT handles and internal offsets + * required for the transforms. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex) for the FFT operations. + */ template class s2fftExec { - friend class PlanCache; + friend class PlanCache; // Allows PlanCache to access private members for caching public: + /** + * @brief Default constructor for s2fftExec. + */ s2fftExec() {} - ~s2fftExec() {} - - HRESULT Initialize(const s2fftDescriptor &descriptor, size_t &worksize); - HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data); + /** + * @brief Destructor for s2fftExec. + */ + ~s2fftExec() {} - HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data); + /** + * @brief Initializes the FFT plans for the SHT. + * + * This method sets up the necessary cuFFT plans for both polar and equatorial + * rings based on the provided descriptor. It also calculates and stores the + * maximum required workspace size (m_work_size). + * + * @param descriptor The s2fftDescriptor containing the parameters for the FFT. + * @return HRESULT indicating success or failure. + */ + HRESULT Initialize(const s2fftDescriptor &descriptor); + + /** + * @brief Executes the forward Spherical Harmonic Transform. + * + * This method performs the forward FFT operations on the input data + * across polar and equatorial rings using the pre-initialized cuFFT plans. + * + * @param desc The s2fftDescriptor for the current transform. + * @param stream The CUDA stream to use for execution. + * @param data Pointer to the input/output data on the device. + * @param workspace Pointer to the workspace memory on the device. + * @return HRESULT indicating success or failure. + */ + HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace); + + /** + * @brief Executes the backward Spherical Harmonic Transform. + * + * This method performs the inverse FFT operations on the input data + * across polar and equatorial rings using the pre-initialized cuFFT plans. + * + * @param desc The s2fftDescriptor for the current transform. + * @param stream The CUDA stream to use for execution. + * @param data Pointer to the input/output data on the device. + * @param workspace Pointer to the workspace memory on the device. + * @return HRESULT indicating success or failure. + */ + HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace); public: + // cuFFT handles for polar and equatorial FFT plans std::vector m_polar_plans; cufftHandle m_equator_plan; std::vector m_inverse_polar_plans; cufftHandle m_inverse_equator_plan; + + // Parameters defining the SHT geometry and data layout int m_nside; int m_equatorial_ring_num; int64 m_total_pixels; @@ -84,18 +201,23 @@ class s2fftExec { int64 m_equatorial_offset_end; std::vector m_upper_ring_offsets; std::vector m_lower_ring_offsets; - - // Callback params stored for cleanup purposes - // thrust::device_vector m_cb_params; + size_t m_work_size = 0; // Maximum workspace size required for FFT plans }; } // namespace s2fft namespace std { +/** + * @brief Custom hash specialization for s2fftDescriptor. + * + * This specialization allows s2fftDescriptor objects to be used as keys + * in `std::unordered_map` by providing a hash function. + */ template <> struct hash { std::size_t operator()(const s2fft::s2fftDescriptor &k) const { - size_t hash = std::hash()(k.nside) ^ (std::hash()(k.harmonic_band_limit) << 1) ^ + // Combine hash values of individual members + size_t hash = std::hash()(k.nside) ^ (std::hash()(k.harmonic_band_limit) << 1) ^ (std::hash()(k.reality) << 2) ^ (std::hash()(k.norm) << 3) ^ (std::hash()(k.shift) << 4) ^ (std::hash()(k.double_precision) << 5); return hash; diff --git a/lib/include/s2fft_callbacks.h b/lib/include/s2fft_callbacks.h index 69a92e56..7f6d687c 100644 --- a/lib/include/s2fft_callbacks.h +++ b/lib/include/s2fft_callbacks.h @@ -1,3 +1,13 @@ +/** + * @file s2fft_callbacks.h + * @brief CUDA CUFFT callbacks for HEALPix spherical harmonic transforms + * + * @note CUFFT CALLBACKS DEPRECATED: This implementation no longer uses cuFFT callbacks. + * The previous callback-based approach has been replaced with direct kernel launches + * for better performance and maintainability. The files s2fft_callbacks.h and + * s2fft_callbacks.cc are no longer used and can be considered orphaned. + */ + #ifndef _S2FFT_CALLBACKS_CUH_ #define _S2FFT_CALLBACKS_CUH_ @@ -12,10 +22,44 @@ typedef long long int int64; namespace s2fftKernels { +/** + * @brief Defines the normalization types for FFT operations. + */ enum fft_norm { FORWARD = 1, BACKWARD = 2, ORTHO = 3, NONE = 4 }; -HRESULT setCallback(cufftHandle forwardPlan, cufftHandle backwardPlan, int64 *params_dev, bool shift, - bool equator, bool doublePrecision, fft_norm norm); +/** + * @brief Sets cuFFT callbacks specifically for a forward FFT plan. + * + * This function configures the cuFFT library to use custom callbacks + * for normalization and shifting operations during forward FFT execution. + * + * @param plan The cuFFT handle for the forward FFT plan. + * @param params_dev Pointer to device memory containing parameters for the callbacks. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param doublePrecision Boolean flag indicating if double precision is used. + * @param norm The FFT normalization type to apply. + * @return HRESULT indicating success or failure. + */ +HRESULT setForwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator, + bool doublePrecision, fft_norm norm); + +/** + * @brief Sets cuFFT callbacks specifically for a backward FFT plan. + * + * This function configures the cuFFT library to use custom callbacks + * for normalization and shifting operations during backward FFT execution. + * + * @param plan The cuFFT handle for the inverse FFT plan. + * @param params_dev Pointer to device memory containing parameters for the callbacks. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param doublePrecision Boolean flag indicating if double precision is used. + * @param norm The FFT normalization type to apply. + * @return HRESULT indicating success or failure. + */ +HRESULT setBackwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator, + bool doublePrecision, fft_norm norm); } // namespace s2fftKernels #endif \ No newline at end of file diff --git a/lib/include/s2fft_kernels.h b/lib/include/s2fft_kernels.h index 8825462c..103221c8 100644 --- a/lib/include/s2fft_kernels.h +++ b/lib/include/s2fft_kernels.h @@ -9,14 +9,82 @@ #include typedef long long int int64; +/** + * @file s2fft_kernels.h + * @brief CUDA kernels for HEALPix spherical harmonic transforms + * + * @note CUFT CALLBACKS DEPRECATED: This implementation no longer uses cuFFT callbacks. + * The previous callback-based approach has been replaced with direct kernel launches + * for better performance and maintainability. The files s2fft_callbacks.h and + * s2fft_callbacks.cc are no longer used and can be considered orphaned. + */ + namespace s2fftKernels { +enum fft_norm { FORWARD = 1, BACKWARD = 2, ORTHO = 3, NONE = 4 }; + +/** + * @brief Launches the spectral folding CUDA kernel. + * + * This function configures and launches the spectral_folding kernel with + * appropriate grid and block dimensions. It performs spectral folding operations + * on ring-ordered data, transforming from Fourier coefficient space to HEALPix + * pixel space with optional FFT shifting. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input data array containing Fourier coefficients per ring. + * @param output Output array for folded HEALPix pixel data. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + * @param shift Flag indicating whether to apply FFT shifting. + * @param stream CUDA stream for kernel execution. + * @return HRESULT indicating success or failure. + */ template HRESULT launch_spectral_folding(complex* data, complex* output, const int& nside, const int& L, const bool& shift, cudaStream_t stream); + +/** + * @brief Launches the spectral extension CUDA kernel. + * + * This function configures and launches the spectral_extension kernel with + * appropriate grid and block dimensions. It performs the inverse operation of + * spectral folding, extending HEALPix pixel data back to full Fourier coefficient + * space by mapping folded frequency components to their appropriate positions. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input array containing folded HEALPix pixel data. + * @param output Output array for extended Fourier coefficients per ring. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + * @param stream CUDA stream for kernel execution. + * @return HRESULT indicating success or failure. + */ template HRESULT launch_spectral_extension(complex* data, complex* output, const int& nside, const int& L, cudaStream_t stream); + +/** + * @brief Launches the shift/normalize CUDA kernel for HEALPix data processing. + * + * This function configures and launches the shift_normalize_kernel with appropriate + * grid and block dimensions. It handles both single and double precision complex + * types and applies the requested normalization and shifting operations to HEALPix + * pixel data on a per-ring basis. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param stream CUDA stream for kernel execution. + * @param data Input/output array of HEALPix pixel data (in-place processing). + * @param nside The HEALPix Nside parameter. + * @param apply_shift Flag indicating whether to apply FFT shifting. + * @param norm Normalization type (0=by nphi, 1=by sqrt(nphi), 2=no normalization). + * @return HRESULT indicating success or failure. + */ +template +HRESULT launch_shift_normalize_kernel(cudaStream_t stream, + complex* data, // In-place data buffer + int nside, bool apply_shift, int norm); + } // namespace s2fftKernels #endif // _S2FFT_KERNELS_H \ No newline at end of file diff --git a/lib/src/extensions.cc b/lib/src/extensions.cc index 8d5a7c4c..e2ce1917 100644 --- a/lib/src/extensions.cc +++ b/lib/src/extensions.cc @@ -1,150 +1,423 @@ - -#include "kernel_nanobind_helpers.h" -#include "kernel_helpers.h" #include #include +#include +#include + +namespace nb = nanobind; #ifndef NO_CUDA_COMPILER +#include "xla/ffi/api/api.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "cuda_runtime.h" #include "plan_cache.h" #include "s2fft_kernels.h" #include "s2fft.h" -#else -void print_error() { - - throw std::runtime_error("This extension was compiled without CUDA support. Cuda functions are not supported."); -} -#endif +#include "cudastreamhandler.hpp" // For forking and joining CUDA streams -namespace nb = nanobind; +namespace ffi = xla::ffi; namespace s2fft { -#ifdef NO_CUDA_COMPILER -void healpix_fft_cuda() { print_error(); } -#else -void healpix_forward(cudaStream_t stream, void** buffers, s2fftDescriptor descriptor) { - void* data = buffers[0]; - void* output = buffers[1]; +/** + * @brief Mapping from XLA DataType to CUFFT complex types. + */ +template +struct FftComplexType; - size_t work_size; - // Execute the kernel based on the Precision - if (descriptor.double_precision) { - auto executor = std::make_shared>(); - cufftDoubleComplex* data_c = reinterpret_cast(data); - cufftDoubleComplex* out_c = reinterpret_cast(output); +template <> +struct FftComplexType { + using type = cufftDoubleComplex; +}; - PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - // Run the fft part - executor->Forward(descriptor, stream, data_c); - // Run the spectral extension part - s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside, - descriptor.harmonic_band_limit, stream); +template <> +struct FftComplexType { + using type = cufftComplex; +}; + +template +using fft_complex_t = typename FftComplexType
::type; + +/** + * @brief Helper to indicate if using double precision. + * + * Default is false; specialized for C128. + */ +template +struct is_double : std::false_type {}; +template <> +struct is_double : std::true_type {}; + +template +constexpr bool is_double_v = is_double::value; + +/** + * @brief Performs a forward HEALPix transform on a single element or batch. + * + * For a batched call, the input buffer is assumed to be 2D: [batch_size, nside^2*12], + * and the output is 3D: [batch_size, (4*nside-1), 2*harmonic_band_limit]. + * + * For non-batched call, the input is 1D and the output is 1D. + * + * @tparam T The XLA data type (F32, F64, etc). + * @param stream CUDA stream to use. + * @param input Input buffer containing HEALPix pixel-space data. + * @param output Output buffer to store the FTM result. + * @param workspace Output buffer for temporary workspace memory. + * @param descriptor Descriptor containing transform parameters. + * @return ffi::Error indicating success or failure. + */ +template +ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer input, ffi::Result> output, + ffi::Result> workspace, s2fftDescriptor descriptor) { + // Step 1: Determine the complex type based on the XLA data type. + using fft_complex_type = fft_complex_t; + const auto& dim_in = input.dimensions(); + + // Step 2: Handle batched and non-batched cases separately. + if (dim_in.size() == 2) { + // Step 2a: Batched case. + int batch_count = dim_in[0]; + // Step 2b: Compute offsets for input and output for each batch. + int64_t input_offset = descriptor.nside * descriptor.nside * 12; + int64_t output_offset = (4 * descriptor.nside - 1) * (2 * descriptor.harmonic_band_limit); + + // Step 2c: Fork CUDA streams for parallel processing of batches. + CudaStreamHandler handler; + handler.Fork(stream, batch_count); + auto stream_iter = handler.getIterator(); + + // Step 2d: Iterate over each batch. + for (int i = 0; i < batch_count && stream_iter.hasNext(); ++i) { + cudaStream_t sub_stream = stream_iter.next(); + // Step 2e: Get or create an s2fftExec instance from the PlanCache. + auto executor = std::make_shared>(); + PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); + + // Step 2f: Calculate device pointers for the current batch's data, output, and workspace. + fft_complex_type* data_c = + reinterpret_cast(input.typed_data() + i * input_offset); + fft_complex_type* out_c = + reinterpret_cast(output->typed_data() + i * output_offset); + fft_complex_type* workspace_c = + reinterpret_cast(workspace->typed_data() + i * executor->m_work_size); + + // Step 2g: Launch the forward transform on this sub-stream. + executor->Forward(descriptor, sub_stream, data_c, workspace_c); + // Step 2h: Launch spectral extension kernel. + s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside, + descriptor.harmonic_band_limit, sub_stream); + } + // Step 2i: Join all forked streams back to the main stream. + handler.join(stream); + return ffi::Error::Success(); } else { - auto executor = std::make_shared>(); - cufftComplex* data_c = reinterpret_cast(data); - cufftComplex* out_c = reinterpret_cast(output); + // Step 2j: Non-batched case. + // Step 2k: Get device pointers for data, output, and workspace. + fft_complex_type* data_c = reinterpret_cast(input.typed_data()); + fft_complex_type* out_c = reinterpret_cast(output->typed_data()); + fft_complex_type* workspace_c = reinterpret_cast(workspace->typed_data()); + // Step 2l: Get or create an s2fftExec instance from the PlanCache. + auto executor = std::make_shared>(); PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - // Run the fft part - executor->Forward(descriptor, stream, data_c); - // Run the spectral extension part + // Step 2m: Launch the forward transform. + executor->Forward(descriptor, stream, data_c, workspace_c); + // Step 2n: Launch spectral extension kernel. s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside, descriptor.harmonic_band_limit, stream); + return ffi::Error::Success(); } } -void healpix_backward(cudaStream_t stream, void** buffers, s2fftDescriptor descriptor) { - void* data = buffers[0]; - void* output = buffers[1]; +/** + * @brief Performs a backward HEALPix transform on a single element or batch. + * + * For a batched call, the input buffer is assumed to be 3D: [batch_size, (4*nside-1), 2*harmonic_band_limit], + * and the output is 2D: [batch_size, nside^2*12]. + * + * For non-batched call, the input is 1D and the output is 1D. + * + * @tparam T The XLA data type. + * @param stream CUDA stream to use. + * @param input Input buffer containing FTM data. + * @param output Output buffer to store HEALPix pixel-space data. + * @param workspace Output buffer for temporary workspace memory. + * @param descriptor Descriptor containing transform parameters. + * @return ffi::Error indicating success or failure. + */ +template +ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer input, ffi::Result> output, + ffi::Result> workspace, s2fftDescriptor descriptor) { + // Step 1: Determine the complex type based on the XLA data type. + using fft_complex_type = fft_complex_t; + const auto& dim_in = input.dimensions(); + const auto& dim_out = output->dimensions(); - size_t work_size; - // Execute the kernel based on the Precision - if (descriptor.double_precision) { - auto executor = std::make_shared>(); - cufftDoubleComplex* data_c = reinterpret_cast(data); - cufftDoubleComplex* out_c = reinterpret_cast(output); + // Step 2: Handle batched and non-batched cases separately. + if (dim_in.size() == 3) { + // Step 2a: Batched case. + // Assertions to ensure correct input/output dimensions for batched operations. + assert(dim_out.size() == 2); + assert(dim_in[0] == dim_out[0]); + int batch_count = dim_in[0]; + // Step 2b: Compute offsets for input, output, and callback parameters for each batch. + int64_t input_offset = (4 * descriptor.nside - 1) * (2 * descriptor.harmonic_band_limit); + int64_t output_offset = descriptor.nside * descriptor.nside * 12; - PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - // Run the spectral folding part - s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside, descriptor.harmonic_band_limit, - descriptor.shift, stream); - // Run the fft part - executor->Backward(descriptor, stream, out_c); + // Step 2c: Fork CUDA streams for parallel processing of batches. + CudaStreamHandler handler; + handler.Fork(stream, batch_count); + auto stream_iter = handler.getIterator(); + + // Step 2d: Iterate over each batch. + for (int i = 0; i < batch_count && stream_iter.hasNext(); ++i) { + cudaStream_t sub_stream = stream_iter.next(); + // Step 2e: Get or create an s2fftExec instance from the PlanCache. + auto executor = std::make_shared>(); + PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); + + // Step 2f: Calculate device pointers for the current batch's data, output, and workspace. + fft_complex_type* data_c = + reinterpret_cast(input.typed_data() + i * input_offset); + fft_complex_type* out_c = + reinterpret_cast(output->typed_data() + i * output_offset); + fft_complex_type* workspace_c = + reinterpret_cast(workspace->typed_data() + i * executor->m_work_size); + // Step 2g: Launch spectral folding kernel. + s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside, + descriptor.harmonic_band_limit, descriptor.shift, + sub_stream); + // Step 2h: Launch the backward transform on this sub-stream. + executor->Backward(descriptor, sub_stream, out_c, workspace_c); + } + // Step 2i: Join all forked streams back to the main stream. + handler.join(stream); + return ffi::Error::Success(); } else { - auto executor = std::make_shared>(); - cufftComplex* data_c = reinterpret_cast(data); - cufftComplex* out_c = reinterpret_cast(output); + // Step 2j: Non-batched case. + // Assertions to ensure correct input/output dimensions for non-batched operations. + assert(dim_in.size() == 2); + assert(dim_out.size() == 1); + // Step 2k: Get device pointers for data, output, and workspace. + fft_complex_type* data_c = reinterpret_cast(input.typed_data()); + fft_complex_type* out_c = reinterpret_cast(output->typed_data()); + fft_complex_type* workspace_c = reinterpret_cast(workspace->typed_data()); + // Step 2l: Get or create an s2fftExec instance from the PlanCache. + auto executor = std::make_shared>(); PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - // Run the spectral folding part + // Step 2m: Launch spectral folding kernel. s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside, descriptor.harmonic_band_limit, descriptor.shift, stream); - // Run the fft part - executor->Backward(descriptor, stream, out_c); + // Step 2n: Launch the backward transform. + executor->Backward(descriptor, stream, out_c, workspace_c); + return ffi::Error::Success(); + } +} + +/** + * @brief Builds an s2fftDescriptor based on provided parameters. + * + * This descriptor is identical for all batch elements. It also ensures that + * an s2fftExec instance corresponding to the descriptor is initialized and cached. + * + * @tparam T The XLA data type. + * @param nside HEALPix resolution parameter. + * @param harmonic_band_limit Harmonic band limit L. + * @param reality Flag indicating whether data is real-valued. + * @param forward Flag indicating forward transform. + * @param normalize Flag for normalization. + * @param adjoint Flag indicating if an adjoint operation is desired. + * @param must_exist If true, throws an error if the plan does not exist in the cache. + * @return s2fftDescriptor configured with the given parameters. + */ +template +s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward, + bool normalize, bool adjoint, bool must_exist, size_t& work_size) { + using fft_complex_type = fft_complex_t; + // Step 1: Determine FFT normalization type based on forward/normalize flags. + s2fftKernels::fft_norm norm = s2fftKernels::fft_norm::NONE; + if (forward && normalize) { + norm = s2fftKernels::fft_norm::FORWARD; + } else if (!forward && normalize) { + norm = s2fftKernels::fft_norm::BACKWARD; + } else if (forward && !normalize) { + norm = s2fftKernels::fft_norm::BACKWARD; + } else if (!forward && !normalize) { + norm = s2fftKernels::fft_norm::FORWARD; + } + // Step 2: Set shift flag (always true for now). + bool shift = true; + // Step 3: Create an s2fftDescriptor object with the given parameters. + s2fftDescriptor descriptor(nside, harmonic_band_limit, reality, adjoint, forward, norm, shift, + is_double_v); + + // Step 4: Get or create an s2fftExec instance from the PlanCache. + // This call will also initialize the executor if it's newly created. + auto executor = std::make_shared>(); + HRESULT hr = PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); + // Step 5: Handle cases where the plan was expected to exist but didn't. + if (hr == S_OK && must_exist) { + // This is an error because S_OK means plan was created, but must_exist implies it should have been + // found. + throw std::runtime_error("S2FFT INTERNAL ERROR: Plan did not exist but it was expected to exist."); + } + // Step 6: If the executor was just created (S_OK), initialize it. + // Note: PlanCache::GetS2FFTExec now handles workspace initialization internally + if (hr == S_OK) { + executor->Initialize(descriptor); } + // Make sure workspace is set + assert(executor->m_work_size > 0 && "S2FFT INTERNAL ERROR: Workspace size is zero after initialization."); + work_size = executor->m_work_size; + // Step 7: Return the created descriptor. + return descriptor; } -void healpix_fft_cuda(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { - // Get the descriptor from the opaque parameter - s2fftDescriptor descriptor = *UnpackDescriptor(opaque, opaque_len); - size_t work_size; - // Execute the kernel based on the Precision - if (descriptor.forward) { - healpix_forward(stream, buffers, descriptor); +/** + * @brief Unified entry point for the HEALPix FFT transform. + * + * This function serves as the main FFI entry point for HEALPix FFT operations. + * Depending on the value of the 'forward' flag in the descriptor, it dispatches + * to either the forward (`healpix_forward`) or backward (`healpix_backward`) transform. + * + * @tparam T The XLA data type. + * @param stream CUDA stream to use. + * @param nside HEALPix resolution parameter. + * @param harmonic_band_limit Harmonic band limit L. + * @param reality Flag indicating whether data is real-valued. + * @param forward Flag indicating forward transform. + * @param normalize Flag for normalization. + * @param adjoint Flag indicating if an adjoint operation is desired. + * @param input Input buffer. + * @param output Output buffer. + * @param workspace Output buffer for temporary workspace memory. + * @return ffi::Error indicating success or failure. + */ +template +ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality, + bool forward, bool normalize, bool adjoint, ffi::Buffer input, + ffi::Result> output, ffi::Result> workspace) { + // Step 1: Build the s2fftDescriptor based on the input parameters. + size_t work_size = 0; // Variable to hold the workspace size + s2fftDescriptor descriptor = build_descriptor(nside, harmonic_band_limit, reality, forward, normalize, + adjoint, true, work_size); + + // Step 2: Dispatch to either forward or backward transform based on the 'forward' flag. + if (forward) { + return healpix_forward(stream, input, output, workspace, descriptor); } else { - healpix_backward(stream, buffers, descriptor); + return healpix_backward(stream, input, output, workspace, descriptor); } } -#endif // NO_CUDA_COMPILER +/** + * @brief FFI registration for the HEALPix FFT CUDA functions. + * + * Registers the handlers for both C64 and C128 data types with XLA FFI. + * This makes the CUDA FFT functions callable from JAX. + */ +XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda, + ffi::Ffi::Bind() + .Ctx>() + .Attr("nside") + .Attr("harmonic_band_limit") + .Attr("reality") + .Attr("forward") + .Attr("normalize") + .Attr("adjoint") + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda, + ffi::Ffi::Bind() + .Ctx>() + .Attr("nside") + .Attr("harmonic_band_limit") + .Attr("reality") + .Attr("forward") + .Attr("normalize") + .Attr("adjoint") + .Arg>() + .Ret>() + .Ret>()); +/** + * @brief Encapsulates an FFI handler into a nanobind capsule. + * + * This helper function is used to wrap C++ FFI handlers so they can be exposed + * to Python via nanobind. + * + * @tparam T The function type of the FFI handler. + * @param fn Pointer to the FFI handler function. + * @return nb::capsule A nanobind capsule containing the FFI handler. + */ +template +nb::capsule EncapsulateFfiCall(T* fn) { + // Step 1: Assert that the provided function is a valid XLA FFI handler. + static_assert(std::is_invocable_r_v, + "Encapsulated function must be an XLA FFI handler"); + // Step 2: Return a nanobind capsule wrapping the function pointer. + return nb::capsule(reinterpret_cast(fn)); +} + +/** + * @brief Returns a dictionary of all registered FFI handlers. + * + * This function creates a nanobind dictionary that maps string names to + * encapsulated FFI handlers, allowing them to be looked up and called from Python. + * + * @return nb::dict A nanobind dictionary with keys for each handler. + */ nb::dict Registration() { + // Step 1: Create an empty nanobind dictionary. nb::dict dict; - dict["healpix_fft_cuda"] = EncapsulateFunction(healpix_fft_cuda); + // Step 2: Add encapsulated FFI handlers for C64 and C128 to the dictionary. + dict["healpix_fft_cuda_c64"] = EncapsulateFfiCall(healpix_fft_cuda_C64); + dict["healpix_fft_cuda_c128"] = EncapsulateFfiCall(healpix_fft_cuda_C128); + // Step 3: Return the populated dictionary. return dict; } } // namespace s2fft NB_MODULE(_s2fft, m) { + // Step 1: Expose the registration function to Python. m.def("registration", &s2fft::Registration); + // Step 2: Declare and expose build_descriptor functions for C64 and C128 to Python. + // These functions allow Python to query the required workspace size for a given descriptor. + m.def("build_descriptor_C64", [](int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward, + bool normalize, bool adjoint) { + // Step 2a: Build the s2fftDescriptor. + size_t work_size = 0; // Variable to hold the workspace size + s2fft::s2fftDescriptor desc = s2fft::build_descriptor( + nside, harmonic_band_limit, reality, forward, normalize, adjoint, false, work_size); + return work_size; + }); + m.def("build_descriptor_C128", [](int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward, + bool normalize, bool adjoint) { + // Step 2e: Build the s2fftDescriptor. + size_t work_size = 0; // Variable to hold the workspace size + s2fft::s2fftDescriptor desc = s2fft::build_descriptor( + nside, harmonic_band_limit, reality, forward, normalize, adjoint, false, work_size); + return work_size; + }); + // Step 3: Expose a boolean attribute indicating if CUDA support is compiled in. + m.attr("COMPILED_WITH_CUDA") = true; +} - m.def("build_healpix_fft_descriptor", - [](int nside, int harmonic_band_limit, bool reality, bool forward,bool normalize, bool double_precision) { -#ifndef NO_CUDA_COMPILER - size_t work_size; - // Only backward for now - s2fftKernels::fft_norm norm = s2fftKernels::fft_norm::NONE; - if (forward && normalize) { - norm = s2fftKernels::fft_norm::FORWARD; - } else if (!forward && normalize) { - norm = s2fftKernels::fft_norm::BACKWARD; - } else if (forward && !normalize) { - norm = s2fftKernels::fft_norm::BACKWARD; - } else if (!forward && !normalize) { - norm = s2fftKernels::fft_norm::FORWARD; - } - // Always shift - bool shift = true; - s2fft::s2fftDescriptor descriptor(nside, harmonic_band_limit, reality, forward, norm, shift, - double_precision); - - if (double_precision) { - auto executor = std::make_shared>(); - s2fft::PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - executor->Initialize(descriptor, work_size); - return PackDescriptor(descriptor); - } else { - auto executor = std::make_shared>(); - s2fft::PlanCache::GetInstance().GetS2FFTExec(descriptor, executor); - executor->Initialize(descriptor, work_size); - return PackDescriptor(descriptor); - } -#else - print_error(); -#endif - }); +#else // NO_CUDA_COMPILER + +// Step 1: Define a fallback NB_MODULE when CUDA is not compiled. +NB_MODULE(_s2fft, m) { + // Step 1a: Provide a dummy registration function that returns an empty dictionary. + m.def("registration", []() { return nb::dict(); }); + // Step 1b: Indicate that CUDA support is not compiled. + m.attr("COMPILED_WITH_CUDA") = false; } + +#endif // NO_CUDA_COMPILER \ No newline at end of file diff --git a/lib/src/plan_cache.cc b/lib/src/plan_cache.cc index f5e468bf..1dd34cb5 100644 --- a/lib/src/plan_cache.cc +++ b/lib/src/plan_cache.cc @@ -7,46 +7,112 @@ namespace s2fft { -PlanCache::PlanCache() { is_initialized = true; } +/** + * @brief Constructor for PlanCache. + * + * Initializes the `is_initialized` flag to true. + */ +PlanCache::PlanCache() { + // Step 1: Set the initialization flag. + is_initialized = true; +} -HRESULT PlanCache::GetS2FFTExec(s2fftDescriptor &descriptor, - std::shared_ptr> &executor) { - HRESULT hr(E_FAIL); +/** + * @brief Retrieves an s2fftExec instance from the cache or initializes a new one. + * + * This templated method attempts to find an existing s2fftExec instance + * matching the provided descriptor in its internal cache (m_Descriptors32 or m_Descriptors64) + * based on the Complex type T. If a matching instance is found, it is returned. + * Otherwise, a new s2fftExec instance is created, initialized with the descriptor, + * and then stored in the cache before being returned. + * + * @tparam T The complex type (cufftComplex or cufftDoubleComplex) of the s2fftExec instance. + * @param descriptor The s2fftDescriptor containing the parameters for the FFT. + * @param executor A shared_ptr that will point to the retrieved or newly initialized s2fftExec instance. + * @return HRESULT indicating success (S_OK if new, S_FALSE if from cache) or failure. + */ +template +HRESULT PlanCache::GetS2FFTExec(s2fftDescriptor &descriptor, std::shared_ptr> &executor) { + // Step 1: Check if the type is cufftComplex (single precision). + if constexpr (std::is_same_v) { + HRESULT hr(E_FAIL); + // Step 1a: Try to find the descriptor in the single-precision cache. + auto it = m_Descriptors32.find(descriptor); + if (it != m_Descriptors32.end()) { + // Step 1b: If found, retrieve the existing executor and set HR to S_FALSE (found in cache). + executor = it->second; + hr = S_FALSE; + } - auto it = m_Descriptors32.find(descriptor); - if (it != m_Descriptors32.end()) { - executor = it->second; - hr = S_FALSE; - } + // Step 1c: If not found (hr is still E_FAIL), + if (hr == E_FAIL) { + // Step 1d: Initialize a new executor with the descriptor. + hr = executor->Initialize(descriptor); + // Step 1e: If initialization is successful, store the new executor in the cache. + if (SUCCEEDED(hr)) { + m_Descriptors32[descriptor] = executor; + } + } + // Step 1f: Return the HRESULT. + return hr; + } else { // Step 2: If the type is not cufftComplex, it must be cufftDoubleComplex (double precision). + HRESULT hr(E_FAIL); + // Step 2a: Try to find the descriptor in the double-precision cache. + auto it = m_Descriptors64.find(descriptor); + if (it != m_Descriptors64.end()) { + // Step 2b: If found, retrieve the existing executor and set HR to S_FALSE (found in cache). + executor = it->second; + hr = S_FALSE; + } - if (hr == E_FAIL) { - size_t worksize(0); - hr = executor->Initialize(descriptor, worksize); - if (SUCCEEDED(hr)) { - m_Descriptors32[descriptor] = executor; + // Step 2c: If not found (hr is still E_FAIL), + if (hr == E_FAIL) { + // Step 2d: Initialize a new executor with the descriptor. + hr = executor->Initialize(descriptor); + // Step 2e: If initialization is successful, store the new executor in the cache. + if (SUCCEEDED(hr)) { + m_Descriptors64[descriptor] = executor; + } } + // Step 2f: Return the HRESULT. + return hr; } - return hr; } -HRESULT PlanCache::GetS2FFTExec(s2fftDescriptor &descriptor, - std::shared_ptr> &executor) { - HRESULT hr(E_FAIL); - - auto it = m_Descriptors64.find(descriptor); - if (it != m_Descriptors64.end()) { - executor = it->second; - hr = S_FALSE; +/** + * @brief Clears all cached s2fftExec instances. + * + * This method is typically called during application shutdown to release + * all resources held by the cached FFT plans. + */ +void PlanCache::Finalize() { + // Step 1: Check if the cache was initialized. + if (is_initialized) { + // Step 1a: Clear both single and double precision descriptor maps. + m_Descriptors32.clear(); + m_Descriptors64.clear(); } + // Step 2: Reset the initialization flag. + is_initialized = false; +} - if (hr == E_FAIL) { - size_t worksize(0); - hr = executor->Initialize(descriptor, worksize); - if (SUCCEEDED(hr)) { - m_Descriptors64[descriptor] = executor; - } - } - return hr; +/** + * @brief Destructor for PlanCache. + * + * Ensures that Finalize() is called when the PlanCache instance is destroyed, + * performing necessary cleanup. + */ +PlanCache::~PlanCache() { + // Step 1: Call Finalize to clean up resources. + Finalize(); } -} // namespace s2fft +// Explicitly instantiate the templates for the supported complex types. +// This is necessary for the linker to find the concrete implementations of the templated function. +template HRESULT PlanCache::GetS2FFTExec(s2fftDescriptor &descriptor, + std::shared_ptr> &executor); + +template HRESULT PlanCache::GetS2FFTExec( + s2fftDescriptor &descriptor, std::shared_ptr> &executor); + +} // namespace s2fft \ No newline at end of file diff --git a/lib/src/s2fft.cu b/lib/src/s2fft.cu index 99b1fd47..4429972e 100644 --- a/lib/src/s2fft.cu +++ b/lib/src/s2fft.cu @@ -12,24 +12,27 @@ #include #include -#include "s2fft_callbacks.h" +#include "s2fft_kernels.h" namespace s2fft { template -HRESULT s2fftExec::Initialize(const s2fftDescriptor &descriptor, size_t &worksize) { +HRESULT s2fftExec::Initialize(const s2fftDescriptor &descriptor) { + // Step 1: Store the Nside parameter from the descriptor. m_nside = descriptor.nside; + // Step 2: Initialize variables for ring offsets and workspace size. size_t start_index(0); size_t end_index(12 * m_nside * m_nside); size_t nphi(0); + size_t worksize(0); + // Step 3: Determine the cuFFT C2C type based on the complex type. const cufftType C2C_TYPE = get_cufft_type_c2c(Complex({0.0, 0.0})); - const s2fftKernels::fft_norm &norm = descriptor.norm; - const bool &shift = descriptor.shift; - const bool &isDouble = descriptor.double_precision; + // Step 4: Reserve space for upper and lower ring offset vectors. m_upper_ring_offsets.reserve(m_nside - 1); m_lower_ring_offsets.reserve(m_nside - 1); + // Step 5: Calculate and store offsets for polar rings. for (size_t i = 0; i < m_nside - 1; i++) { nphi = 4 * (i + 1); m_upper_ring_offsets.push_back(start_index); @@ -37,119 +40,172 @@ HRESULT s2fftExec::Initialize(const s2fftDescriptor &descriptor, size_t start_index += nphi; end_index -= nphi; - } + } // + // Step 6: Store offsets and number of equatorial rings. m_equatorial_offset_start = start_index; m_equatorial_offset_end = end_index; m_equatorial_ring_num = (end_index - start_index) / (4 * m_nside); - // Plan creation + // Step 7: Create cuFFT plans for polar rings. for (size_t i = 0; i < m_nside - 1; i++) { size_t polar_worksize{0}; int64 upper_ring_offset = m_upper_ring_offsets[i]; int64 lower_ring_offset = m_lower_ring_offsets[i]; + // Step 7a: Create cuFFT handles for forward and inverse plans. cufftHandle plan{}; cufftHandle inverse_plan{}; CUFFT_CALL(cufftCreate(&plan)); CUFFT_CALL(cufftCreate(&inverse_plan)); - // Plans are done on upper and lower polar rings - int rank = 1; // 1D FFT : In our case the rank is always 1 - int batch_size = 2; // Number of rings to transform - int64 n[] = {4 * ((int64)i + 1)}; // Size of each FFT 4 times the ring number (first is 4, second is - // 8, third is 12, etc) + + // Step 7b: Define parameters for 1D FFTs on polar rings. + int rank = 1; // 1D FFT + int batch_size = 2; // Number of rings to transform (upper and lower) + int64 n[] = {4 * ((int64)i + 1)}; // Size of each FFT int64 inembed[] = {0}; // Stride of input data (meaningless but has to be set) - int64 istride = 1; // Distance between consecutive elements in the same batch always 1 since we - // have contiguous data + int64 istride = 1; // Distance between consecutive elements in the same batch int64 idist = lower_ring_offset - - upper_ring_offset; // Distance between the starting points of two consecutive - // batches, it is equal to the distance between the two rings + upper_ring_offset; // Distance between starting points of two consecutive batches int64 onembed[] = {0}; // Stride of output data (meaningless but has to be set) - int64 ostride = 1; // Distance between consecutive elements in the output batch, also 1 since - // everything is done in place - int64 odist = - lower_ring_offset - upper_ring_offset; // Same as idist since we want to transform in place + int64 ostride = 1; // Distance between consecutive elements in the output batch + int64 odist = lower_ring_offset - upper_ring_offset; // Same as idist for in-place transform - // TODO CUFFT_C2C + // Step 7c: Create cuFFT plans for forward and inverse polar transforms. CUFFT_CALL(cufftMakePlanMany64(plan, rank, n, inembed, istride, idist, onembed, ostride, odist, C2C_TYPE, batch_size, &polar_worksize)); + // Step 7d: Update overall maximum workspace size. + worksize = std::max(worksize, polar_worksize); CUFFT_CALL(cufftMakePlanMany64(inverse_plan, rank, n, inembed, istride, idist, onembed, ostride, odist, C2C_TYPE, batch_size, &polar_worksize)); - int64 params[2]; - int64 *params_dev; - params[0] = n[0]; - params[1] = idist; - cudaMalloc(¶ms_dev, 2 * sizeof(int64)); - cudaMemcpy(params_dev, params, 2 * sizeof(int64), cudaMemcpyHostToDevice); - - s2fftKernels::setCallback(plan, inverse_plan, params_dev, shift, false, isDouble, norm); + // Step 7e: Update overall maximum workspace size again. + worksize = std::max(worksize, polar_worksize); + // Step 7f: Store the created plans. m_polar_plans.push_back(plan); m_inverse_polar_plans.push_back(inverse_plan); } - // Equator plan - - // Equator is a matrix with size 4 * m_nside x equatorial_ring_num - // cufftMakePlan1d is enough for this case + // Step 8: Create cuFFT plans for the equatorial ring. size_t equator_worksize{0}; int64 equator_size = (4 * m_nside); - // TODO CUFFT_C2C - // Forward plan + + // Step 8a: Create cuFFT handle for the forward equatorial plan. CUFFT_CALL(cufftCreate(&m_equator_plan)); CUFFT_CALL(cufftMakePlanMany64(m_equator_plan, 1, &equator_size, nullptr, 1, 1, nullptr, 1, 1, C2C_TYPE, m_equatorial_ring_num, &equator_worksize)); - // Inverse plan + // Step 8b: Update overall maximum workspace size. + worksize = std::max(worksize, equator_worksize); + + // Step 8c: Create cuFFT handle for the inverse equatorial plan. CUFFT_CALL(cufftCreate(&m_inverse_equator_plan)); CUFFT_CALL(cufftMakePlanMany64(m_inverse_equator_plan, 1, &equator_size, nullptr, 1, 1, nullptr, 1, 1, C2C_TYPE, m_equatorial_ring_num, &equator_worksize)); - - int64 equator_params[1]; - equator_params[0] = equator_size; - int64 *equator_params_dev; - cudaMalloc(&equator_params_dev, sizeof(int64)); - cudaMemcpy(equator_params_dev, equator_params, sizeof(int64), cudaMemcpyHostToDevice); - - s2fftKernels::setCallback(m_equator_plan, m_inverse_equator_plan, equator_params_dev, shift, true, - isDouble, norm); + // Step 8d: Update overall maximum workspace size again. + worksize = std::max(worksize, equator_worksize); + // Step 9: Store the final maximum workspace size. + this->m_work_size = worksize; return S_OK; } template -HRESULT s2fftExec::Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data) { - // Polar rings ffts*/ - +HRESULT s2fftExec::Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, + Complex *workspace) { + // Step 1: Determine the FFT direction (forward or inverse based on adjoint flag). + const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD; + // Step 2: Extract normalization, shift, and double precision flags from the descriptor. + const s2fftKernels::fft_norm &norm = desc.norm; + const bool &shift = desc.shift; + + // Step 3: Execute FFTs for polar rings. for (int i = 0; i < m_nside - 1; i++) { + // Step 3a: Get upper and lower ring offsets. int upper_ring_offset = m_upper_ring_offsets[i]; - CUFFT_CALL(cufftSetStream(m_polar_plans[i], stream)) - CUFFT_CALL(cufftXtExec(m_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, - CUFFT_FORWARD)); + // Step 3e: Set the CUDA stream and work area for the cuFFT plan. + CUFFT_CALL(cufftSetStream(m_polar_plans[i], stream)); + CUFFT_CALL(cufftSetWorkArea(m_polar_plans[i], workspace)); + // Step 3f: Execute the cuFFT transform. + CUFFT_CALL( + cufftXtExec(m_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, DIRECTION)); } - // Equator fft - CUFFT_CALL(cufftSetStream(m_equator_plan, stream)) + // Step 4: Execute FFT for the equatorial ring. + // Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan. + CUFFT_CALL(cufftSetStream(m_equator_plan, stream)); + CUFFT_CALL(cufftSetWorkArea(m_equator_plan, workspace)); + // Step 4e: Execute the cuFFT transform for the equator. CUFFT_CALL(cufftXtExec(m_equator_plan, data + m_equatorial_offset_start, data + m_equatorial_offset_start, - CUFFT_FORWARD)); + DIRECTION)); + + // Step 5: Launch the custom kernel for normalization and shifting. + switch (norm) { + case s2fftKernels::fft_norm::NONE: + case s2fftKernels::fft_norm::BACKWARD: + // No normalization, only shift if required. + s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 2); + break; + case s2fftKernels::fft_norm::FORWARD: + // Normalize by sqrt(Npix). + s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 0); + break; + case s2fftKernels::fft_norm::ORTHO: + // Normalize by Npix. + s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 1); + break; + default: + return E_INVALIDARG; // Invalid normalization type. + } return S_OK; } template -HRESULT s2fftExec::Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data) { - // Polar rings inverse FFTs +HRESULT s2fftExec::Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, + Complex *workspace) { + // Step 1: Determine the FFT direction (forward or inverse based on adjoint flag). + const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE; + // Step 2: Extract normalization, shift, and double precision flags from the descriptor. + const s2fftKernels::fft_norm &norm = desc.norm; + + // Step 3: Execute inverse FFTs for polar rings. for (int i = 0; i < m_nside - 1; i++) { + // Step 3a: Get upper and lower ring offsets. int upper_ring_offset = m_upper_ring_offsets[i]; - CUFFT_CALL(cufftSetStream(m_inverse_polar_plans[i], stream)) + // Step 3e: Set the CUDA stream and work area for the cuFFT plan. + CUFFT_CALL(cufftSetStream(m_inverse_polar_plans[i], stream)); + CUFFT_CALL(cufftSetWorkArea(m_inverse_polar_plans[i], workspace)); + // Step 3f: Execute the cuFFT transform. CUFFT_CALL(cufftXtExec(m_inverse_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, - CUFFT_INVERSE)); + DIRECTION)); } - // Equator inverse FFT - CUFFT_CALL(cufftSetStream(m_inverse_equator_plan, stream)) + // Step 4: Execute inverse FFT for the equatorial ring. + // Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan. + CUFFT_CALL(cufftSetStream(m_inverse_equator_plan, stream)); + CUFFT_CALL(cufftSetWorkArea(m_inverse_equator_plan, workspace)); + // Step 4e: Execute the cuFFT transform for the equator. CUFFT_CALL(cufftXtExec(m_inverse_equator_plan, data + m_equatorial_offset_start, - data + m_equatorial_offset_start, CUFFT_INVERSE)); - // + data + m_equatorial_offset_start, DIRECTION)); + + // Step 5: Launch the custom kernel for normalization and shifting. + switch (norm) { + case s2fftKernels::fft_norm::NONE: + case s2fftKernels::fft_norm::FORWARD: + // No normalization, do nothing. + break; + case s2fftKernels::fft_norm::BACKWARD: + // Normalize by sqrt(Npix). + s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 0); + break; + case s2fftKernels::fft_norm::ORTHO: + // Normalize by Npix. + s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 1); + break; + default: + return E_INVALIDARG; // Invalid normalization type. + } + return S_OK; } diff --git a/lib/src/s2fft_callbacks.cu b/lib/src/s2fft_callbacks.cu index 937926d8..02349eca 100644 --- a/lib/src/s2fft_callbacks.cu +++ b/lib/src/s2fft_callbacks.cu @@ -1,4 +1,3 @@ - #include #include "hresult.h" #include @@ -10,173 +9,374 @@ namespace s2fftKernels { // Fundamental Functions +/** + * @brief Computes the shifted index for a 1D FFT. + * + * This function calculates the new index after applying an FFT shift, + * which effectively moves the zero-frequency component to the center of the spectrum. + * + * @param offset The original offset (index) of the element. + * @param params A pointer to an array containing FFT parameters: params[0] is n (size of FFT), params[1] is + * dist (distance between batches). + * @return The shifted index. + */ __device__ int64 fft_shift(size_t offset, int64 *params) { + // Step 1: Extract FFT size and distance between batches from parameters. int64 n = params[0]; int64 dist = params[1]; + // Step 2: Determine the offset of the first element in the current batch. int64 first_element_offset = offset < dist ? 0 : dist; + // Step 3: Calculate half the FFT size for shifting. int64 half = n / 2; + // Step 4: Normalize the offset relative to the start of its batch. int64 normalized_offset = offset - first_element_offset; + // Step 5: Apply the FFT shift. int64 shifted_index = normalized_offset + half; + // Step 6: Calculate the final index, ensuring it wraps around correctly within the batch. int64 indx = (shifted_index % n) + first_element_offset; return indx; } +/** + * @brief Computes the shifted index for an equatorial FFT. + * + * This function calculates the new index after applying an FFT shift specifically + * for the equatorial ring, where the data layout might differ slightly. + * + * @param offset The original offset (index) of the element. + * @param params A pointer to an array containing FFT parameters: params[0] is n (size of FFT). + * @return The shifted index. + */ __device__ int64 fft_shift_eq(size_t offset, int64 *params) { + // Step 1: Extract FFT size from parameters. int64 n = params[0]; + // Step 2: Calculate the starting offset of the current ring. int64 first_element_offset = (offset / n) * n; + // Step 3: Calculate the offset within the current ring. int64 offset_in_ring = first_element_offset + offset % n; + // Step 4: Calculate half the FFT size for shifting. int64 half = n / 2; + // Step 5: Apply the FFT shift within the ring. int64 shifted_index = offset_in_ring + half; + // Step 6: Calculate the final index, ensuring it wraps around correctly within the ring. int64 indx = (shifted_index % n) + first_element_offset; return indx; } +/** + * @brief Normalizes a complex element by dividing by the FFT size. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param element Pointer to the complex element to normalize. + * @param size The size of the FFT. + */ template __device__ void normalize(Complex *element, int64 size) { + // Step 1: Calculate the normalization factor. float norm_factor = 1.0f / (float)size; + // Step 2: Apply the normalization factor to the real part. element->x *= norm_factor; + // Step 3: Apply the normalization factor to the imaginary part. element->y *= norm_factor; } +/** + * @brief Normalizes a complex element by dividing by the square root of the FFT size (orthonormalization). + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param element Pointer to the complex element to normalize. + * @param size The size of the FFT. + */ template __device__ void normalize_ortho(Complex *element, int64 size) { + // Step 1: Calculate the orthonormalization factor. float norm_factor = 1.0f / sqrtf((float)size); + // Step 2: Apply the normalization factor to the real part. element->x *= norm_factor; + // Step 3: Apply the normalization factor to the imaginary part. element->y *= norm_factor; } // Callbacks +/** + * @brief cuFFT callback function for applying FFT shift. + * + * This callback is executed by cuFFT to apply a circular shift to the output data. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_shift_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Calculate the shifted index. int64 indx = fft_shift(offset, params); + // Step 4: Store the element at the shifted index. data[indx] = element; } +/** + * @brief cuFFT callback function for applying FFT shift to equatorial data. + * + * This callback is executed by cuFFT to apply a circular shift to the output data + * specifically for the equatorial ring. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_shift_eq_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Calculate the shifted index for the equatorial ring. int64 indx = fft_shift_eq(offset, params); + // Step 4: Store the element at the shifted index. data[indx] = element; } +/** + * @brief cuFFT callback function for applying orthonormalization. + * + * This callback is executed by cuFFT to normalize the output data by 1/sqrt(N). + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_ortho_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using orthonormalization. normalize_ortho(&element, params[0]); + // Step 4: Store the normalized element at the original offset. data[offset] = element; } +/** + * @brief cuFFT callback function for applying standard normalization (1/N). + * + * This callback is executed by cuFFT to normalize the output data by 1/N. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using standard normalization. normalize(&element, params[0]); + // Step 4: Store the normalized element at the original offset. data[offset] = element; } // Declare the callbacks with shifts +/** + * @brief cuFFT callback function for applying orthonormalization and FFT shift. + * + * This callback combines orthonormalization and circular shifting of the output data. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_ortho_shift_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using orthonormalization. normalize_ortho(&element, params[0]); + // Step 4: Calculate the shifted index. int64 indx = fft_shift(offset, params); + // Step 5: Store the normalized element at the shifted index. data[indx] = element; } +/** + * @brief cuFFT callback function for applying standard normalization (1/N) and FFT shift. + * + * This callback combines standard normalization and circular shifting of the output data. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_shift_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using standard normalization. normalize(&element, params[0]); + // Step 4: Calculate the shifted index. int64 indx = fft_shift(offset, params); + // Step 5: Store the normalized element at the shifted index. data[indx] = element; } +/** + * @brief cuFFT callback function for applying orthonormalization and equatorial FFT shift. + * + * This callback combines orthonormalization and circular shifting of the output data + * specifically for the equatorial ring. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_ortho_shift_eq_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using orthonormalization. normalize_ortho(&element, params[0]); + // Step 4: Calculate the shifted index for the equatorial ring. int64 indx = fft_shift_eq(offset, params); + // Step 5: Store the normalized element at the shifted index. data[indx] = element; } +/** + * @brief cuFFT callback function for applying standard normalization (1/N) and equatorial FFT shift. + * + * This callback combines standard normalization and circular shifting of the output data + * specifically for the equatorial ring. + * + * @tparam Complex The complex type (cufftComplex or cufftDoubleComplex). + * @param dataOut Pointer to the output data buffer. + * @param offset The current offset (index) within the output buffer. + * @param element The complex element at the current offset. + * @param callerInfo Pointer to user-defined parameters (params array). + * @param sharedPointer Pointer to shared memory (unused in this callback). + */ template __device__ void fft_norm_shift_eq_cb(void *dataOut, size_t offset, Complex element, void *callerInfo, void *sharedPointer) { + // Step 1: Cast callerInfo to the correct parameter type. int64 *params = (int64 *)callerInfo; + // Step 2: Cast dataOut to the correct complex data type. Complex *data = (Complex *)dataOut; + // Step 3: Normalize the element using standard normalization. normalize(&element, params[0]); + // Step 4: Calculate the shifted index for the equatorial ring. int64 indx = fft_shift_eq(offset, params); + // Step 5: Store the normalized element at the shifted index. data[indx] = element; } -// Ortho double +// Pointers to device-managed cuFFT callback functions for different normalization and shift combinations. +// These are __managed__ to allow access from both host and device code. + +// Ortho double precision callbacks __device__ __managed__ cufftCallbackStoreZ fft_norm_ortho_double_no_shift_ptr = fft_norm_ortho_cb; __device__ __managed__ cufftCallbackStoreZ fft_norm_ortho_double_shift_ptr = fft_norm_ortho_shift_cb; __device__ __managed__ cufftCallbackStoreZ fft_norm_ortho_double_shift_eq_ptr = fft_norm_ortho_shift_eq_cb; -// Ortho float + +// Ortho single precision callbacks __device__ __managed__ cufftCallbackStoreC fft_norm_ortho_float_no_shift_ptr = fft_norm_ortho_cb; __device__ __managed__ cufftCallbackStoreC fft_norm_ortho_float_shift_ptr = fft_norm_ortho_shift_cb; __device__ __managed__ cufftCallbackStoreC fft_norm_ortho_float_shift_eq_ptr = fft_norm_ortho_shift_eq_cb; -// Norm FWD and BWD double +// Standard (1/N) normalization double precision callbacks __device__ __managed__ cufftCallbackStoreZ fft_norm_noshift_double_ptr = fft_norm_cb; __device__ __managed__ cufftCallbackStoreZ fft_norm_shift_double_ptr = fft_norm_shift_cb; __device__ __managed__ cufftCallbackStoreZ fft_norm_shift_eq_double_ptr = fft_norm_shift_eq_cb; -// Norm FWD and BWD float + +// Standard (1/N) normalization single precision callbacks __device__ __managed__ cufftCallbackStoreC fft_norm_noshift_float_ptr = fft_norm_cb; __device__ __managed__ cufftCallbackStoreC fft_norm_shift_float_ptr = fft_norm_shift_cb; __device__ __managed__ cufftCallbackStoreC fft_norm_shift_eq_float_ptr = fft_norm_shift_eq_cb; -// Shifts double +// Shift-only double precision callbacks __device__ __managed__ cufftCallbackStoreZ fft_shift_double_ptr = fft_shift_cb; __device__ __managed__ cufftCallbackStoreZ fft_shift_eq_double_ptr = fft_shift_eq_cb; -// Shifts float + +// Shift-only single precision callbacks __device__ __managed__ cufftCallbackStoreC fft_shift_float_ptr = fft_shift_cb; __device__ __managed__ cufftCallbackStoreC fft_shift_eq_float_ptr = fft_shift_eq_cb; -// This could have been done in a cleaner way perhaps. - +/** + * @brief Returns the appropriate orthonormalization callback function pointer for double precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftNormOrthoDouble(bool equator, bool shift) { + // Step 1: Check if it's an equatorial ring. if (equator) { + // Step 1a: If equatorial, check for shift. if (shift) { return (void **)&fft_norm_ortho_double_shift_eq_ptr; } else { return (void **)&fft_norm_ortho_double_no_shift_ptr; } - } else { + } else { // Step 1b: If not equatorial, check for shift. if (shift) { return (void **)&fft_norm_ortho_double_shift_ptr; } else { @@ -185,14 +385,23 @@ static auto getfftNormOrthoDouble(bool equator, bool shift) { } } +/** + * @brief Returns the appropriate orthonormalization callback function pointer for single precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftNormOrthoFloat(bool equator, bool shift) { + // Step 1: Check if it's an equatorial ring. if (equator) { + // Step 1a: If equatorial, check for shift. if (shift) { return (void **)&fft_norm_ortho_float_shift_eq_ptr; } else { return (void **)&fft_norm_ortho_float_no_shift_ptr; } - } else { + } else { // Step 1b: If not equatorial, check for shift. if (shift) { return (void **)&fft_norm_ortho_float_shift_ptr; } else { @@ -201,14 +410,23 @@ static auto getfftNormOrthoFloat(bool equator, bool shift) { } } +/** + * @brief Returns the appropriate standard normalization callback function pointer for double precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftNormDouble(bool equator, bool shift) { + // Step 1: Check if it's an equatorial ring. if (equator) { + // Step 1a: If equatorial, check for shift. if (shift) { return (void **)&fft_norm_shift_eq_double_ptr; } else { return (void **)&fft_norm_noshift_double_ptr; } - } else { + } else { // Step 1b: If not equatorial, check for shift. if (shift) { return (void **)&fft_norm_shift_double_ptr; } else { @@ -217,14 +435,23 @@ static auto getfftNormDouble(bool equator, bool shift) { } } +/** + * @brief Returns the appropriate standard normalization callback function pointer for single precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftNormFloat(bool equator, bool shift) { + // Step 1: Check if it's an equatorial ring. if (equator) { + // Step 1a: If equatorial, check for shift. if (shift) { return (void **)&fft_norm_shift_eq_float_ptr; } else { return (void **)&fft_norm_noshift_float_ptr; } - } else { + } else { // Step 1b: If not equatorial, check for shift. if (shift) { return (void **)&fft_norm_shift_float_ptr; } else { @@ -233,76 +460,97 @@ static auto getfftNormFloat(bool equator, bool shift) { } } +/** + * @brief Returns the appropriate shift-only callback function pointer for double precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftShiftDouble(bool equator) { + // Step 1: Check if it's an equatorial ring. if (equator) { return (void **)&fft_shift_eq_double_ptr; - } else { + } else { // Step 1a: If not equatorial. return (void **)&fft_shift_double_ptr; } } +/** + * @brief Returns the appropriate shift-only callback function pointer for single precision. + * + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @return A void** pointer to the selected cuFFT callback function. + */ static auto getfftShiftFloat(bool equator) { + // Step 1: Check if it's an equatorial ring. if (equator) { return (void **)&fft_shift_eq_float_ptr; - } else { + } else { // Step 1a: If not equatorial. return (void **)&fft_shift_float_ptr; } } -HRESULT setCallback(cufftHandle forwardPlan, cufftHandle backwardPlan, int64 *params_dev, bool shift, - bool equator, bool doublePrecision, fft_norm norm) { - // Set the callback for the forward and backward +/** + * @brief Sets cuFFT callbacks specifically for a forward FFT plan. + * + * This function configures the cuFFT library to use custom callbacks + * for normalization and shifting operations during forward FFT execution. + * + * @param plan The cuFFT handle for the forward FFT plan. + * @param params_dev Pointer to device memory containing parameters for the callbacks. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param doublePrecision Boolean flag indicating if double precision is used. + * @param norm The FFT normalization type to apply. + * @return HRESULT indicating success or failure. + */ +HRESULT setForwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator, + bool doublePrecision, fft_norm norm) { + // Step 1: Set the callback for the forward plan based on normalization type. switch (norm) { case fft_norm::ORTHO: - // ORTHO double shift - // Shifting always happends in the load callback for the inverse fft + // Step 1a: Orthonormalization with optional shift. if (doublePrecision) { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftNormOrthoDouble(equator, shift), + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormOrthoDouble(equator, shift), CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); - CUFFT_CALL(cufftXtSetCallback(backwardPlan, getfftNormOrthoDouble(equator, false), - CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)) - // ORTHO float shift } else { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftNormOrthoFloat(equator, shift), - CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); - CUFFT_CALL(cufftXtSetCallback(backwardPlan, getfftNormOrthoFloat(equator, false), - CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormOrthoFloat(equator, shift), CUFFT_CB_ST_COMPLEX, + (void **)¶ms_dev)); } break; case fft_norm::BACKWARD: + // Step 1b: Backward normalization. Apply shift only if requested. if (doublePrecision) { if (shift) { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftShiftDouble(equator), + CUFFT_CALL(cufftXtSetCallback(plan, getfftShiftDouble(equator), CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); } - CUFFT_CALL(cufftXtSetCallback(backwardPlan, getfftNormDouble(equator, false), - CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); } else { if (shift) { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftShiftFloat(equator), CUFFT_CB_ST_COMPLEX, + CUFFT_CALL(cufftXtSetCallback(plan, getfftShiftFloat(equator), CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); } - CUFFT_CALL(cufftXtSetCallback(backwardPlan, getfftNormFloat(equator, false), - CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); } break; case fft_norm::FORWARD: + // Step 1c: Forward normalization. Apply normalization and shift. if (doublePrecision) { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftNormDouble(equator, shift), + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormDouble(equator, shift), CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); } else { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftNormFloat(equator, shift), - CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormFloat(equator, shift), CUFFT_CB_ST_COMPLEX, + (void **)¶ms_dev)); } break; case fft_norm::NONE: + // Step 1d: No normalization. Apply shift only if requested. if (shift) { if (doublePrecision) { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftShiftDouble(equator), + CUFFT_CALL(cufftXtSetCallback(plan, getfftShiftDouble(equator), CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); } else { - CUFFT_CALL(cufftXtSetCallback(forwardPlan, getfftShiftFloat(equator), CUFFT_CB_ST_COMPLEX, + CUFFT_CALL(cufftXtSetCallback(plan, getfftShiftFloat(equator), CUFFT_CB_ST_COMPLEX, (void **)¶ms_dev)); } } @@ -311,4 +559,53 @@ HRESULT setCallback(cufftHandle forwardPlan, cufftHandle backwardPlan, int64 *pa return S_OK; } -} // namespace s2fftKernels + +/** + * @brief Sets cuFFT callbacks specifically for a backward FFT plan. + * + * This function configures the cuFFT library to use custom callbacks + * for normalization and shifting operations during backward FFT execution. + * + * @param plan The cuFFT handle for the inverse FFT plan. + * @param params_dev Pointer to device memory containing parameters for the callbacks. + * @param shift Boolean flag indicating whether to apply FFT shifting. + * @param equator Boolean flag indicating if the current operation is for the equatorial ring. + * @param doublePrecision Boolean flag indicating if double precision is used. + * @param norm The FFT normalization type to apply. + * @return HRESULT indicating success or failure. + */ +HRESULT setBackwardCallback(cufftHandle plan, int64 *params_dev, bool shift, bool equator, + bool doublePrecision, fft_norm norm) { + // Step 1: Set the callback for the backward plan based on normalization type. + switch (norm) { + case fft_norm::ORTHO: + // Step 1a: Orthonormalization without shift (shift is handled in forward for ORTHO). + if (doublePrecision) { + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormOrthoDouble(equator, false), + CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)) + } else { + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormOrthoFloat(equator, false), CUFFT_CB_ST_COMPLEX, + (void **)¶ms_dev)); + } + break; + + case fft_norm::BACKWARD: + // Step 1b: Backward normalization without shift. + if (doublePrecision) { + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormDouble(equator, false), + CUFFT_CB_ST_COMPLEX_DOUBLE, (void **)¶ms_dev)); + } else { + CUFFT_CALL(cufftXtSetCallback(plan, getfftNormFloat(equator, false), CUFFT_CB_ST_COMPLEX, + (void **)¶ms_dev)); + } + break; + case fft_norm::FORWARD: + case fft_norm::NONE: + // Step 1c: No normalization or forward normalization for backward plan. + // No callback is set for these cases in the backward plan. + break; + } + + return S_OK; +} +} // namespace s2fftKernels \ No newline at end of file diff --git a/lib/src/s2fft_kernels.cu b/lib/src/s2fft_kernels.cu index 14986200..31bd5b0d 100644 --- a/lib/src/s2fft_kernels.cu +++ b/lib/src/s2fft_kernels.cu @@ -7,74 +7,203 @@ namespace s2fftKernels { -__device__ void computeNphi(int nside, int ring_index, int L, int& nphi, int& offset_ring) { - // Compute number of pixels +// ============================================================================ +// HELPER DEVICE FUNCTIONS +// ============================================================================ + +/** + * @brief Computes the number of pixels in the polar caps for a given Nside. + * + * This function calculates the total number of pixels contained within both + * polar caps (north and south) of a HEALPix sphere for the given Nside parameter. + * + * @param nside The HEALPix Nside parameter. + * @return The number of pixels in both polar caps combined. + */ +__device__ int ncap(int nside) { return 2 * nside * (nside - 1); } + +/** + * @brief Computes the total number of pixels for a given Nside. + * + * This function calculates the total number of pixels in a HEALPix sphere + * for the given Nside parameter. + * + * @param nside The HEALPix Nside parameter. + * @return The total number of pixels (12 * nside^2). + */ +__device__ int npix(int nside) { return 12 * nside * nside; } + +/** + * @brief Computes the maximum ring index for a given Nside. + * + * This function calculates the highest ring index in the HEALPix tessellation + * for the given Nside parameter. + * + * @param nside The HEALPix Nside parameter. + * @return The maximum ring index (4 * nside - 2). + */ +__device__ int rmax(int nside) { return 4 * nside - 2; } + +/** + * @brief Computes the number of pixels and ring offset for a given ring index. + * + * This function calculates the number of pixels (nphi) in a specific ring and + * the offset to the start of that ring in the HEALPix pixel numbering scheme. + * It handles polar caps and equatorial rings differently according to HEALPix geometry. + * + * @param nside The HEALPix Nside parameter. + * @param ring_index The index of the ring (0-based). + * @param L The harmonic band limit (unused in current implementation). + * @param nphi Reference to store the number of pixels in the ring. + * @param offset_ring Reference to store the offset to the start of the ring. + */ +__device__ void compute_nphi_offset_from_ring(int nside, int ring_index, int L, int& nphi, int& offset_ring) { + // Step 1: Compute basic HEALPix parameters int total_pixels = 12 * nside * nside; int total_rings = 4 * nside - 1; int upper_pixels = nside * (nside - 1) * 2; - // offset for original healpix ring - // Sum of all elements from 0 to n is n * (n + 1) / 2 in o(1) time .. times 4 to get the number of - // elements before current ring + // Step 2: Determine ring type and compute nphi and offset + // Use triangular number formula: sum from 0 to n = n * (n + 1) / 2 - // Upper Polar rings + // Step 2a: Upper Polar rings (0 to nside-2) if (ring_index < nside - 1) { nphi = 4 * (ring_index + 1); offset_ring = ring_index * (ring_index + 1) * 2; } - // Lower Polar rings + // Step 2b: Lower Polar rings (3*nside to 4*nside-2) else if (ring_index > 3 * nside - 1) { - // Compute lower pixel offset + // Compute lower pixel offset using symmetry nphi = 4 * (total_rings - ring_index); - nphi = nphi == 0 ? 4 : nphi; + nphi = nphi == 0 ? 4 : nphi; // Handle edge case int reverse_ring_index = total_rings - ring_index; offset_ring = total_pixels - (reverse_ring_index * (reverse_ring_index + 1) * 2); } - // Equatorial ring + // Step 2c: Equatorial rings (nside-1 to 3*nside-1) else { nphi = 4 * nside; offset_ring = upper_pixels + (ring_index - nside + 1) * 4 * nside; } } +/** + * @brief Converts HEALPix pixel index to ring coordinates and pixel information. + * + * This function maps a HEALPix pixel index to its corresponding ring index, + * offset within the ring, number of pixels in the ring, and the start index + * of the ring. It correctly handles all three HEALPix regions: upper polar cap, + * equatorial belt, and lower polar cap. + * + * @param p The HEALPix pixel index (0-based). + * @param nside The HEALPix Nside parameter. + * @param r Reference to store the ring index. + * @param o Reference to store the offset within the ring. + * @param nphi Reference to store the number of pixels in the ring. + * @param r_start Reference to store the starting pixel index of the ring. + */ +__device__ void pixel_to_ring_offset_nphi(long long int p, int nside, int& r, int& o, int& nphi, + int& r_start) { + // Step 1: Compute HEALPix parameters + long long int Ncap = ncap(nside); + long long int Npix = npix(nside); + int Rmax = rmax(nside); + + // Step 2: Determine which region the pixel belongs to and compute coordinates + if (p < Ncap) { + // Step 2a: Upper Polar Cap + double p_d = static_cast(p); + // Use inverse triangular number formula to find ring + int k = static_cast(floor(0.5 * (sqrt(1.0 + 2.0 * p_d) - 1.0))); + r = k; + o = p - 2 * k * (k + 1); + r_start = 2 * k * (k + 1); + nphi = 4 * (k + 1); + } else if (p < Npix - Ncap) { + // Step 2b: Equatorial Belt + long long int q = p - Ncap; + int k = q / (4 * nside); + r = (nside - 1) + k; + o = q % (4 * nside); + o = o < 0 ? 4 * nside + o : o; // Ensure positive offset + r_start = Ncap + 4 * nside * k; + nphi = 4 * nside; + } else { + // Step 2c: Lower Polar Cap (use symmetry with upper cap) + long long int pprime = Npix - 1 - p; + double pprime_d = static_cast(pprime); + int k_south = static_cast(floor(0.5 * (sqrt(1.0 + 2.0 * pprime_d) - 1.0))); + r = Rmax - k_south; + long long o_prime = pprime - 2 * k_south * (k_south + 1); + int nphi_lo = 4 * (k_south + 1); + o = nphi_lo - 1 - o_prime; + r_start = Npix - (2 * k_south * (k_south + 1) + nphi_lo); + nphi = nphi_lo; + } +} + +/** + * @brief Generic inline swap function for device code. + * + * This function swaps the values of two variables of any type T. + * It's used within CUDA kernels for efficient data manipulation. + * + * @tparam T The type of the variables to swap. + * @param a Reference to the first variable. + * @param b Reference to the second variable. + */ template __device__ void inline swap(T& a, T& b) { T c(a); a = b; b = c; } + +// ============================================================================ +// GLOBAL KERNELS +// ============================================================================ + +/** + * @brief CUDA kernel for spectral folding in spherical harmonic transforms. + * + * This kernel performs spectral folding operations on ring-ordered data, + * transforming from Fourier coefficient space to HEALPix pixel space. + * It handles both positive and negative frequency components and applies + * optional FFT shifting. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input data array containing Fourier coefficients per ring. + * @param output Output array for folded HEALPix pixel data. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + * @param shift Flag indicating whether to apply FFT shifting. + */ template __global__ void spectral_folding(complex* data, complex* output, int nside, int L, bool shift) { - // Which ring are we working on + // Step 1: Determine which ring this thread is processing int current_indx = blockIdx.x * blockDim.x + threadIdx.x; if (current_indx >= (4 * nside - 1)) { return; } + // Step 2: Initialize ring parameters int ring_index = current_indx; - // Compute nphi of current ring int nphi(0); int ring_offset(0); - computeNphi(nside, ring_index, L, nphi, ring_offset); - - // ring index - - int ftm_offset = ring_index * (2 * L); - // offset for original healpix ring - // Sum of all elements from 0 to n is n * (n + 1) / 2 in o(1) time .. times 4 to get the number of - // elements before current ring + compute_nphi_offset_from_ring(nside, ring_index, L, nphi, ring_offset); - int slice_start = (L - nphi / 2); - int slice_end = slice_start + nphi; + // Step 3: Compute indices for Fourier coefficient and HEALPix data + int ftm_offset = ring_index * (2 * L); // Offset for this ring's FTM data + int slice_start = (L - nphi / 2); // Start of central slice + int slice_end = slice_start + nphi; // End of central slice - // Fill up the healpix ring + // Step 4: Copy the central part of the spectrum directly for (int i = 0; i < nphi; i++) { int folded_index = i + ring_offset; int target_index = i + ftm_offset + slice_start; - output[folded_index] = data[target_index]; } - // fold the negative part of the spectrum + + // Step 5: Fold the negative part of the spectrum for (int i = 0; i < slice_start; i++) { int folded_index = -(1 + i) % nphi; folded_index = folded_index < 0 ? nphi + folded_index : folded_index; @@ -85,7 +214,8 @@ __global__ void spectral_folding(complex* data, complex* output, int nside, int output[folded_index].x += data[target_index].x; output[folded_index].y += data[target_index].y; } - // fold the positive part of the spectrum + + // Step 6: Fold the positive part of the spectrum for (int i = 0; i < L - nphi / 2; i++) { int folded_index = i % nphi; folded_index = folded_index < 0 ? nphi + folded_index : folded_index; @@ -97,9 +227,9 @@ __global__ void spectral_folding(complex* data, complex* output, int nside, int output[folded_index].y += data[target_index].y; } + // Step 7: Apply FFT shifting if requested if (shift) { int half_nphi = nphi / 2; - // Shift the spectrum for (int i = 0; i < half_nphi; i++) { int origin_index = i + ring_offset; int shifted_index = origin_index + half_nphi; @@ -107,99 +237,53 @@ __global__ void spectral_folding(complex* data, complex* output, int nside, int } } } -template -__global__ void spectral_folding_parallel(complex* data, complex* output, int nside, int L) { - // Which ring are we working on - int current_indx = blockIdx.x * blockDim.x + threadIdx.x; - - // Compute nphi of current ring - int nphi(0); - int offset_ring(0); - computeNphi(nside, current_indx, L, nphi, offset_ring); - - // ring index - int ring_index = current_indx / (2 * L); - // offset for the FTM slice - int offset = current_indx % (2 * L); - int ftm_offset = ring_index * (2 * L); - // offset for original healpix ring - // Sum of all elements from 0 to n is n * (n + 1) / 2 in o(1) time .. times 4 to get the number of - // elements before current ring - - int slice_start = (L - nphi / 2); - int slice_end = slice_start + nphi; - - // Fill up the healpix ring - if (offset >= slice_start && offset < slice_end) { - int center_offset = offset - slice_start; - int indx = center_offset + offset_ring; - - output[indx] = data[current_indx]; - } - __syncthreads(); - // fold the negative part of the spectrum - if (offset < slice_start && true) { - int folded_index = -(1 + offset) % nphi; - folded_index = folded_index < 0 ? nphi + folded_index : folded_index; - int target_index = slice_start - (1 + offset); - - folded_index = folded_index + offset_ring; - target_index = target_index + ftm_offset; - atomicAdd(&output[folded_index].x, data[target_index].x); - atomicAdd(&output[folded_index].y, data[target_index].y); - } - // fold the positive part of the spectrum - __syncthreads(); - if (offset >= slice_end && true) { - int folded_index = (offset - slice_end) % nphi; - folded_index = folded_index < 0 ? nphi + folded_index : folded_index; - int target_index = slice_end + (offset - slice_end); - - folded_index = folded_index + offset_ring; - target_index = target_index + ftm_offset; - atomicAdd(&output[folded_index].x, data[target_index].x); - atomicAdd(&output[folded_index].y, data[target_index].y); - } -} +/** + * @brief CUDA kernel for spectral extension in spherical harmonic transforms. + * + * This kernel performs the inverse operation of spectral folding, extending + * HEALPix pixel data back to full Fourier coefficient space. It maps folded + * frequency components back to their appropriate positions in the extended spectrum. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input array containing folded HEALPix pixel data. + * @param output Output array for extended Fourier coefficients per ring. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + */ template __global__ void spectral_extension(complex* data, complex* output, int nside, int L) { - // few inits + // Step 1: Initialize basic parameters int ftm_size = 2 * L; - // Which ring are we working on int current_indx = blockIdx.x * blockDim.x + threadIdx.x; if (current_indx >= (4 * nside - 1) * ftm_size) { return; } - // Compute nphi of current ring - int nphi(0); - int offset_ring(0); - // ring index + + // Step 2: Determine ring and frequency offset int ring_index = current_indx / (2 * L); - computeNphi(nside, ring_index, L, nphi, offset_ring); + int offset = current_indx % (2 * L); // Frequency offset within this ring - // offset for the FTM slice - int offset = current_indx % (2 * L); - // offset for original healpix ring - // Sum of all elements from 0 to n is n * (n + 1) / 2 in o(1) time .. times 4 to get the number of - // elements before current ring + // Step 3: Get ring parameters + int nphi(0); + int offset_ring(0); + compute_nphi_offset_from_ring(nside, ring_index, L, nphi, offset_ring); + // Step 4: Map frequency components based on their position in spectrum if (offset < L - nphi / 2) { + // Step 4a: Negative frequency part int indx = (-(L - nphi / 2 - offset)) % nphi; indx = indx < 0 ? nphi + indx : indx; indx = indx + offset_ring; output[current_indx] = data[indx]; - } - - // Compute the central part of the spectrum - else if (offset >= L - nphi / 2 && offset < L + nphi / 2) { - int center_offset = offset - /*negative part offset*/ (L - nphi / 2); + } else if (offset >= L - nphi / 2 && offset < L + nphi / 2) { + // Step 4b: Central part of the spectrum (direct mapping) + int center_offset = offset - (L - nphi / 2); int indx = center_offset + offset_ring; output[current_indx] = data[indx]; - } - // Compute the positive part of the spectrum - else { + } else { + // Step 4c: Positive frequency part int reverse_offset = ftm_size - offset; int indx = (L - (int)((nphi + 1) / 2) - reverse_offset) % nphi; indx = indx < 0 ? nphi + indx : indx; @@ -208,33 +292,171 @@ __global__ void spectral_extension(complex* data, complex* output, int nside, in } } +/** + * @brief CUDA kernel for FFT shifting and normalization of HEALPix data. + * + * This kernel applies per-ring normalization and optional FFT shifting to HEALPix + * pixel data. It processes each pixel independently, computing its ring coordinates + * and applying the appropriate transformations based on the ring geometry. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @tparam T The floating-point type (float or double) for normalization. + * @param data Input/output array of HEALPix pixel data. + * @param nside The HEALPix Nside parameter. + * @param apply_shift Flag indicating whether to apply FFT shifting. + * @param norm Normalization type (0=by nphi, 1=by sqrt(nphi), 2=no normalization). + */ +template +__global__ void shift_normalize_kernel(complex* data, int nside, bool apply_shift, int norm) { + // Step 1: Get pixel index and check bounds + long long int p = blockIdx.x * blockDim.x + threadIdx.x; + long long int Npix = npix(nside); + + if (p >= Npix) return; + + // Step 2: Convert pixel index to ring coordinates + int r, o, nphi, r_start; + pixel_to_ring_offset_nphi(p, nside, r, o, nphi, r_start); + + // Step 3: Read and normalize the pixel data + complex element = data[p]; + + if (norm == 0) { + // Step 3a: Normalize by nphi + element.x /= nphi; + element.y /= nphi; + } else if (norm == 1) { + // Step 3b: Normalize by sqrt(nphi) + T norm_val = sqrt((T)nphi); + element.x /= norm_val; + element.y /= norm_val; + } + // Step 3c: No normalization for norm == 2 + __syncthreads(); // Ensure all threads have completed normalization + + // Step 4: Apply FFT shifting if requested + if (apply_shift) { + // Step 4a: Compute shifted position within ring + long long int shifted_o = (o + nphi / 2) % nphi; + shifted_o = shifted_o < 0 ? nphi + shifted_o : shifted_o; + long long int dest_p = r_start + shifted_o; + // printf(" -> CUDA: Applying shift: p=%lld, dest_p=%lld, shifted_o=%lld\n", p, dest_p, shifted_o); + data[dest_p] = element; + } else { + // Step 4b: Write back to original position + data[p] = element; + } +} + +// ============================================================================ +// C++ LAUNCH FUNCTIONS +// ============================================================================ + +/** + * @brief Launches the spectral folding CUDA kernel. + * + * This function configures and launches the spectral_folding kernel with + * appropriate grid and block dimensions. It performs error checking and + * returns the execution status. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input data array containing Fourier coefficients per ring. + * @param output Output array for folded HEALPix pixel data. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + * @param shift Flag indicating whether to apply FFT shifting. + * @param stream CUDA stream for kernel execution. + * @return HRESULT indicating success or failure. + */ template HRESULT launch_spectral_folding(complex* data, complex* output, const int& nside, const int& L, const bool& shift, cudaStream_t stream) { + // Step 1: Configure kernel launch parameters int block_size = 128; int ftm_elements = (4 * nside - 1); int grid_size = (ftm_elements + block_size - 1) / block_size; + // Step 2: Launch the kernel spectral_folding<<>>(data, output, nside, L, shift); + + // Step 3: Check for kernel launch errors checkCudaErrors(cudaGetLastError()); return S_OK; } +/** + * @brief Launches the spectral extension CUDA kernel. + * + * This function configures and launches the spectral_extension kernel with + * appropriate grid and block dimensions. It performs error checking and + * returns the execution status. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param data Input array containing folded HEALPix pixel data. + * @param output Output array for extended Fourier coefficients per ring. + * @param nside The HEALPix Nside parameter. + * @param L The harmonic band limit. + * @param stream CUDA stream for kernel execution. + * @return HRESULT indicating success or failure. + */ template HRESULT launch_spectral_extension(complex* data, complex* output, const int& nside, const int& L, cudaStream_t stream) { - // Launch the kernel + // Step 1: Configure kernel launch parameters int block_size = 128; int ftm_elements = 2 * L * (4 * nside - 1); int grid_size = (ftm_elements + block_size - 1) / block_size; + // Step 2: Launch the kernel spectral_extension<<>>(data, output, nside, L); + // Step 3: Check for kernel launch errors checkCudaErrors(cudaGetLastError()); return S_OK; } -// Specializations +/** + * @brief Launches the shift/normalize CUDA kernel for HEALPix data processing. + * + * This function configures and launches the shift_normalize_kernel with appropriate + * grid and block dimensions. It handles both single and double precision complex types + * and applies the requested normalization and shifting operations. + * + * @tparam complex The complex type (cufftComplex or cufftDoubleComplex). + * @param stream CUDA stream for kernel execution. + * @param data Input/output array of HEALPix pixel data. + * @param nside The HEALPix Nside parameter. + * @param apply_shift Flag indicating whether to apply FFT shifting. + * @param norm Normalization type (0=by nphi, 1=by sqrt(nphi), 2=no normalization). + * @return HRESULT indicating success or failure. + */ +template +HRESULT launch_shift_normalize_kernel(cudaStream_t stream, complex* data, int nside, bool apply_shift, + int norm) { + // Step 1: Configure kernel launch parameters + long long int Npix = 12 * nside * nside; + int block_size = 256; + int grid_size = (Npix + block_size - 1) / block_size; + + // Step 2: Launch kernel with appropriate precision + if constexpr (std::is_same_v) { + shift_normalize_kernel + <<>>((cufftComplex*)data, nside, apply_shift, norm); + } else { + shift_normalize_kernel + <<>>((cufftDoubleComplex*)data, nside, apply_shift, norm); + } + + // Step 3: Check for kernel launch errors + checkCudaErrors(cudaGetLastError()); + return S_OK; +} + +// ============================================================================ +// C++ TEMPLATE SPECIALIZATIONS +// ============================================================================ + +// Explicit template specializations for spectral folding functions template HRESULT launch_spectral_folding(cufftComplex* data, cufftComplex* output, const int& nside, const int& L, const bool& shift, cudaStream_t stream); @@ -243,10 +465,19 @@ template HRESULT launch_spectral_folding(cufftDoubleComplex* const int& L, const bool& shift, cudaStream_t stream); +// Explicit template specializations for spectral extension functions template HRESULT launch_spectral_extension(cufftComplex* data, cufftComplex* output, const int& nside, const int& L, cudaStream_t stream); template HRESULT launch_spectral_extension(cufftDoubleComplex* data, cufftDoubleComplex* output, const int& nside, const int& L, cudaStream_t stream); -} // namespace s2fftKernels \ No newline at end of file +// Explicit template specializations for shift/normalize functions +template HRESULT launch_shift_normalize_kernel(cudaStream_t stream, cufftComplex* data, + int nside, bool apply_shift, int norm); + +template HRESULT launch_shift_normalize_kernel(cudaStream_t stream, + cufftDoubleComplex* data, int nside, + bool apply_shift, int norm); + +} // namespace s2fftKernels diff --git a/notebooks/JAX_CUDA_HEALPix.ipynb b/notebooks/JAX_CUDA_HEALPix.ipynb index 76392d2c..f0401a90 100644 --- a/notebooks/JAX_CUDA_HEALPix.ipynb +++ b/notebooks/JAX_CUDA_HEALPix.ipynb @@ -1,350 +1,560 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# __S2FFT CUDA Implementation__\n", - "---\n", - "\n", - "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipynb)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "IN_COLAB = 'google.colab' in sys.modules\n", - "\n", - "# Install s2fft and data if running on google colab.\n", - "if IN_COLAB:\n", - " !pip install s2fft &> /dev/null" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Short comparaison between the pure JAX implementation and the CUDA implementation of the S2FFT algorithm." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "from jax import numpy as jnp\n", - "import argparse\n", - "import time\n", - "\n", - "jax.config.update(\"jax_enable_x64\", True)\n", - "\n", - "from s2fft.utils.healpix_ffts import healpix_fft_jax, healpix_ifft_jax, healpix_fft_cuda, healpix_ifft_cuda\n", - "\n", - "import numpy as np\n", - "import s2fft \n", - "\n", - "from jax._src.numpy.util import promote_dtypes_complex\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def run_fwd_test(nside):\n", - " L = 2 * nside \n", - "\n", - " total_pixels = 12 * nside**2\n", - " arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, ))\n", - "\n", - " start = time.perf_counter()\n", - " cuda_res = healpix_fft_cuda(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " cuda_jit_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " cuda_res = healpix_fft_cuda(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " cuda_run_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " jax_res = healpix_fft_jax(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " jax_jit_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " jax_res = healpix_fft_jax(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " jax_run_time = end - start\n", - "\n", - " method = \"jax_healpy\"\n", - " sampling = \"healpix\"\n", - " (arr,) = promote_dtypes_complex(arr)\n", - " start = time.perf_counter()\n", - " flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method).block_until_ready()\n", - " end = time.perf_counter()\n", - " healpy_jit_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method).block_until_ready()\n", - " end = time.perf_counter()\n", - " healpy_run_time = end - start\n", - "\n", - " print(f\"For nside {nside}\")\n", - " print(f\" -> FWD\")\n", - " print(f\" -> -> cuda_jit_time: {cuda_jit_time}, cuda_run_time: {cuda_run_time}\")\n", - " print(f\" -> -> jax_jit_time: {jax_jit_time}, jax_run_time: {jax_run_time}\")\n", - " print(f\" -> -> healpy_jit_time: {healpy_jit_time}, healpy_run_time: {healpy_run_time}\")\n", - "\n", - " return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time\n", - "\n", - "\n", - "def run_bwd_test(nside):\n", - "\n", - " L = 2 * nside\n", - " ftm_shape = (4 * nside - 1, 2 * L)\n", - " ftm_size = ftm_shape[0] * ftm_shape[1]\n", - "\n", - " arr = jax.random.normal(jax.random.PRNGKey(0), ftm_shape)\n", - "\n", - " start = time.perf_counter()\n", - " cuda_res = healpix_ifft_cuda(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " cuda_jit_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " cuda_res = healpix_ifft_cuda(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " cuda_run_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " jax_res = healpix_ifft_jax(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - "\n", - " jax_jit_time = end - start\n", - " \n", - " start = time.perf_counter()\n", - " jax_res = healpix_ifft_jax(arr, L, nside,reality=False).block_until_ready()\n", - " end = time.perf_counter()\n", - " jax_run_time = end - start\n", - "\n", - " method = \"jax_healpy\"\n", - " sampling = \"healpix\"\n", - " rng = np.random.default_rng(23457801234570)\n", - " flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n", - "\n", - " start = time.perf_counter()\n", - " f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)\n", - " end = time.perf_counter()\n", - " healpy_jit_time = end - start\n", - "\n", - " start = time.perf_counter()\n", - " f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)\n", - " end = time.perf_counter()\n", - " healpy_run_time = end - start\n", - "\n", - " print(f\"For nside {nside}\")\n", - " print(f\" -> BWD\")\n", - " print(f\" -> -> cuda_jit_time: {cuda_jit_time}, cuda_run_time: {cuda_run_time}\")\n", - " print(f\" -> -> jax_jit_time: {jax_jit_time}, jax_run_time: {jax_run_time}\")\n", - " print(f\" -> -> healpy_jit_time: {healpy_jit_time}, healpy_run_time: {healpy_run_time}\")\n", - "\n", - " return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "jax.clear_caches()" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "For nside 4\n", - " -> FWD\n", - " -> -> cuda_jit_time: 0.0005623459999242186, cuda_run_time: 0.0002589869998246286\n", - " -> -> jax_jit_time: 0.00023036399988995981, jax_run_time: 0.0001553519998651609\n", - " -> -> healpy_jit_time: 0.003654524000012316, healpy_run_time: 0.00570670499996595\n", - "For nside 4\n", - " -> BWD\n", - " -> -> cuda_jit_time: 0.0003901920001680992, cuda_run_time: 0.0005790029999843682\n", - " -> -> jax_jit_time: 0.0004877889998624596, jax_run_time: 0.00042751199998747325\n", - " -> -> healpy_jit_time: 0.004256186000020534, healpy_run_time: 0.004342149000194695\n", - "For nside 8\n", - " -> FWD\n", - " -> -> cuda_jit_time: 0.0005613310001990612, cuda_run_time: 0.0010512769999877492\n", - " -> -> jax_jit_time: 0.0015170009999110334, jax_run_time: 0.0028007529999740655\n", - " -> -> healpy_jit_time: 0.01888900099993407, healpy_run_time: 0.020618764999881023\n", - "For nside 8\n", - " -> BWD\n", - " -> -> cuda_jit_time: 0.0009404789998370688, cuda_run_time: 0.0007269820000601612\n", - " -> -> jax_jit_time: 0.001543406999871877, jax_run_time: 0.0008582420000493585\n", - " -> -> healpy_jit_time: 0.005325634999962858, healpy_run_time: 0.006471215000146913\n", - "For nside 16\n", - " -> FWD\n", - " -> -> cuda_jit_time: 0.0004737690001093142, cuda_run_time: 0.00029633700000886165\n", - " -> -> jax_jit_time: 0.0011566660000426054, jax_run_time: 0.0006750920001650229\n", - " -> -> healpy_jit_time: 0.017174200999988898, healpy_run_time: 0.011208771000156048\n", - "For nside 16\n", - " -> BWD\n", - " -> -> cuda_jit_time: 0.00030138499982967915, cuda_run_time: 0.0003267360000336339\n", - " -> -> jax_jit_time: 0.0005259600000044884, jax_run_time: 0.0003649550001227908\n", - " -> -> healpy_jit_time: 0.005033792000176618, healpy_run_time: 0.01343913400000929\n", - "For nside 32\n", - " -> FWD\n", - " -> -> cuda_jit_time: 0.0007112130001587502, cuda_run_time: 0.0005518440000287228\n", - " -> -> jax_jit_time: 0.005327952000016012, jax_run_time: 0.002135986999974193\n", - " -> -> healpy_jit_time: 0.05451428600008512, healpy_run_time: 0.045718837000094936\n", - "For nside 32\n", - " -> BWD\n", - " -> -> cuda_jit_time: 0.0007191470001544076, cuda_run_time: 0.0011659209999379527\n", - " -> -> jax_jit_time: 0.0011368859998128755, jax_run_time: 0.001248700999894936\n", - " -> -> healpy_jit_time: 0.015641461000086565, healpy_run_time: 0.027776794999908816\n" - ] - } - ], - "source": [ - "fwd_times = []\n", - "bwd_times = []\n", - "nsides = [4 , 8, 16, 32]\n", - "for nside in nsides:\n", - " fwd_times.append(run_fwd_test(nside))\n", - " bwd_times.append(run_bwd_test(nside))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import seaborn as sns\n", - "sns.plotting_context(\"poster\")\n", - "sns.set(font_scale=1.4)\n", - "\n", - "\n", - "def plot_times(title, nsides, chrono_times):\n", - "\n", - " # Extracting times from the chrono_times\n", - " cuda_jit_times = [times[0] for times in chrono_times]\n", - " cuda_run_times = [times[1] for times in chrono_times]\n", - " jax_jit_times = [times[2] for times in chrono_times]\n", - " jax_run_times = [times[3] for times in chrono_times]\n", - " healpy_jit_times = [times[4] for times in chrono_times]\n", - " healpy_run_times = [times[5] for times in chrono_times]\n", - "\n", - " # Create subplots\n", - " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))\n", - "\n", - " f2 = lambda a: np.log2(a)\n", - " g2 = lambda b: b**2\n", - "\n", - "\n", - " # Plot for JIT times\n", - " ax1.plot(nsides, cuda_jit_times, 'g-o', label='ours')\n", - " ax1.plot(nsides, jax_jit_times, 'b-o', label='s2fft base')\n", - " ax1.plot(nsides, healpy_jit_times, 'r-o', label='Healpy')\n", - " ax1.set_title('Compilation Times (first run)')\n", - " ax1.set_xlabel('nside')\n", - " ax1.set_ylabel('Time (seconds)')\n", - " ax1.set_xscale('function', functions=(f2, g2))\n", - " ax1.set_xticks(nsides)\n", - " ax1.set_xticklabels(nsides)\n", - " ax1.legend()\n", - " ax1.grid(True, which=\"both\", ls=\"--\")\n", - "\n", - " # Plot for Run times\n", - " ax2.plot(nsides, cuda_run_times, 'g-o', label='ours')\n", - " ax2.plot(nsides, jax_run_times, 'b-o', label='s2fft base')\n", - " ax2.plot(nsides, healpy_run_times, 'r-o', label='Healpy')\n", - " ax2.set_title('Execution Times')\n", - " ax2.set_xlabel('nside')\n", - " ax2.set_ylabel('Time (seconds)')\n", - " ax2.set_xscale('function', functions=(f2, g2))\n", - " ax2.set_xticks(nsides)\n", - " ax2.set_xticklabels(nsides)\n", - " ax2.legend()\n", - " ax2.grid(True, which=\"both\", ls=\"--\")\n", - "\n", - " # Set the overall title for the figure\n", - " fig.suptitle(title, fontsize=16)\n", - "\n", - " # Show the plots\n", - " plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust rect to make space for the suptitle\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# __S2FFT CUDA Implementation__\n", + "---\n", + "\n", + "[![colab image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipynb)" + ] + }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAKzCAYAAABlBC9iAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeZyN5f/H8dc5s6/2fRsJI2sJIdmXaJFEq7VosaVFJEuypLKUUlSoX4uIFiQqInuYUEiyJdnCYGbMnHPu3x/nOyfHLGY4c5a538/HYx7dc5373OdznfNx+sx1rnNdFsMwDERERERERERERETEL1h9HYCIiIiIiIiIiIiI/EeDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiuRQXF4fFYsn2Z8qUKb4OMyCsXLkSi8VCs2bNcnW/yz3/FouFL774IsPjXO7n9OnTOTrv0p/s4t+/f/8VXbNHjx7Af/m2f//+XD+/IiIiIhKYgn0dgIiIiEigaty4Mddee22mt1133XVejsac2rZtS8mSJTO9rXz58pm2d+/ePcvrhYaGZnr7P//8w7fffpvl/ePj47O8ZnR0dKb3+eOPP1izZg1RUVF07tw5w+0333xzltcUERERkfzNYhiG4esgRERERAJJXFwcBw4cYNasWa7ZkHJlVq5cSfPmzWnatCkrV67M8f0sFgsAK1asyNEs3fTHAbiS8vdq75+Z2bNn07NnTypUqJDtLNq9e/eSlpZGpUqVCAkJ8chji4iIiIh/00xbERERERE/VqlSJV+HICIiIiJepjVtRURERLzgr7/+on///lSuXJnw8HAKFChA48aNeeedd7Db7RnOnz17tmtd03///ZdBgwZRqVIlwsLCaNasGadPnyYoKIhChQrhcDjc7vvZZ5+51kVdsmSJ220XLlwgMjKS8PBwkpOTXe2//fYbI0eOpHHjxpQpU4bQ0FCKFClCq1at+OyzzzLt08Xr0SYlJTFixAiqVatGZGQkcXFxbud+8MEH1KtXj8jISAoXLky7du1YvXr1FT6b5pLVmrbNmjXDYrGwcuVK1q9fT4cOHShSpAgxMTE0bdrU7fldunQpLVu2pFChQkRHR9O6dWu2bNmS5WOeOnWKkSNHUqdOHWJiYoiMjKRmzZq89NJLJCUlZTjf4XAwY8YMGjduTMGCBQkJCaF48eLUrl2b/v37az1eERERkVzSTFsRERGRPLZp0ybatWvHv//+S/ny5enYsSNnzpxh5cqVrF27loULF/LVV18RGhqa4b4nTpzgxhtv5PTp0zRp0oS6desSGhpKwYIFqVu3Lps2beLnn3+mfv36rvt89913bsft27d3/b5mzRqSk5Np3rw5ERERrvZJkybx3nvvER8fT82aNSlYsCAHDx5kxYoVfP/996xfv55JkyZl2r+UlBSaNWvGb7/9xi233ELt2rU5efKk6/aBAwfy+uuvY7VaufnmmyldujTbtm2jWbNm9O/f/6qeW4HFixczZcoUatasSevWrdm9ezerVq2idevW/PDDD2zdupUBAwZw00030aZNGxISEvjuu+9o2rQpW7duzbAu82+//Ua7du04dOgQpUqV4uabbyYkJISNGzfywgsv8Pnnn7Ny5UoKFCjgus/DDz/MrFmzCA8P5+abb6ZYsWL8+++//Pnnn0ybNo2WLVtmGMgXERERkaxp0FZEREQkD124cIF77rmHf//9l0cffZTXX3/dtS7pn3/+ScuWLfn2228ZPXo0Y8eOzXD/xYsX07JlSxYsWEBsbKzbba1atWLTpk189913GQZtS5cuzYULF9wGcNNvS7/vxR566CGGDRvGNddc49a+e/duWrVqxeTJk7n33nvdHifdhg0bqFWrFn/88UeGTcEWL17M66+/TlRUFN988w1NmjRx3TZ+/HiGDRuW5XMnOfPaa6/xwQcf8OCDD7rannrqKSZNmkSvXr04fPgwy5Yto2XLlgDY7Xa6du3K559/zssvv8zMmTNd90tOTuaOO+7g0KFDDB8+nBdeeMH1YUJSUhIPP/wwn3zyCU8++STvv/8+AAcPHmTWrFmULVuWTZs2ZciBnTt3EhUVlddPg4iIiEi+ouURRERERK5Qz549XcsQXPxz8cZY8+bN48CBA5QuXZopU6a4bSR1zTXX8OqrrwLwxhtvkJKSkuExQkJCmDFjRoYBW/hv4HX58uWutj///JN9+/bRunVrWrRowfbt2zl69Kjr9qwGbZs2bZphwBagatWqvPDCCwDMnz8/y+di2rRpGQbrAKZMmQJAv3793AZsAYYOHUqdOnWyvGZONG/ePNPXILsN4jI732KxMHv27KuKxVc6d+7sNmAL8PzzzwPOQffHHnvMNWALEBQU5Bos//77793uN2fOHPbu3cttt93GmDFj3GZ/R0ZGMmPGDIoXL86HH37IqVOnAFz5dcMNN2SaA9WqVaN8+fIe6KmIiIiIeWimrYiIiMgVaty4cYavlgPEx8e7jleuXAnAvffeS1hYWIZzO3XqRKFChTh16hSbN2+mcePGbrdff/31mQ6mpj9+REQE69atIykpicjISNegbOvWrTl//jzz5s3ju+++44EHHuD06dNs3ryZggULcuONN2a43rlz5/jmm2/YunUrJ06cIDU1FYAjR44AzgHAzBQvXjzDgCyAzWbjp59+AsgwqJiuW7duJCQkZHpbTrRt2zbTgcKbb745y/t079490/bMXstAcPHyF+kKFy5MkSJFOHnyZKa3V65cGYC///7brX3x4sUAdO3aNdPHio6O5sYbb2TJkiVs2rSJNm3aEB8fT0xMDEuWLGHs2LHcf//9VKxY8Wq7JSIiImJqGrQVERERuUIPP/xwtjM6AQ4fPgyQ5SCWxWKhYsWKnDp1ynXuxbJbBzQsLIybb76Z5cuXs3r1atq2bct3332HxWKhVatWnD9/HsA1aPvDDz/gcDho3rw5Vqv7F66+/vprevbs6bYW7aUSExMzbc8qxpMnT7pmD2fV/6sd3HvuuefcZjbnRKDOqM1KVrNYo6OjOXnyZKa3x8TEAM7lOy72559/As7lMh566KFsH/f48eOua82aNYuePXsyfPhwhg8fTqlSpbjpppto164d999/P9HR0bnul4iIiIiZadBWRERExI9dvFlYZlq1asXy5ctZvnw5bdq04YcffqBmzZqUKFECcA6Kps++zWpphMOHD9O1a1eSk5N59tlneeCBB4iLiyM6Ohqr1cqyZcto27YthmFcUYySty4dgM/t7RdzOBwAtGvXzpVDWalQoYLr+O6776ZVq1Z89dVXrF69mjVr1rBw4UIWLlzIiBEjWL58OTVr1sxxHCIiIiJmp0FbERERkTxUpkwZ4L8ZjJnZt2+f27m5kT4A+91337F161ZOnjzp9vX/Vq1aMXPmTHbt2pXloO3XX39NcnIyd911Fy+//HKGx9izZ0+u4wIoUqQIYWFhXLhwgf3791O9evUM5+zfv/+Kri15o1y5cuzatYvevXvTuXPnXN23QIECbjN0Dx06RP/+/fnyyy/p168fP/74Y16ELCIiIpIvaSMyERERkTyU/tX9uXPnZrrR2MKFCzl16hQxMTHUrVs319e//vrrKVKkCNu2bePjjz8GnOvZpksfoH3vvffYs2cP5cqVo0qVKm7X+PfffwH3mZPpDMNwXTe3goODXWv0fvTRR5me8+GHH17RtSVv3HrrrQB89tlnV32tcuXKMXr0aICrWrdYRERExIw0aCsiIiKSh+655x7Kly/P33//zeDBg7HZbK7b9u3bx1NPPQVA//79CQ8Pz/X1LRYLLVq0wDAM3nzzTUJDQ7nllltct7ds2RKLxcK0adOAjLNsAapVqwbA/PnzXZuOAdjtdkaMGMHatWtzHVe6QYMGAfDGG29kuM7EiRPZsmXLFV9bPK9Pnz5UqFCBefPmMWTIEM6ePZvhnH/++YeZM2e6ft+6dStz584lOTk5w7lff/01kPkHAiIiIiKSNS2PICIiIpKHwsLCmD9/Pu3atWP69OksWbKEm266ibNnz/LDDz+QkpJC27ZtGTly5BU/RqtWrZg3bx4pKSk0b96cyMhI121FihShTp06bN261XXupW6//Xbq1q3L5s2bqVKlCk2bNiUqKooNGzbw999/M2TIkEyXTciJ22+/nSeeeII333yTJk2acMstt1CqVCm2bdvGzp07GThwIFOnTr2yjovHRUVFsXjxYm677TYmTpzIjBkzqFWrFmXLliUpKYnff/+dnTt3Urx4cR555BEADhw4wL333ktERAQ33HAD5cqVw2azsX37dnbv3k1oaCgTJ070cc9EREREAotm2oqIiIjksXr16pGQkMATTzxBUFAQCxcuZPXq1Vx//fVMnz6dRYsWERoaesXXv3ggNrNB2fQ2i8VCy5YtM9weHBzMypUrGTZsGGXKlOH7779n5cqVXH/99axbt4527dpdcWwA06ZN4/333+f6669n/fr1LFmyhFKlSvH999/TsWPHq7q2eF716tXZtm0bEydOpFq1amzbto158+axYcMGoqKiePrpp1m4cKHr/JtuuokJEybQvHlz/v77b7766iuWLVtGUFAQTzzxBNu2bbvqHBIRERExG4uR1TbAIiIiIiIiIiIiIuJ1mmkrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyIiIiIiIiIiIuJHNGgrIiIiIiIiIiIi4kc0aCsiIiIiIiIiIiLiRzRoKyI+YbFYGDVqlOv32bNnY7FY2L9/v8ceY//+/VgsFmbPnu2xa3paXFwcPXr08HUYueZwOKhRowZjx451a9+0aRONGjUiKioKi8VCQkICo0aNwmKx+ChS77v33nvp0qWLr8MQERER8bmVK1disVhYuXKlr0PJ0qV/l4iI+AsN2ooEsL1799K3b1+uueYawsPDiY2NpXHjxkydOpXk5GRfh+c1H3/8MVOmTPF1GMB/hWlOfgLZJ598wqFDh+jXr5+rLS0tjXvuuYd///2XyZMn8+GHH1KhQgWPP/Zvv/3GqFGjPDrA70lDhgzh888/55dffvF1KCIiIuLn0icuZPWzfv16X4eYI2+99ZbfTJS43HOa/hMXF+frUEVEsmUxDMPwdRAiknuLFy/mnnvuISwsjG7dulGjRg1SU1P56aef+Pzzz+nRowczZszwdZhZSklJITg4mODgYMBZXPXs2ZN9+/bluoC67bbb2LFjR4ZBPMMwuHDhAiEhIQQFBXko8uwdPXqU5cuXu7UNHTqU6Ohonn/+ebf2Bx98kAsXLmC1WgkJCfFKfJ5Sp04dGjRowDvvvONq27VrF9WqVWPmzJk8/PDDrnabzYbNZiM8PNwjjz1//nzuueceVqxYQbNmzTxyTU9r0KABVatW5YMPPvB1KCIiIuLH0mvgF198kYoVK2a4vV27dhQtWtQHkeVOjRo1KFq0aIYZtQ6Hg9TUVEJDQ7FavTNn7M8//2Tt2rVubQ8//DD169enT58+rrbo6Gg6duyY4e8SERF/oXclkQC0b98+7r33XipUqMAPP/xAqVKlXLc98cQT/PHHHyxevNiHEV6epwbwsmOxWLzyOBcrUaIEDz74oFvbhAkTKFq0aIZ2gLCwMG+F5jFbt27ll19+4bXXXnNrP3bsGAAFCxZ0a89JEZxe0Hvr9Tp//jxRUVF5dv0uXbowcuRI3nrrLaKjo/PscURERCR/uPXWW7nxxht9HYbHWa1Wr9fj11xzDddcc41b26OPPso111yTaT3u7fhERHJKyyOIBKCJEydy7tw53nvvPbcB23TXXnstAwcOdP1us9kYM2YMlSpVIiwsjLi4OIYNG8aFCxfc7hcXF8dtt93GypUrufHGG4mIiKBmzZquT8wXLFhAzZo1CQ8Pp27dumzdutXt/j169CA6Opo///yTtm3bEhUVRenSpXnxxRe5dFJ/TtaO+vLLL+nQoQOlS5cmLCyMSpUqMWbMGOx2u+ucZs2asXjxYg4cOJDhq05ZrWn7ww8/0KRJE6KioihYsCB33nknO3fudDsnfR3WP/74gx49elCwYEEKFChAz549SUpKyjbu3Lh0Tdv0r3P99NNPDBgwgGLFilGwYEH69u1Lamoqp0+fplu3bhQqVIhChQrx7LPPZnhuHQ4HU6ZMoXr16oSHh1OiRAn69u3LqVOn3M77+eefadu2LUWLFiUiIoKKFSvSq1evy8b8xRdfEBoayi233OJq69GjB02bNgXgnnvuwWKxuGbBZramrcVioV+/fnz00UdUr16dsLAwli5dCsCnn35K3bp1iYmJITY2lpo1azJ16lTX83PPPfcA0Lx5c9drnt06ael5uXfvXtq3b09MTAwPPPBAps9/umbNmrnN4k1f9uKzzz5j7NixlC1blvDwcFq2bMkff/yR4f6tW7fm/PnzGWZdi4iIiFyJkSNHYrVa+f77793a+/TpQ2hoqNuyTBs2bKBdu3YUKFCAyMhImjZtypo1azJc8/Dhw/Tu3dtVa1esWJHHHnuM1NRUIPMaDjLuRREXF8evv/7Kjz/+6KrN0uuorNa0nTdvHnXr1iUiIsI1ueHw4cNu56TXcIcPH6Zjx45ER0dTrFgxnn76abe/B67WpX+XpPf7999/58EHH6RAgQIUK1aMF154AcMwOHToEHfeeSexsbGULFkyw0QGgAsXLjBy5EiuvfZawsLCKFeuHM8++2yGv7+WL1/OzTffTMGCBYmOjqZq1aoMGzbMY30TkcCmmbYiAejrr7/mmmuuoVGjRjk6/+GHH2bOnDl07tyZp556ig0bNjB+/Hh27tzJwoUL3c79448/uP/+++nbty8PPvggr776Krfffjtvv/02w4YN4/HHHwdg/PjxdOnShd27d7t91clut9OuXTtuuukmJk6cyNKlSxk5ciQ2m40XX3wxV/2cPXs20dHRDB48mOjoaH744QdGjBhBYmIir7zyCgDPP/88Z86c4a+//mLy5MkA2c5s/O6777j11lu55pprGDVqFMnJybzxxhs0btyYLVu2ZFiaoUuXLlSsWJHx48ezZcsW3n33XYoXL87LL7+cq77kVv/+/SlZsiSjR49m/fr1zJgxg4IFC7J27VrKly/PuHHjWLJkCa+88go1atSgW7durvv27dvX9VW7AQMGsG/fPqZNm8bWrVtZs2YNISEhHDt2jDZt2lCsWDGee+45ChYsyP79+1mwYMFlY1u7di01atRwW9Khb9++lClThnHjxjFgwADq1atHiRIlsr3ODz/8wGeffUa/fv0oWrQocXFxLF++nPvuu4+WLVu6nuOdO3eyZs0aBg4cyC233MKAAQN4/fXXGTZsGNWqVQNw/TcrNpuNtm3bcvPNN/Pqq68SGRl52X5mZsKECVitVp5++mnOnDnDxIkTeeCBB9iwYYPbeddddx0RERGsWbOGu+6664oeS0RERMzjzJkznDhxwq3NYrFQpEgRAIYPH87XX39N79692b59OzExMXz77bfMnDmTMWPGULt2bcBZX916663UrVvXNdA7a9YsWrRowerVq6lfvz4Af//9N/Xr1+f06dP06dOH+Ph4Dh8+zPz580lKSiI0NDTHsU+ZMoX+/fu7LQeWXR2YXqfWq1eP8ePHc/ToUaZOncqaNWvYunWr27e27HY7bdu2pUGDBrz66qt89913vPbaa1SqVInHHnssxzFeia5du1KtWjUmTJjA4sWLeemllyhcuDDvvPMOLVq04OWXX+ajjz7i6aefpl69eq4JDQ6HgzvuuIOffvqJPn36UK1aNbZv387kyZP5/fff+eKLLwD49ddfue2226hVqxYvvvgiYWFh/PHHH5kOsIuISRkiElDOnDljAMadd96Zo/MTEhIMwHj44Yfd2p9++mkDMH744QdXW4UKFQzAWLt2ravt22+/NQAjIiLCOHDggKv9nXfeMQBjxYoVrrbu3bsbgNG/f39Xm8PhMDp06GCEhoYax48fd7UDxsiRI12/z5o1ywCMffv2udqSkpIy9Kdv375GZGSkkZKS4mrr0KGDUaFChQzn7tu3zwCMWbNmudrq1KljFC9e3Dh58qSr7ZdffjGsVqvRrVs3V9vIkSMNwOjVq5fbNe+66y6jSJEiGR4rO9WrVzeaNm2a6W0VKlQwunfv7vo9/Xlo27at4XA4XO0NGzY0LBaL8eijj7rabDabUbZsWbdrr1692gCMjz76yO1xli5d6ta+cOFCAzA2bdqUq74YhmGULVvWuPvuuzO0r1ixwgCMefPmubWnP5cXAwyr1Wr8+uuvbu0DBw40YmNjDZvNluXjz5s3L0PuZSc9L5977rkMt136/Kdr2rSp2/Oa3rdq1aoZFy5ccLVPnTrVAIzt27dnuEaVKlWMW2+9NUcxioiIiDml136Z/YSFhbmdu337diM0NNR4+OGHjVOnThllypQxbrzxRiMtLc0wDGfdXbly5Qx1ZFJSklGxYkWjdevWrrZu3boZVqs101ow/b6Z1XAXx3xx3Z5VvZteQ6XXbampqUbx4sWNGjVqGMnJya7zFi1aZADGiBEjXG3pNdyLL77ods3rr7/eqFu3bobHyk5UVFSmNZ9hZPy7JL3fffr0cbWl190Wi8WYMGGCq/3UqVNGRESE27U//PBDw2q1GqtXr3Z7nLffftsAjDVr1hiGYRiTJ082ALe/kURELqblEUQCTGJiIgAxMTE5On/JkiUADB482K39qaeeAsiw9u11111Hw4YNXb83aNAAgBYtWlC+fPkM7X/++WeGx+zXr5/rOP1r8KmpqXz33Xc5ijldRESE6/js2bOcOHGCJk2akJSUxK5du3J1LYAjR46QkJBAjx49KFy4sKu9Vq1atG7d2vVcXezRRx91+71JkyacPHnS9Trkld69e7t9Ha1BgwYYhkHv3r1dbUFBQdx4441ur8G8efMoUKAArVu35sSJE66funXrEh0dzYoVK4D/1p1dtGgRaWlpuYrt5MmTFCpU6Cp659S0aVOuu+46t7aCBQvm2bICnpiN0bNnT7eZJ02aNAEy/3dQqFChDDNmRERERDLz5ptvsnz5crefb775xu2cGjVqMHr0aN59913atm3LiRMnmDNnjmvvgISEBPbs2cP999/PyZMnXXXg+fPnadmyJatWrcLhcOBwOPjiiy+4/fbbM11HN7MlETzl559/5tixYzz++ONua8l26NCB+Pj4TPflyKwez6z28rSLN9ZNr7svrccLFixI1apVM9Tj1apVIz4+3q0eb9GiBUCGevzLL7/E4XDkeX9EJPBoeQSRABMbGws4BzFz4sCBA1itVq699lq39pIlS1KwYEEOHDjg1n7xwCxAgQIFAChXrlym7Zeuk2q1WjMs/F+lShUA17pXOfXrr78yfPhwfvjhhwyDpGfOnMnVtQBXX6tWrZrhtmrVqvHtt99m2KDq0ucjfbDy1KlTrtciL+Tmdbj4NdizZw9nzpyhePHimV43fbOwpk2bcvfddzN69GgmT55Ms2bN6NixI/fff3+ONkczLllH90pktkPy448/zmeffcatt95KmTJlaNOmDV26dKFdu3ZX9VjBwcGULVv2qq4B2efDpQzDyNM/ekRERCT/qF+/fo42InvmmWf49NNP2bhxI+PGjXP7AHzPnj0AdO/ePcv7nzlzhtTUVBITE6lRo8bVB55L2dXj8fHx/PTTT25t4eHhFCtWzK2tUKFCmdZenpZZPR4eHk7RokUztJ88edL1+549e9i5c2eGuNOl1+Ndu3bl3Xff5eGHH+a5556jZcuWdOrUic6dO7stPyci5qVBW5EAExsbS+nSpdmxY0eu7pfTwaOgoKBctXti8C4zp0+fpmnTpsTGxvLiiy9SqVIlwsPD2bJlC0OGDPHap9He7vflHjez9otjcTgcFC9enI8++ijT+6cXjxaLhfnz57N+/Xq+/vprvv32W3r16sVrr73G+vXrs10XuEiRIh4plC+eSZ2uePHiJCQk8O233/LNN9/wzTffMGvWLLp168acOXOu+LHCwsIyLX6z+ndht9szfa5zkw+nTp2icuXKuYxUREREJGt//vmna3B2+/btbrel18evvPIKderUyfT+0dHR/Pvvvzl6rOzqJG/Jqvby1WPnpBZ0OBzUrFmTSZMmZXpu+iSMiIgIVq1axYoVK1i8eDFLly5l7ty5tGjRgmXLlvm07yLiHzRoKxKAbrvtNmbMmMG6devcljLITIUKFXA4HOzZs8dts6ajR49y+vRpKlSo4NHYHA4Hf/75p2t2LcDvv/8OkGGTr+ysXLmSkydPsmDBAtei/gD79u3LcG5OB6TT+7p79+4Mt+3atYuiRYu6zbINRJUqVeK7776jcePGmQ6KXuqmm27ipptuYuzYsXz88cc88MADfPrpp25fB7tUfHx8pq+Dp4SGhnL77bdz++2343A4ePzxx3nnnXd44YUXuPbaaz06e7VQoUKcPn06Q/uBAwcyzBjPDZvNxqFDh7jjjjuuIjoRERGR/zgcDnr06EFsbCyDBg1i3LhxdO7cmU6dOgHOOhCckzxatWqV5XWKFStGbGzsZSeBpH+j6PTp026bg136TT24sno8fbmAdLt37/b43ya+UKlSJX755Rdatmx52efFarXSsmVLWrZsyaRJkxg3bhzPP/88K1asyPY1FBFz0Jx7kQD07LPPEhUVxcMPP8zRo0cz3L53716mTp0KQPv27QHnrq4XS//kt0OHDh6Pb9q0aa5jwzCYNm0aISEhtGzZMsfXSP9k+eJPrVNTU3nrrbcynBsVFZWj5RJKlSpFnTp1mDNnjttA3Y4dO1i2bJnruQpkXbp0wW63M2bMmAy32Ww2V79PnTqVYXZo+oyMCxcuZPsYDRs2ZMeOHZc970pc/NUycBaytWrVcosrfWA9s8HW3KpUqRLr168nNTXV1bZo0SIOHTp0Vdf97bffSElJoVGjRlcbooiIiAjgrN/Xrl3LjBkzGDNmDI0aNeKxxx5zraFft25dKlWqxKuvvsq5c+cy3P/48eOAs77q2LEjX3/9NT///HOG89JrxPRB4FWrVrluO3/+fKbffoqKispRbXbjjTdSvHhx3n77bbda8ptvvmHnzp158reJt3Xp0oXDhw8zc+bMDLclJydz/vx5gExnPOe0HhcRc9BMW5EAVKlSJT7++GO6du1KtWrV6NatGzVq1CA1NZW1a9cyb948evToAUDt2rXp3r07M2bMcC05sHHjRubMmUPHjh1p3ry5R2MLDw9n6dKldO/enQYNGvDNN9+wePFihg0bluW6Tplp1KgRhQoVonv37gwYMACLxcKHH36Y6dfQ69aty9y5cxk8eDD16tUjOjqa22+/PdPrvvLKK9x66600bNiQ3r17k5yczBtvvEGBAgUYNWrUlXbbbzRt2pS+ffsyfvx4EhISaNOmDSEhIezZs4d58+YxdepUOnfuzJw5c3jrrbe46667qFSpEmfPnmXmzJnExsZedvD6zjvvZMyYMfz444+0adPGo/E//PDD/Pvvv7Ro0YKyZcty4MAB3njjDerUqeOaKV6nTh2CgoJ4+eWXOXPmDGFhYbRo0SLLdXwv93jz58+nXbt2dOnShb179/J///d/rj9SrtTy5cuJjIykdevWV3UdERERMYdvvvkm0412GzVqxDXXXMPOnTt54YUX6NGjh6vOnT17NnXq1HHtCWC1Wnn33Xe59dZbqV69Oj179qRMmTIcPnyYFStWEBsby9dffw3AuHHjWLZsGU2bNqVPnz5Uq1aNI0eOMG/ePH766ScKFixImzZtKF++PL179+aZZ54hKCiI999/n2LFinHw4EG3OOvWrcv06dN56aWXuPbaaylevHiGmbQAISEhvPzyy/Ts2ZOmTZty3333cfToUaZOnUpcXBxPPvlkHjy73vXQQw/x2Wef8eijj7JixQoaN26M3W5n165dfPbZZ3z77bfceOONvPjii6xatYoOHTpQoUIFjh07xltvvUXZsmW5+eabfd0NEfEDGrQVCVB33HEH27Zt45VXXuHLL79k+vTphIWFUatWLV577TUeeeQR17nvvvsu11xzDbNnz2bhwoWULFmSoUOHMnLkSI/HFRQUxNKlS3nsscd45plniImJYeTIkYwYMSJX1ylSpAiLFi3iqaeeYvjw4RQqVIgHH3yQli1b0rZtW7dzH3/8cRISEpg1axaTJ0+mQoUKWQ7atmrViqVLl7piCgkJoWnTprz88suZbowViN5++23q1q3LO++8w7BhwwgODiYuLo4HH3yQxo0bA7gG7z/99FOOHj1KgQIFqF+/Ph999NFln4e6detSq1YtPvvsM48P2j744IPMmDGDt956i9OnT1OyZEm6du3KqFGjXGvSlixZkrfffpvx48fTu3dv7HY7K1asuKJB27Zt2/Laa68xadIkBg0axI033ujKu6sxb948OnXqRExMzFVdR0RERMwhq1p51qxZVKhQge7du1O0aFG3b89VrlyZ8ePHM3DgQD777DO6dOlCs2bNWLduHWPGjGHatGmcO3eOkiVL0qBBA/r27eu6b5kyZdiwYQMvvPACH330EYmJiZQpU4Zbb72VyMhIwDnAunDhQh5//HFeeOEFSpYsyaBBgyhUqBA9e/bMEP+BAweYOHEiZ8+epWnTppkO2gL06NGDyMhIJkyYwJAhQ4iKiuKuu+7i5ZdfdluGIVBZrVa++OILJk+ezAcffMDChQuJjIzkmmuuYeDAga5l5O644w7279/P+++/z4kTJyhatChNmzZl9OjRrk2IRcTcLEZe76YjIqbRo0cP5s+fn+nXsSR/+fDDD3niiSc4ePBgviiuPSkhIYEbbriBLVu2ZLkJiIiIiIiIiEh2tKatiIjk2gMPPED58uV58803fR2K35kwYQKdO3fWgK2IiIiIiIhcMS2PICIiuWa1Wi+747BZffrpp74OQURERERERAKcZtqKiIiIiIiIiIiI+BGtaSsiIiIiIiIiIiLiRzTTVkRERERERERERMSPaNBWRERERERERERExI8E9EZkDoeDv//+m5iYGCwWi6/DERERETE1wzA4e/YspUuXxmrV3ABvUD0sIiIi4j88WQ8H9KDt33//Tbly5XwdhoiIiIhc5NChQ5QtW9bXYZiC6mERERER/+OJejigB21jYmIA5xMRGxvr42hEREREzC0xMZFy5cq5ajTJe6qHRURERPyHJ+vhgB60Tf8KWGxsrN8XqXa7nT179lC5cmWCgoJ8HY6I1yj3xayU+2JGdrsdQF/T9yLVwyL+T7kvZqS8F7PyZD2sxca8xOFwsHv3bhwOh69DEfEq5b6YlXJfzEj5LtnR+6KYlXJfzEh5L2blyZzXoK2IiIiIiIiIiIiIH9GgrYiIiIiIiIiIiIgfCeg1bXPKbreTlpbm8xjKlStHamqqa30LMZeQkBBTruVjtVopX748Vqs+IxJzUe6LGSnf/ZfqYQkk+a1uVk0gZqS8F7PyZM5bDMMwPHY1L0tMTKRAgQKcOXMm040XDMPgn3/+4fTp094PTiQTBQsWpGTJktqgRURE8qXL1WbieaqHJb9S3SwiIoHIk/Vwvp5pm16gFi9enMjISJ/+D98wDFJSUggPD1fhYUKGYZCUlMSxY8cAKFWqlI8j8h673c62bduoVatWvpoxIXI5yn0xI82e9D+qhyXQ5Me6WTWBmJHyXszKk/Vwvh20tdvtrgK1SJEivg4Hh8PBhQsXCAsL09cDTCoiIgKAY8eOUbx4cdP8j8vhcHDw4EFq1Khhmj6LgHJfzEk7RPsX1cMSqPJb3ayaQMxIeS9m5cl6ON9WS+lrdkVGRvo4EpH/pOejr9eUExERkfxP9bAEMtXNIiJidvl20Dadvnol/kT5KCIiIt6m+kMCkfJWRETMLt8P2voLi8VCWFiYig8xHavVStWqVfU1SDEd5b6YkfJdsqN6WMxKNYGYkfJezMqTOa9/PV5isViIiIhQkXoZK1euxGKxaIfjfCQoKIj4+HitYySmo9wXM1K+S3ZUD3uOaubAoppAzEh5L2blyZzXoG0O2B12Vu5fySfbP2Hl/pXYHbnfCc4wDM6dO4dhGHkQYd7Yv38/FouFhISEDL+PGjUKi8WS7c+levToke35cXFxNGrUiCNHjlCgQAEv91byis1mY+3atdhsNl+HIuJVyn0xI+V7/uSJWhgCrx5Or33TfwoXLkzTpk1ZvXp1nj6uaub8RzWBmJHyXszKkzmvQdvLWLBzAXFT42g+pzn3L7if5nOaEzc1jgU7F+TqOoZhYLPZfFakpqamevR6Tz/9NEeOHHH9lC1blhdffNGt7VJTp07NcPusWbNcv2/atInQ0FBKliypGRj5iGEYHD9+PGD+QBPxFOW+mJHyPf/xVC0Mvq2Hr6YW/u677zhy5AirVq2idOnS3HbbbRw9etSD0blTzZz/qCYQM1Lei1l5Muc1aJuNBTsX0PmzzvyV+Jdb++HEw3T+rPMVFas5ceHCBQYMGEDx4sUJDw/n5ptvZtOmTa7bZ8+eTcGCBd3u88UXX7gVbaNGjaJOnTq8++67VKxYkfDwcADmz59PzZo1iYiIoEiRIrRq1Yrz58/nOsbo6GhKlizp+gkKCiImJsat7VIFChTIcHvBggVdvxcrVizDV73S+7po0SKqVq1KZGQknTt3JikpiTlz5hAXF0ehQoUYMGAAdvt/sz4uXLjA008/TZkyZYiKiqJBgwasXLky1/0UERERMSvVwk5FihShZMmS1KhRg2HDhpGYmMiGDRtyHcuHH35IXFwcBQoU4N577+Xs2bOZPp5qZhEREQEI9nUA3mQYBklpSTk61+6wM+CbARhkHCE3MLBgYeA3A2lVsRVB1uzXq4gMicxVnM8++yyff/45c+bMoUKFCkycOJG2bdvyxx9/ULhw4Rxf548//uDzzz9nwYIFBAUFceTIEe677z4mTpzIXXfdxdmzZ1m9erXff/KVlJTE66+/zqeffsrZs2fp1KkTd911FwULFmTJkiX8+eef3H333TRu3JiuXbsC0K9fP3777Tc+/fRTSpcuzcKFC2nXrh3bt2+ncuXKPu6RiIiIiPf5qhaG3NXD/loLJycn88EHHwAQGhqa4zgA9u7dyxdffMGiRYs4deoUXbp0YcKECYwdOzZX18mOamYREZH8xVSDtklpSUSPj/bItQwM/jr7FwVevvw6UueGniMyJDJHGy+cP3+e6dOnM3v2bG699VYAZs6cyfLly3nvvfd45plnchxjamoqH3zwAcWKFQNgy5Yt2Gw2OnXqRIUKFQCoWbNmjq/nK2lpaUyfPp1KlSoB0LlzZz788EOOHj1KdHQ01113Hc2bN2fFihV07dqVgwcPMmvWLA4ePEjp0qUB53IOS5cuZdasWYwbN86X3TGdoKAg6tSpowXoxXSU+2JGynf/5qtaGHJeD/tjLdyoUSOsVitJSUkYhkHdunVp2bJljuMAcDgczJ49m5iYGAAeeughvv/+e48O2qpm9l+qCcSMlPdiVp7MeVMN2vqSxWIhLCzssuft3buXtLQ0Gjdu7GoLCQmhfv367Ny5M1ePWaFCBVeRClC7dm1atmxJzZo1adu2LW3atKFz584UKlQoV9f1tsjISFfxCVCiRAni4uKIjo52azt27BgA27dvx263U6VKFbfrXLhwgSJFingnaHGxWq2uP4xEzES5L2ZktWrlLclaTuphf6yF586dS3x8PDt27ODZZ59l9uzZhISE5CqWuLg414AtQKlSpVy1q6eoZvZfqgnEjJT3YlaerIdNNWgbGRLJuaHncnTuqgOraP9x+8uet+T+JdxS4ZbLPq5hGJw9e5aYmJir3jDAarVm+BpXWlpahvOioqLcfg8KCmL58uWsXbuWZcuW8cYbb/D888+zYcMGKlaseFUx5aVLi2KLxZJpm8PhAODcuXMEBQWxefPmDJ9wXFy0infYbDZWrVrFLbfcQnCwqd5yxOSU+2I6djt2rYXp13xVC6c/tqfqYW/XwuXKlaNy5cpUrlwZm83GXXfdxY4dOwgLC8txLNnVrp6imtl/qSYQM1Leiyl5uB421XQIi8VCVGhUjn7aVGpD2diyWMi8oLRgoVxsOdpUanPZa1ksFgzDwOFwXHbNrEqVKhEaGsqaNWtcbWlpaWzatInrrrsOgGLFinH27Fm3TRMSEhJy/Bw0btyY0aNHs3XrVkJDQ1m4cGGO7hsorr/+eux2O8eOHePaa691+8lsgzTJW+l/oPn72skinqbcF1NZsADi4gi6805fRyLZ8FUtnJt62N9r4c6dOxMcHMxbb7111bH4mmpm71FNIGakvBfTyYN62FSDtrkRZA1iarupABmK1fTfp7SbkqONF3IjKiqKxx57jGeeeYalS5fy22+/8cgjj5CUlETv3r0BaNCgAZGRkQwbNoy9e/fy8ccfM3v27Mtee8OGDYwbN46ff/6ZgwcPsmDBAo4fP061atU82gdfq1KlCg888ADdunVjwYIF7Nu3j40bNzJ+/HgWL17s6/BERETylwULoHNn+OsvX0ciHqRaOHMWi4UBAwYwYcIEkpKSrjgWf6CaWURExEPyqB7WoG02OlXrxPwu8ykTW8atvWxsWeZ3mU+nap3y5HEnTJjA3XffzUMPPcQNN9zAH3/8wbfffutab6tw4cL83//9H0uWLKFmzZp88sknjBo16rLXjY2NZdWqVbRv354qVaowfPhwXnvtNdcmD5dK/+pUIH6VYdasWXTr1o2nnnqKqlWr0rFjRzZt2kT58uV9HZqIiEj+YbfDwIGgWTT5ktlr4ax0796dtLQ0pk2bdsWx+AvVzCIiIlcpD+thixHAc9UTExMpUKAAZ86cITY21u22lJQU9u3bR8WKFQkPD7+qx7E77Kw+uJojZ49QKqYUTco3yfWsAsMwsNlsBAcHX/Watt6yfv16GjZsyPHjxylatKivw8kXPJmXgcLhcHDixAmKFi2qDWrEVJT7YgorV0Lz5q5fE4ECkGltJnnDG/WwJ2phCMx6WHwnP9XNqgnEjJT3Yhp5WA8H3hRKHwiyBtEsrtlVXSOzjQD8lc1mY//+/bzyyivUrl1bA7ZyVaxWK8WLF/d1GCJep9wXUzhyxNcRiBd4ohaGwKqHRTxJNYGYkfJeTCMP62F93OElDoeD06dPe3yX2LywY8cOatWqxZEjR/jggw98HY4EuLS0NBYvXpzpTsoi+ZlyX0yhVClfRyABJJDqYRFPUk0gZqS8F9PIw3pYM20lgzp16pCUlOTrMCQfsdlsvg5BxCeU+5LvNWkCZctqEzIRkctQTSBmpLwXU8jDelgzbUVERETkygQFOTdeEBERERExo6AgGDMmTy6tQVsRERERuTKpqfB//+c8jojwbSwiIiIiIr6wcaPzvx5eu1+Dtl5isViIiYnRTrliOsHBwTRv3pzgYK3GIuai3BdTGD8efvkFihSBvXsxvv7a1xGJH1M9LGalmkDMSHkvprF9O7zzjvP4m288Wg/rX48XqUAVs4rQ7CsxKeW+5GvbtsFLLzmPp01zbsLQpIlvYxK/p3pYzEo1gZiR8l7yPcOAQYPA4YDOnaFlSzhzxmOX10xbLzEMg8TERAzD8HUoIl5ls9lYsmSJFqEX01HuS76WlgY9e4LNBh07QteugDYckeypHhazUk0gZqS8F1P48kv44QcIC4NXXgE8Ww9r0FZEREREcmfiRNiyBQoVgunTQbMnRURERMRMUlLgqaecx08/DXFxHn8IDdrmQ0lJSdx9993ExsZisVg4ffp0pm2e0qxZMwYNGpTtOXFxcUyZMsVjj+kJPXr0oGPHjr4OQ0REJLDs2AGjRzuPX38dSpb0bTwil1At7FmqmUVERDIxZQr8+adzibDnnsuTh9CgbQ7Y7bByJXzyifO/drvvYpk5cyZNmjShUKFCFCpUiFatWrExfZe6/5kzZw6rV69m7dq1HDlyhAIFCmRoO3XqFBaLhYSEBN905BKjRo2iTp06mf4eFxeHxWLJ8qdHjx4Zrpfd+RaLhVGjRjF16lRmz57tlf6JiIjkCzabc1mEtDS4/XZ44AFfRyReoFo4740aNcpVpwYFBVGuXDn69OnDv//+m6ePq5pZRETkChw5AmPHOo9ffhmio/PkYbQR2WUsWAADB8Jff/3XVrYsTJ0KnTrl/DoWi8X1yf7VWLlyJffddx+NGjUiPDycl19+mTZt2vDrr79SpkwZAPbu3Uu1atWoUaOG636Xtu3fv/+q4vCmTZs2Yf/fXwdr167l7rvvZvfu3cTGxgKZL25+5MgR1/HcuXMZMWIEu3fvdrVFR0cTnUf/qMRdcHAw7du3166hYjrKfcmXXnsNfv4ZChaEt9/OsCyC8j3/8VQtDJ6ph/NzLVy9enW+++477HY7O3fupFevXpw5c4a5c+fm2WOqZvYO1QRiRsp7ydeGDYNz56BBgwyTGDyZ85ppm40FC5ybv11cpAIcPuxsX7Agd9fL6aYL8+fPp2bNmkRERFCkSBFatWrF+fPnAfjoo494/PHHqVOnDvHx8bz77rs4HA6+//57wPn1rNdee41Vq1ZhsVho1qxZpm0VK1YE4Prrr3e1ZeXHH3+kfv36hIWFUapUKZ577rlsF1Y+duwYt99+OxEREVSsWJGPPvooh89Q5ooVK0bJkiUpWbIkhQsXBqB48eKutgIFCmS4T/pt6bdbLBa3tujo6Axf9WrWrBn9+/dn0KBBFCpUiBIlSjBz5kzOnz9Pz549iYmJ4dprr+Wbb75xe6wdO3Zw6623Eh0dTYkSJXjooYc4ceLEVfU5v0lOTvZ1CCI+odyXfGXnThgxwnk8ZQqULu3TcCTveboWhpzVw2athYODgylZsiRlypShVatW3HPPPSxfvtx1e2bLMHTs2NHtW2dxcXGMGzeOXr16ERMTQ/ny5ZkxY0aWj6ma2XtUE4gZKe8lX9q0CdK/gTJ1KljzbmjVVIO2hgHnz+fsJzERBgxw3iez64Bz1kFi4uWvZRjOAvXs2bOXLVSPHDnCfffdR69evdi5cycrV66kU6dOWd4vKSmJtLQ012DmggULeOSRR2jYsCFHjhxhwYIFmbalf43su+++c7Vl5vDhw7Rv35569erxyy+/MH36dN577z1eeumlLPvQo0cPDh06xIoVK5g/fz5vvfUWx44dy7bf/mLOnDkULVqUjRs30r9/fx577DHuueceGjVqxJYtW2jTpg0PPfQQSUlJAJw+fZoWLVpw/fXX8/PPP7N06VKOHj1Kly5dfNwT/2Gz2VixYoV2DRXTUe5LvmK3O5dFSE2FW2+Fbt0yPU357t98VQvnph5WLey0f/9+vv32W0JDQ3N1P4DXXnuNG2+8ka1bt/L444/z2GOPuc2e9QTVzLmjmkDMSHkv+ZJhOAsggIcecs60vYQnc95U89STkjy3zIRhOGcdZDLJM4Nz5yCTb/Bn6siRI9hsNjp16kSFChUAqFmzZpbnDxkyhNKlS9OqVSsAChcuTGRkJKGhoZS8aGOQS9sSExMBKFKkiNt5l3rrrbcoV64c06ZNw2KxEB8fz99//82QIUMYMWIE1ks+Ufj999/55ptv2LhxI/Xq1QPgvffeo1q1ajl7Anysdu3aDB8+HIChQ4cyYcIEihYtyiOPPALAiBEjmD59Otu2beOmm25i2rRpXH/99YwbN851jffff59y5crx+++/U6VKFZ/0Q0RExKMmT4YNGyA2FmbMyLAsggQGX9XCkPN62My18Pbt24mOjsZut5OSkgLApEmTLnu/S7Vv357HH38ccD4/kydPZsWKFVStWjXX18qKamYRETGlTz6BdesgKgrGj8/zhzPVoG0gqF27Ni1btqRmzZq0bduWNm3a0LlzZwoVKpTh3AkTJvDpp5+ycuVKwsPD8ySenTt30rBhQ7e1xxo3bsy5c+f466+/KF++fIbzg4ODqVu3rqstPj6eggUL5kl8nlarVi3XcVBQEEWKFHH7Q6FEiRIArtkSv/zyCytWrMh0ra+9e/eqABURkcC3eze88ILzeNIk54KmInnEzLVw1apV+eqrr0hJSeH//u//SEhIoH///rmO+eJ6Nn25A09/6001s4iImM758zBkiPN46FD431r6eclUyyNERjo/5c/Jz5IlObvmkiWXv1ZkZM5jDAoKYvny5XzzzTdcd911vPHGG1StWpV9+/a5nffqq68yYcIEli1b5lY0ydUJCQlx+91isbi1pRfsDocDgHPnznH77beTkJDg9rNnzx5uueUW7wXu57T4vJiVcl8Cnt0OvXpBSgq0aeM8loDlq1o4N/WwmWvh0NBQrr32WmrUqMGECRMICgpi9OjRrtutVmuGZSLS0tIyXCezeja9dvUU1cy5p5pAzEh5L/nKxInOrxnFxcHgwV55SFMN2loszhnMOflp08Y5kSSrb/9ZLFCunPO8y13LYnEWWQULFszwFarMr22hcePGjB49mq1btxIaGsrChQtdt0+cOJExY8awdOlSbrzxxit6LtLXx7Lb7dmeV61aNdatW+dWIK5Zs4aYmBjKZjLTJj4+HpvNxubNm11tu3fv5vTp01cUp7+74YYb+PXXX4mLi+Paa691+4mKivJ1eH4hJCSEDh06ZCjuRfI75b7kC2+8AWvXQkwMzJx52WURlO/+zVe1cG7rYdXCTsOHD+fVV1/l77//Bpyb8x45csR1u91uZ8eOHbm+ri+YvWZWTSBmpLyXfOXAAeegLcArr2S75pMnc95Ug7a5ERTk3AQOMhar6b9PmeI8LycMwyAtLe2yG5Ft2LCBcePG8fPPP3Pw4EEWLFjA8ePHXetgvfzyy7zwwgu8//77xMXF8c8///DPP/9w7ty5XPQOihcvTkREhGsTgDNnzmR63uOPP86hQ4fo378/u3bt4ssvv2TkyJEMHjw404K7atWqtGvXjr59+7JhwwY2b97Mww8/TEROF/UNME888QT//vsv9913H5s2bWLv3r18++239OzZ87J/BJiFw+Hg2LFjHp/hIeLvlPsS8P74A4YNcx6/8gpc8jXwzCjf8w9P18KQs3pYtfB/GjZsSK1atVzrwLZo0YLFixezePFidu3axWOPPRYwEyPMXjOrJhAzUt5LvjJkiPObZ02bwt13Z3uqJ3Neg7bZ6NQJ5s/PuExF2bLO9k6dcn4twzA4f/78ZQdtY2NjWbVqFe3bt6dKlSoMHz6c1157jVtvvRWA6dOnk5qaSufOnSlVqpTr59VXX81V34KDg3n99dd55513KF26NHfeeWem55UpU4YlS5awceNGateuzaOPPkrv3r1dGw9kZtasWZQuXZqmTZvSqVMn+vTpQ/HixbONx+FwBORXJ0qXLs2aNWuw2+20adOGmjVrMmjQoBzPqjYDu93OunXrTFGQi1xMuS8BzeGA3r0hORlatIA+fXJ0N+V7/uLJWhhyVg+btRbOypNPPsm7777LoUOH6NWrF927d6dbt240bdqUa665hubNm1/Rdb3N7DWzagIxI+W95BurV8PcuWC1Oj+xvsw3zzyZ8xbjcqOIfiwxMZECBQpw5swZYmNj3W5LSUlh3759VKxY8ao3JrDbna/RkSNQqhQ0aZK7WQXgHJRMTEwkNjbWFIVJbj366KP89ddfLFq0yNeh5ClP5mWgSEtLY8mSJbRv315fjRFTUe5LQJs2Dfr3d36vfccO59pdOXDy5EmKFi2aaW0mecMb9bAnamFQPSy5k5/qZtUEYkbKe8kX7HaoVw+2bnVOYnjnncvexZP1cOBNbfSBoCBo1szXUeRPZ8+eZevWrSxYsIBh6V/BFBEREd/588//dsadODHHA7aSf6kWFhEREVOaPds5YFugALz0ktcfXh9xe4nFYsFqtbp2UhWnESNG0LlzZ+666y4effRRX4cjecBisRATE6PcF9NR7ktASl8WISnJOUqXy/83K98lO6qHxaxUE4gZKe8l4CUm/re/w4gRUKxYju7myZzXTFsvsVgs+ppgJiZPnszkyZN9HYbkoeDgYFq0aOHrMES8TrkvAemdd2DlSoiMhHffda7dlQuBuD69eI/qYTEr1QRiRsp7CXgvvQTHjkGVKtCvX47v5sl6WDNtvcQwDC5cuHDZjchE8huHw8GBAwe0a6iYjnJfAs7+/fDss87j8eOhUqVcX0L5LtlRPSxmpZpAzEh5LwFtzx7npmMAkydDaGiO7+rJnNegrZcYhkFycrKKVDEdu91OQkKCdg0V01HuS0AxDHjkETh3zrnLVC5mE1xM+S7ZUT0sZqWaQMxIeS8B7amnIC0N2rWD9u1zdVdP5rwGbUVERETM7t134bvvIDwc3nsv18siiIiIiIjkC8uWwddfQ3AwTJrk01BUkYuIiIiY2cGDztkEAOPGQeXKvo1HRERERMQXbDZ48knncb9+UK2aT8PRoK2XWCwWgoODtXOimI7FYqFYsWLKfTEd5b4EBMOAPn3g7Flo2BAGDLiqyynfJTuqh8WsVBOIGSnvJSC9/Tb89hsUKQIjRlzRJTyZ89ri10ssFgvR0dG+DkPE64KDg2nUqJGvwxDxOuW+BITZs+HbbyEsDN5/H4KCrupyntwtV/If1cNiVqoJxIyU9xJwTp78b6D2pZegUKEruown62HNtPUSf914YeXKlVgsFk6fPu3rUDwuLi6OKem7/YnP2O12du3apQXoxXSU++L3Dh/+7+tfY8ZAfPxVX1L5Ltnxx3o4P9fCOaWaOe+pJhAzUt5LwBk5Ek6dgpo14eGHr/gy2ojM2+x2WLkSPvnE+d8reAEMw+DChQs5KlJ79OhBx44dM7QHclHZrFkzBg0alOH3/fv3Y7FYsv2ZPXu227XSn4fsflauXMmmTZvo06ePdzsqGTgcDnbv3o3D4fB1KCJepdwXv2YY0LcvnDkD9evD4MEeuazyPZ/yQC0MOa+H82stnF6nhoeHU6VKFcaPH5+nA9iqmf2HagIxI+W9BJQdO5xLIwBMmeLchOwKeTLnNWh7OQsWQFwcNG8O99/v/G9cnLNdrlq5cuU4cuSI6+epp56ievXqbm1du3Z1u0+jRo3cbu/SpQvt2rVza2vUqBHFihUjMjLSRz0TERHxYx9+CIsXQ2gozJp11csi5BdvvvkmcXFxhIeH06BBAzZu3Jjt+fPmzSM+Pp7w8HBq1qzJkiVL3G4fNWoU8fHxREVFUahQIVq1asWGDRvczomLi8swkDZhwgSP9+2KqRb2mEceeYQjR46we/duhg4dyogRI3g7/Q/EPKCaWUREJAcMAwYNcn4o3akTtGjh64hcNGibnQULoHNn+Osv9/bDh53tflCs/vTTTzRp0oSIiAjKlSvHgAEDOH/+vOv2Dz/8kBtvvJGYmBhKlizJ/fffz7Fjx7K83uzZsylYsCBffPEFlStXJjw8nLZt23Lo0CEA9u/fj9Vq5eeff3a735QpU6hQoUKuP1EICgqiZMmSrp/o6GiCg4Pd2iIiItzuExoamuH2sLAwt7bQ0NAMX/WyWCy888473HbbbURGRlKtWjXWrVvHH3/8QbNmzYiKiqJRo0bs3bvX7fG+/PJLbrjhBsLDw7nmmmsYPXo0NpstV/0UERHxG0eOwMCBzuNRo+C663wajr+YO3cugwcPZuTIkWzZsoXatWvTtm3bLOumtWvXct9999G7d2+2bt1Kx44d6dixIzt27HCdU6VKFaZNm8b27dv56aefiIuLo02bNhw/ftztWi+++KLbQFr//v3ztK85plrYo7VwZGQkJUuWpEKFCvTs2ZNatWqxfPly1+0Wi4UvvvjC7T4FCxZ0fess/RtqCxYsoHnz5kRGRlK7dm3WrVuX6eOpZhYREcmBr76C7793TmZ45RVfR+PGXIO2hgHnz+fsJzHRuYNyZl9ZSm8bONB53uWuZRhYLBZCQ0M9uovc3r17adeuHXfffTfbtm1j7ty5/PTTT/Tr1891TlpaGmPGjOGXX37hiy++YP/+/fTo0SPb6yYlJTF27Fg++OAD1qxZw+nTp7n33nsB52yQVq1aMWvWLLf7zJo1ix49emC1+ndKjRkzhm7dupGQkEB8fDz3338/ffv2ZejQofz8888YhuH2/K1evZpu3boxcOBAfvvtN9555x1mz57N2LFjfdiLwGK1Wilfvrzf54aIpyn3xS8ZBjz6KJw+DXXrwjPPePTygZzvkyZN4pFHHqFnz55cd911vP3220RGRvL+++9nev7UqVNp164dzzzzDNWqVWPMmDHccMMNTJs2zXXO/fffT6tWrbjmmmuoXr06kyZNIjExkW3btrldK31AMf0nKioqbzrpq1o4j+rhQK2FDcNg9erV7Nq1i9DQ0Fz3+/nnn+fpp58mISGBKlWqcN9993l8cFQ1s2epJhAzUt5LQLhwAZ56ynn81FNwzTVXfUmP5rwRwM6cOWMAxpkzZzLclpycbPz2229GcnLyf43nzhmGs8z07s+5c7nqV/fu3Y2goCAjKirK7Sc8PNwAjFOnThmGYRi9e/c2+vTp43bf1atXG1ar1b3fF9m0aZMBGGfPnjUMwzBWrFjhds1Zs2YZgLF+/XrXfXbu3GkAxoYNGwzDMIy5c+cahQoVMlJSUgzDMIzNmzcbFovF2LdvX5Z9atq0qTFw4MAsf083cuRIo3bt2tk8Oxl1797duPPOOzO0V6hQwZg8ebLrd8AYPny46/d169YZgPHee++52j755BMjPDzc9XvLli2NcePGuV33ww8/NEqVKpWrGNNlmpciIiLe8tFHztokJMQwtm/3+OWzq8382YULF4ygoCBj4cKFbu3dunUz7rjjjkzvU65cObc6wzAMY8SIEUatWrWyfIxXXnnFKFCggHH8+HFXe4UKFYwSJUoYhQsXNurUqWNMnDjRSEtLyzLWlJQU48yZM66fQ4cOGYBx4sQJIzU11UhNTTVsNpthGIZx7tw549dffzXOnz9v2O12w3H2rG9qYTAcZ88adrvd9eNwOAzDMAyHw5Gh/XK18MmTJw273W706tXLeOSRR9yu8+OPPxpWq9VISkrK9PobN2505ajdbje+//57AzD+/fdfw+FwGO+9954BGGvXrnXF89tvvxmAsW7dOsNutxuffPKJUahQISM5Odmw2+3Gpk2bDIvFYvz5559Z9qlp06ZGSEiIERUVZYSEhBiAER4ebqxZs8Z1PmB8/vnnbs9NgQIFjPfee8+w2+3Gn3/+aQDGzJkzXdfevn27ARg7d+50e8z0x704lvR8vrS9QoUKxqRJk1yPCRjPP/+86xrpNfPFj/vxxx8b4eHhruu0bNnSGDt2rFvsc+bMMUqVKnXZ1zur9uTkZOPXX381EhMTXbmdHnv675drNwzDsNvtbm3p/76yarfZbG7t6f+esmpPS0tza7fb7dm25zR29Ul9Up/UJ/Up7/tkHz/eWaeUKmWknjzpkT6dOHHCY/Xwla+sK7li/G+33IiIiBzNLmjevDnTp093a9uwYQMPPvig6/dffvmFbdu28dFHH7k9jsPhYN++fVSrVo3NmzczatQofvnlF06dOuX6ytbBgwe5LouvQwYHB1OvXj3X7/Hx8RQsWJCdO3dSv359OnbsyBNPPMHChQu59957mT17Ns2bNycuLi43T4lP1KpVy3VcokQJAGrWrOnWlpKSQmJiIrGxsfzyyy+sWbPGbZaA3W4nJSWFpKQkrf+VA3a7nW3btlGrVi2CtGaimIhyX/zOP/9A+tfuR4yAGjU8/hCBukP0iRMnsNvtrtogXYkSJdi1a1em9/nnn38yPf+ff/5xa1u0aBH33nsvSUlJlCpViuXLl1O0aFHX7QMGDOCGG26gcOHCrF27lqFDh3LkyBEmTZqU6eOOHz+e0aNHZ2hftmyZqy4pX748119/Pbt378YwDM6dO0dqaiphNhsRGe7pHTabjfOJia7frVYrsbGxpKamkpyc7GoP/t/GH02bNuWVi76iGBISwrZt23jwwQc5e/YsVquVrVu38uuvv/Lxxx+7zkuvhX///Xdq167NqlWrGD9+PDt27ODMmTOuWvi3334jPj6epKQkwLlpiGEYpKSkEBwcTNWqVV31YJUqVShQoABbt24lPj6eFi1aEBQUxPz587ntttuYOXMmTZo0oUiRIgBZ9qlr1648+eSTnD59mvHjx9OoUSMaNWpEUlISqampACQnJ5OSkkJERATnz593xZSYmOg6p1KlSiT+77mMjo4G4NixY5QsWdLtOY+JicFisbjOTUtLw2azYRgGhmFw9uxZV99TUlJcrxPAtddeS2JiIlar1ZXn11xzjetaBQoUICUlhWPHjhEWFkZCQgJr1qxh3LhxrsdPr5n/+ecfChUq5OrTxbOC05dsOHv2rNvSEumzzVNSUli1apXrPs2bNyciIiLD+tHt27cnOTmZFStWuD3vHTp04MSJE25LSMTExNCiRQsOHTpEQkKCq71YsWI0atSIPXv2sHv3bld7+r+nbdu2cfDgQVd71apViY+PZ+PGjW5LntSpU4cKFSqwatUq13Ocfv0GDRqwbNkyt+cgkPvUsGFDihcvrj6pT5n2qUGDBqxatcr1vpEf+pQfXycz9yns1Clav/QSAH/06sVvq1d7pE/ptYUnmGvQNjISzp3L2bmrVkH79pc/b8kSuOWWyz6uYRikpqYSHh6eo0HbqKgorr32Wre2vy5ZT+zcuXP07duXAQMGZLh/+fLlOX/+PG3btqVt27Z89NFHFCtWjIMHD9K2bVtX0XclQkND6datG7NmzaJTp058/PHHTJ069Yqv500hISGu4/TXIbO29KLx3LlzjB49mk6dOmW4Vnh4eF6Gmm84HA4OHjxIjRo1NHAlpqLcF79iGPDEE/Dvv3D99TBkSJ48jHaIzqh58+YkJCRw4sQJZs6cSZcuXdiwYQPFixcHYPDgwa5za9WqRWhoKH379mX8+PGEhYVluN7QoUPd7pOYmEi5cuVo06YNsbGxwH9fy6tatSoHDhwgOjraWYMCnDvnGrRLl74Bmlv76tVYO3S4bP8cixdDkyau6wBu106PJzgigpj/DRTGxMS4YgwNDc20FouJiaFOnTpu7SdPnnTdFhsbS3JyMn369GHgwIEZ+lShQgXOnz/P3XffTZs2bVy18KFDh2jXrh2hoaHExsa6BrqtVisWi8VV38XGxrra0n/Cw8Ndz3G3bt348MMP6dy5M59//jmTJ08mJiYm2z4VLlzY1acbb7yRKlWq0KRJE1q2bOn6GyE8PNwVQ1RUFDabzfW4p06dApzr3KbHkf5vzuFwuNoufdz09pCQEIKDg139uThf0h8zfYA5NjaW2NhYLBYL//77b4bHTd9zIv15PH/+PKNGjaJTp04Zcql48eKu/w9GRUVlyL301/TSdrvdTnh4OLfcckuG+Npf8ndacHAwMTExGdoBihYt6tae/pjlypWjdOnSGdorV65MpUqVXO3puVqrVi1qXPRhV3p7/fr13WJP7+stt9ziak9LS2P58uU4HA7atGmTIfZA7NPF7eqT+pRZnxwOB4mJibRu3dr1nhjofYL89zqZuU9BffpgPX8e6tWj4ogRxF00Vnc1fUqvVzzBXIO2FgvkdI2wNm2gbFnnRguZreVlsThvb9MmZzsuZ3aNq3TDDTfw22+/ZRjcTbd9+3ZOnjzJhAkTKFeuHECGTRMyY7PZ+Pnnn6lfvz4Au3fv5vTp01SrVs11zsMPP0yNGjV46623sNlsmQ5q5gc33HADu3fvzvI5FhERCQjz5jk3jQoOhlmz4KIBJXH+IRAUFMTRo0fd2o8ePZph9mK6kiVL5uj89A/ir732Wm666SYqV67Me++9x9ChQzO9boMGDbDZbOzfv5+qVatmuD0sLCzTwdyQkBC3gUJw/mFhsViwWq3/ra8WFYUFyGwKgVt727Y5qoWtbdtmqIWzvPb/BhjTBw0vPb5UVmvCpffnhhtuYOfOnVnWaTt27ODkyZO8/PLLrlp4y5YtbtdIf4z0OKxWKzabjS1btmSohatXr+46P70Wfvvtt7HZbHTu3PmyfUq/PjgHRQcOHMjTTz/N1q1bsVqtFCtWjKNHj7ru+8cff5CUlOSK8+Lrp1/n4ucoq+fr0rgu/j2zcy5+fjJ73i89P/21+P3336lcuXKmMVz8OFk9N1m1Z5bbl/6eXXtmfcmuPSgoKNMPW7NqT/9DPaftuYk9q3b1SX0C/+9TWlqa6zqZ/f8pEPsE+e91ApP2afNmmDPH2Th1KsFZrDF/JX3K6rYroRWhsxIUBOmzRy8tINJ/nzIlZwO2eWTIkCGsXbuWfv36kZCQwJ49e/jyyy9dmwKUL1+e0NBQ3njjDf7880+++uorxowZc9nrhoSE0L9/fzZs2MDmzZvp0aMHN910k6twBahWrRo33XQTQ4YM4b777nN92p7fjBgxgg8++IDRo0fz66+/snPnTj799FOGDx/u69BERERy5vhx5yxbgOefh9q1fRuPHwoNDaVu3bp8//33rjaHw8H3339Pw4YNM71Pw4YN3c4HWL58eZbnX3zdCxcuZHl7QkICVqvVNRPXZ1QL53kt3LdvX37//Xc+//xzAFq0aMG0adPYunUrP//8M48++qhH//DLS6qZRUQkYBiGczNVw4AHHoDL1G6+pEHb7HTqBPPnQ5ky7u1lyzrbczG71GKxEBYW5rHdcsE5rfvHH3/k999/p0mTJlx//fWMGDHCNTW8WLFizJ49m3nz5nHdddcxYcIEXn311cteNzIykiFDhnD//ffTuHFjoqOjmTt3bobzevfuTWpqKr169brsNR0OR5afUPiztm3bsmjRIpYtW0a9evW46aabmDx5MhUqVPB1aAHDarVStWpV7RoqpqPcF7/Rrx+cOAG1asGwYXn6UIGc74MHD2bmzJnMmTOHnTt38thjj3H+/Hl69uwJOL8Of/Hs2IEDB7J06VJee+01du3axahRo/j5559dA4bnz59n2LBhrF+/ngMHDrB582Z69erF4cOHueeeewBYt24dU6ZM4ZdffuHPP//ko48+4sknn+TBBx+kUKFC3n8SLuXBWhg8Xw8HUi2cmcKFC9OtWzdGjRqFw+Hgtddeo1y5cjRp0oT777+fp59+OmD2T1DNnD3VBGJGynvxW3Pnwpo1ziVUJ0zw+OU9mfMW49JFpwJIYmIiBQoU4MyZMxnWcEpJSWHfvn1UrFjx6tcetdth9Wo4cgRKlXKu25VP1yecPXs2gwYN4vTp05c9d8yYMcybN49t27Zd9tz4+Hgefvhhnn76aQ9EGbg8mpciIiKX8/nn0Lmzs27ZuBFuuCFPHy672iwQTJs2jVdeeYV//vmHOnXq8Prrr9OgQQMAmjVrRlxcHLNnz3adP2/ePIYPH87+/fupXLkyEydOdK23lpKSwv3338+GDRs4ceIERYoUoV69egwfPty14euWLVt4/PHH2bVrFxcuXKBixYo89NBDDB48ONMlEDLjlXpYtXCmclMLS+6pbhYREY9LSoL4eDh0CF58EV54weMP4cl6OPCmPvpCUBA0a3ZVlzAMg/PnzxMVFeXR2ba+cO7cOfbv38+0adN46X877WXl2LFjfPPNN+zevZuWLVt6KULxJzabjY0bN1K/fv2AnG0tcqWU++JzJ07A4487j597Ls8HbAG3XXQDUb9+/VwzZS+1cuXKDG333HOPa9bspcLDw1mwYEG2j3fDDTewfv36XMfpdR6ohSH/1MO5qYVFQDWBmJPyXvzSK684B2zLl4c8mlToyXpY/3K8xDAMbDYbhmEEdJEKzj9oPvnkEzp27HjZr4O1a9eOU6dO8frrr3P99dd7KULxJ4ZhcPz48Qw7SYvkd8p98bkBA+DYMahePU9mEWRG+S7ZyS/1cG5qYRFQTSDmpLwXv3PoELz8svP41Vchj/Zm8mTOa3ERcdOjR4/Lfh1s9uzZXLhwgblz52a6G9/FtmzZwr59++jfv78HoxQREZFsffEFfPIJWK0waxbk8Kv2Imbn6VpYRERE/MSQIZCc7FzmqXNnX0eTIxq0FREREclP/v0XHn3Uefzss/C/9VNFREREREzpp5+cExosFpg61fnfAKBBWy+xWCxEREQE9FfBRK5EUFAQderU0UwUMR3lvvjMoEFw9ChUqwYjR3r1oZXvkh3Vw2JWqgnEjJT34jccDmd9DNC7N+Tx0p2ezPl8v6atw+HwdQiAs0jN6S7Akn/5Sz56k9VqpUKFCr4OQ8TrlPviE4sWwYcfOpdFeP998PKO61ar5gP4I3+pP1QPS274S956gmoCMSPlvfiNOXNg82aIjQUvbCDqyXo43w7ahoaGYrVa+fvvvylWrBihoaE+/VQ/v+yWK1fGMAxSU1M5fvw4VquV0NBQX4fkNTabjVWrVnHLLbdo11AxFeW+eN3p09C3r/N48GC46Savh+DJ3XLl6qkelkCUH+tm1QRiRsp78QuJiTB0qPN4xAgoUSLPH9KT9XC+/ZdjtVqpWLEiR44c4e+///Z1OBiGQXJysr4SZnKRkZGUL1/eVDORDMPg7Nmz2jVUTEe5L143eDD8/TdUqQIvvuiTEJTv/kX1sASy/FQ3qyYQM1Lei18YN865bFjlytC/v1ce0pM5n28HbcE5u6B8+fLYbDbsdrtPY0lLS3N9yhQSEuLTWMQ3goKCCA4O1h8pIiLied98A7NmOTdVmDULIiJ8HZH4CdXDEohUN4uIyFX74w+YPNl5PGkSBOA3N/L1oC04184KCQnxeWEYFBSEzWYjPDzc57GIiIhIPnLmDPTp4zweOBAaNfJtPOJ3VA+LiIiI6Tz9NKSmQps20KGDr6O5IoH/XZMAERQURMOGDbVzopiOcl/MSrkvXvP00/DXX1CpEowd69NQlO+SHb0vilkp98WMlPfiU999B19+CUFBztm2XvzmhidzPt/PtPUXVquV4sWL+zoMEa9T7otZKffFK5Ytg3ffdR6//z5ERvo0nPyw9qTkHb0vilkp98WMlPfiMzYbDBrkPH7iCbjuOq8+vCfrYVXWXpKWlsbixYtJS0vzdSgiXqXcF7NS7kueO3sWHnnEedy/P9xyi2/jAeW7ZEvvi2JWyn0xI+W9+Mw778Cvv0LhwjBypNcf3pM5r0FbL7LZbL4OQcQnlPtiVsp9yVPPPgsHD0LFijB+vK+jEckRvS+KWSn3xYyU9+J1//4LI0Y4j8eMcQ7cBjC/GbSdMGECFouFQelTmEVEREQkcz/8AG+/7Tx+7z2IivJtPCIiIiIivjZqlHPgtkaN/zbqDWB+MWi7adMm3nnnHWrVquXrUERERET827lz0Lu38/ixx6B5c9/GIyIiIiLia7/9Bm+95TyeMgWCA38bL58P2p47d44HHniAmTNnUqhQIV+Hk2eCg4Np3rw5wfkgaURyQ7kvZqXclzzz3HOwfz9UqAAvv+zraNwo3yU7el8Us1Luixkp78WrDMO5+ZjdDh07QsuWPgvFkznv80HbJ554gg4dOtCqVStfh5LnIiIifB2CiE8o98WslPvicT/+CG++6Tx+912IifFtPCK5pPdFMSvlvpiR8l68ZtEiWL4cQkPh1Vd9HY3H+PQjj08//ZQtW7awadOmHJ1/4cIFLly44Po9MTERcO7Mlr47m9VqJSgoCLvdjsPhcJ2b3m6z2TAMw9UeFBSE1WrNsv3SXd/SR8wvXVA7q/aQkBAcDgcpKSksX76c1q1bExoaSnBwMA6HA7vd7jrXYrEQHBycZez+1qfMYlef1KdL+5SWlsby5ctp3749lwrUPl3cnl9eJ/XJ8326cOEC3377La1btyYkJCRf9Ck/vk4B1afz5wnu3RsLQJ8+OFq0wH7R9f2hT8nJyYhkxWazsWTJEtq3b09ISIivwxHxGuW+mJHyXrwmNRUGD3YeP/kkVKrk03A8uQGfzwZtDx06xMCBA1m+fDnh4eE5us/48eMZPXp0hvZly5YRGRkJQPny5bn++uvZtm0bBw8edJ1TtWpV4uPj2bhxI8ePH3e116lThwoVKrBq1SrOnj3ram/YsCHFixdn2bJlbk948+bNiYiIYMmSJW4xtG/fnuTkZFasWOFqCw4OpkOHDpw4cYJ169YBsHz5cmJiYmjRogWHDh0iISHBdX6xYsVo1KgRe/bsYffu3a52f+8ToD6pT9n2KV1+6lN+fJ3UJ8/2ae/evYDzfT+/9Ck/vk6B1Kca775Lpb17SSlenPBXXvHLPiUlJSEiIiIi4jWvvw5//AElS8Lzz/s6Go+yGBdPofCiL774grvuuougoCBXm91ux2KxYLVauXDhgtttkPlM23LlynHixAliY2MB/50ho5m26pNZ+6SZtuqTWfuUkpKimbbqk8f65Fi1iqAWLbAYBvbFiwlq394v+3Ty5ElKlSrFmTNnXLWZ5K3ExEQKFCgQEM95WlqaZl2JKSn3xYyU9+IVR49ClSqQmAjvvw89e/o6Ik6ePEnRokU9Upv5bKZty5Yt2b59u1tbz549iY+PZ8iQIRkGbAHCwsIICwvL0B4SEpLhTSAoKCjTa2S1IHBW7Vm9ueSm3Wq1utpDQkJcj2W1WrFaMy4rnFXs/tanzGJXn9Sn7NrVJ/XJbH1Kv9bF1wv0PuXH18nv+5SSgrVPH+cGC716EfS/D8H8sU/6o0xEREREvGb4cOeAbd260L27r6PxOJ/NtM1Ms2bNqFOnDlOmTMnR+YE0s8AwDGw2G8HBwVgsFl+HI+I1yn0xK+W+eMzTT8Nrr0Hp0vDrr1CwoK8jytKZM2coWLBgQNRm+YXqYRH/p9wXM1LeS57butU5WGsYsGYNNGrk64gAz9bDGadnSJ7R5hxiVsp9MSvlvly1detg0iTn8YwZfj1gK5ITel8Us1Luixkp7yXPGAYMHOj87333+c2Araf51aDtypUrczzLNtDYbDZWrFjh0V3kRAKBcl/MSrkvVy052bkul2FAt27QoYOvI7os5btkR++LYlbKfTEj5b3kqXnzYPVqiIiAl1/2dTRuPJnzfjVoKyIiIiL/M2oU7N4NpUpBPv1QW0REREQkV5KT4ZlnnMdDhkC5cr6NJw9p0FZERETE32zcCK++6jx++20oVMi38YiIiIiI+INXX4WDB52DtemDt/mUBm29KKtdmEXyO+W+mJVyX67IhQvOZREcDnjgAbjjDl9HJOIxel8Us1Luixkp78Xj/voLJkxwHr/yCkRG+jaePGYxDMPwdRBXKpB2yxURERHJkeefh3HjoEQJ+PVXKFLE1xHlmGoz79NzLiIiIqbx4IPw0Udw882wahVYLL6OKANP1maaaeslDoeDY8eO4XA4fB2KiFcp98WslPtyRTZv/m8zhbfeCqgBW0D5LtnS+6KYlXJfzEh5Lx63dq1zwNZice734IcDtuDZeliDtl5it9tZt24ddrvd16GIeJVyX8xKuS+5lprqXBbBboeuXaFTJ19HlGvKd8mO3hfFrJT7YkbKe/EohwMGDnQe9+oFdev6Np5seDLnNWgrIiIi4g/GjoXt26FYMXjjDV9HIyIiIiLiHz78EH7+GWJinDWzSWjQVkRERMTXEhKc69gCvPmmc+BWRERERMTszp6F555zHr/wgnPfB5PQoK2XWCwWYmJisPjpmhsieUW5L2al3JccS0uDHj3AZoO774Z77vF1RFdM+S7Z0fuimJVyX8xIeS8eM24c/PMPVKoEAwb4OprL8mTOWwzDMDx2NS/TbrkiIiIS8F58EUaOdG469uuvAT17QLWZ9+k5FxERkXzrzz+hWjXn3g9ffgl33OHriC7Lk7WZZtp6icPh4MCBA9o5UUxHuS9mpdyXHNm2DV56yXn8xhsBPWALnt0tV/IfvS+KWSn3xYyU9+IRTz/tHLBt3Rpuv93X0eSIJ3Neg7ZeYrfbSUhI0M6JYjrKfTEr5b5cVloa9Ozp/G/HjnDvvb6O6Kop3yU7el8Us1Luixkp7+Wq/fADLFwIQUEweTIEyFIbnsx5DdqKiIiI+MIrr8CWLVCoEEyfHjCFqIiIiIhInrLZYNAg5/Fjj0H16j4Nx1c0aCsiIiLibb/+CqNHO49ffx1KlvRtPCIiIiIi/mLmTNi+3Tm5YdQoX0fjMxq09RKLxUKxYsW0c6KYjnJfzEq5L1my2ZzLIqSmwm23wQMP+Doij1G+S3b0vihmpdwXM1LeyxU7dQpeeMF5/OKLzs16A4gnc95iGIbhsat5mXbLFRERkYAzcSIMGQIFCjhn3JYp4+uIPEa1mffpORcREZF8ZdAgmDrVuSRCQgIEB/s6olzxZG2mmbZeYrfb2bVrlxbhFtNR7otZKfclU7t2wYgRzuMpU/LVgC1oIzLJnt4XxayU+2JGynu5Ijt3wptvOo8nTw64AVvQRmQByeFwsHv3bhwOh69DEfEq5b6YlXJfMrDbncsiXLgAt94K3bv7OiKPU75LdvS+KGal3BczUt5LrhkGPPmkcymxO+6A1q19HdEV8WTOa9BWRERExBumTIH16yE2Ft55B7TGm4iIiIiI05Il8O23EBICr73m62j8ggZtRURERPLa77/D8OHO49deg3LlfBuPiIiIiIi/SE2FwYOdx4MGwbXX+jQcf6FBWy+xWq2UL18eq1VPuZiLcl/MSrkvLnY79OoFKSnOr3n17u3riPKM8l2yo/dFMSvlvpiR8l5yZdo05ySH4sX/m+gQoDyZ8xbDMAyPXc3LtFuuiIiI+L2pU50zBqKj4ddfoXx5X0eUZ1SbeZ+ecxEREQlox45B5cqQmAjvvhvwExw8WZvpIw8vsdvtbN26VTsniuko98WslPsCwB9/wNChzuNXX83XA7bg2d1yJf/R+6KYlXJfzEh5Lzn2wgvOAdsbboAePXwdzVXzZM5r0NZLHA4HBw8e1M6JYjrKfTEr5b7gcDhnCiQnQ4sW0KePryPKc8p3yY7eF8WslPtiRsp7yZGEBJg503k8dSoEBfk0HE/wZM5r0FZEREQkL7z1FqxaBVFRzq96WSy+jkhERERExD8YhnMJMcOArl3h5pt9HZHf0aCtiIiIiKft2wfPPec8fvllqFjRt/GIiIiIiPiTzz+HH3+E8HCYONHX0fglDdp6idVqpWrVqto5UUxHuS9mpdw3sfRlEc6fh6ZN4bHHfB2R1yjfJTt6XxSzUu6LGSnvJVvJyfD0087jIUPy1b4Pnsz5YI9dSbIVFBREfHy8r8MQ8TrlvpiVct/EZsyAFSsgIgLeew9M9MdKUD5Yh0zyjt4XxayU+2JGynvJ1qRJcOAAlC0Lzz7r62g8ypP1sHn+ivAxm83G2rVrsdlsvg5FxKuU+2JWyn2TOnAAnnnGeTx+PFSq5Nt4vEz5LtnR+6KYlXJfzEh5L1k6fBjGjXMeT5wIkZG+jcfDPJnzGrT1EsMwOH78OIZh+DoUEa9S7otZKfdNyDDgkUfg3DnnRgr9+/s6Iq9Tvkt29L4oZqXcFzNS3kuWnnsOkpKgUSO4915fR+Nxnsx5DdqKiIiIeMJ778Hy5c7NFN5/31TLIuQXb775JnFxcYSHh9OgQQM2btyY7fnz5s0jPj6e8PBwatasyZIlS9xuHzVqFPHx8URFRVGoUCFatWrFhg0b3M75999/eeCBB4iNjaVgwYL07t2bc+fOebxvIiIiIj63fj383/85j6dOBYvFt/H4Of01ISIiInK1Dh2CwYOdx2PHQuXKvo1Hcm3u3LkMHjyYkSNHsmXLFmrXrk3btm05duxYpuevXbuW++67j969e7N161Y6duxIx44d2bFjh+ucKlWqMG3aNLZv385PP/1EXFwcbdq04fjx465zHnjgAX799VeWL1/OokWLWLVqFX369Mnz/oqIiIh4lcMBAwc6j3v2hBtv9G08AcBiBPBc9cTERAoUKMCZM2eIjY31dTjZcjgcHDp0iHLlymn3RDEV5b6YlXLfRAwD2reHpUuhYUNYvRpMuiHX6dOnKVSoUEDUZpdq0KAB9erVY9q0aYDz33C5cuXo378/zz33XIbzu3btyvnz51m0aJGr7aabbqJOnTq8/fbbmT5Geu363Xff0bJlS3bu3Ml1113Hpk2buPF/f7gsXbqU9u3b89dff1G6dOnLxq16WMT/KffFjJT3ksGHH0K3bhAdDXv2QMmSvo4oT3iyHta/HC+xWq1UqFBBb1ZiOsp9MSvlvonMnu0csA0Lcy6LYNIBWyBg8z01NZXNmzfTqlUrV5vVaqVVq1asW7cu0/usW7fO7XyAtm3bZnl+amoqM2bMoECBAtSuXdt1jYIFC7oGbAFatWqF1WrNsIxCfqD3RTEr5b6YkfJe3Jw7B0OGOI+HD8+3A7bg2Xo42GNXkmzZbDZWrVrFLbfcQnCwnnYxD+W+mJVy3yQOH4Ynn3Qev/gixMf7Nh4fC9Qdok+cOIHdbqdEiRJu7SVKlGDXrl2Z3ueff/7J9Px//vnHrW3RokXce++9JCUlUapUKZYvX07RokVd1yhevLjb+cHBwRQuXDjDddJduHCBCxcuuH5PTEwEIC0tjbS0NMD5x0JQUBB2ux2Hw+E6N73dZrO5bZIRFBSE1WrNsj39uhfHCBlf76zaQ0JCcDgcXLhwgbVr19KoUSNCQkIIDg7G4XBgt9td51osFoKDg7OM3d/6lFns6pP6dGmfbDYba9eupWnTphk2qAnUPl3cnl9eJ/XJs30C+PHHH2nUqJHrmoHep/z4OnmrT9aXXiLoyBGMSpWcSyQYRsD3KavYk5OT8RT9FeklhmFw9uxZ7ZwopqPcF7NS7puAYUDfvnDmDNSv/9+atiamfM+oefPmJCQkcOLECWbOnEmXLl3YsGFDhsHanBo/fjyjR4/O0L5s2TIiIyMBKF++PNdffz3btm3j4MGDrnOqVq1KfHw8GzdudFtXt06dOlSoUIFVq1Zx9uxZV3vDhg0pXrw4y5Ytc/sDqnnz5kRERGTYeK19+/YkJyezYsUKV1twcDAdOnTgxIkTrlnIy5YtIyYmhhYtWnDo0CESEhJc5xcrVoxGjRqxZ88edu/e7Wr39z4B6pP6lG2fwPkemZ/6lB9fJ/XJc32qV68e586dY9myZfmmT/nxdfJGn1Z/8AFNJ00CYOM991AtNZWIoKCA7lN2r1NSUhKeojVtvSQtLY0lS5bQvn17QkJCfB2OiNco98WslPsmkL4uV2gobN0K113n64h87uTJkxQtWjQgarOLpaamEhkZyfz58+nYsaOrvXv37pw+fZovv/wyw33Kly/P4MGDGTRokKtt5MiRfPHFF/zyyy9ZPlblypXp1asXQ4cO5f333+epp57i1KlTrtttNhvh4eHMmzePu+66K8P9M5tpW65cOU6cOOF6zv111k9KSgrLly+ndevWhIaG+u0Mmdz0KVBm/ahPvu1TWloay5cvp3379lwqUPt0cXt+eZ3UJ8/2yTAMlixZQuvWrV21cKD3KT++Tt7ok6NTJ6wLF+Jo0QL7N98Q/L98COQ+Zfc6nTx5klKlSnmkHtZMWxEREZHcOnLkv91vR47UgG2ACw0NpW7dunz//feuQVuHw8H3339Pv379Mr1Pw4YN+f77790GbZcvX07Dhg2zfaz0ZQLSr3H69Gk2b95M3bp1Afjhhx9wOBw0aNAg0/uHhYURFhaWoT0kJCTDB0RBQUEEZbLGclZLtmTVntUHT7lpt1qtrvb0pRHS2zNb+y2r2P2tT5nFrj6pT9m1q0/qk1n6lD6oFQj/fzLz6wR53KeVK7EuXAhWK9YpU7CGhmYbe1btftWnLGJMb/fkhB0N2npJUFAQDRs2zDQ5RPIz5b6YlXI/HzMMeOwxOHUK6taFZ5/1dUR+I5DzffDgwXTv3p0bb7yR+vXrM2XKFM6fP0/Pnj0B6NatG2XKlGH8+PEADBw4kKZNm/Laa6/RoUMHPv30U37++WdmzJgBwPnz5xk7dix33HEHpUqV4sSJE7z55pscPnyYe+65B4Bq1arRrl07HnnkEd5++23S0tLo168f9957L6VLl/bNE5GH9L4oZqXcFzNS3gt2+3+THB59FGrW9G08XuLJnNegrZdYrdYrXrtMJJAp98WslPv52KefwpdfQkgIzJoF2mjOJbOZDoGia9euHD9+nBEjRvDPP/9Qp04dli5d6tps7ODBg279a9SoER9//DHDhw9n2LBhVK5cmS+++IIaNWoAzoJ9165dzJkzhxMnTlCkSBHq1avH6tWrqV69uus6H330Ef369aNly5ZYrVbuvvtuXn/9de923kv0vihmpdwXM1LeC+++C9u2QaFC8OKLvo7GazxZD2tNWy9JS0tj2bJltGnTRmsbiqko98WslPv51NGjzqUQ/v3XWXy+8IKvI/IrgbqmbSBTPSzi/5T7YkbKe5M7fRoqV4YTJ2DqVBgwwNcReY0n6+HAnQ4RgC5dTFnELJT7YlbK/XzGMODxx50DtnXqwHPP+ToikYCj90UxK+W+mJHy3sRefNE5YFutmnNZMbkiGrQVERERyYl582DBAudyCLNmOZdHEBERERGR/+zaBW+84TyePFk181XQoK2IiIjI5Rw/Dk884TweNsw501ZERERERNw99RTYbHDbbdC2ra+jCWha09ZLDMPg7NmzxMTEYLFYfB2OiNco98WslPv5TNeu8Nlnzl1vf/4ZQkN9HZFfOnPmDAULFgyI2iy/UD0s4v+U+2JGynuT+uYbaN/eObt2xw6oUsXXEXmdJ+thzbT1ooiICF+HIOITyn0xK+V+PrFggXPANigIZs/WgK3IVdD7opiVcl/MSHlvMmlp8OSTzuMBA0w5YOtpGrT1EpvNxpIlS7QQt5iOcl/MSrmfT5w8+d/mCc89Bzfc4Nt4/JzyXbKj90UxK+W+mJHy3oTefBN274ZixeCFF3wdjc94Muc1aCsiIiKSlQED4NgxqF7d1MWniIiIiEiWjh+HUaOcx+PGQYECPg0nv9CgrYiIiEhmvvwSPv4YrFaYNQvCwnwdkYiIiIiI/xkxAs6ccW7W27Onr6PJNzRoKyIiInKpf/+FRx91Hj/zDNSr59t4RERERET80S+/wIwZzuOpU537QIhHWAzDMHwdxJUKtN1ybTYbwcHB2jlRTEW5L2al3A9w3bvDBx9AfDxs3Qrh4b6OKCB4crdcyRnVwyL+T7kvZqS8NwnDgBYtYOVKuOce5+a9JufJelgzbb0oOTnZ1yGI+IRyX8xKuR+gFi92DtimL4ugAVsRj9H7opiVcl/MSHlvAgsXOgdsw8Nh4kRfR5PvaNDWS2w2GytWrNDOiWI6yn0xK+V+gDp9Gvr0cR4/+STcdJNPwwk0ynfJjt4XxayU+2JGynsTSEmBp55yHj/zDMTF+TQcf+HJnNegrYiIiEi6wYPh77+hShUYM8bX0YiIiIiI+KfJk2H/fihTBoYM8XU0+ZIGbUVEREQAli51LodgscD770NEhK8jEhERERHxP3//DWPHOo9ffhmionwbTz6lQVsvCg4O9nUIIj6h3BezUu4HkDNn4JFHnMcDB0Ljxr6NRySf0vuimJVyX8xIeZ+PDR0K589Dw4Zw//2+jibfshiGYfg6iCsVSLvlioiIiB/r0wdmzoRKlWDbNoiM9HVEAUm1mffpORcRERGv2rgRGjT477hePd/G42c8WZtppq2XOBwOjh07hsPh8HUoIl6l3BezUu4HkO++cw7YgnNZBA3YXjHlu2RH74tiVsp9MSPlfT7lcMCAAc7j7t01YJsJT+a8Bm29xG63s27dOux2u69DEfEq5b6YlXI/QJw9C717O4/79YNbbvFtPAFO+S7Z0fuimJVyX8xIeZ9PffwxbNjgXMN23DhfR+OXPJnzGrQVERER8xoyBA4ehIoVYfx4X0cjIiIiIuKfzp1z1s4Azz8PpUv7Nh4T0KCtiIiImNMPP8D06c7jd9+F6GjfxiMiIiIi4q9efhn+/ts52eHJJ30djSlo0NZLLBYLMTExWCwWX4ci4lXKfTEr5b6fO3cOHn7Yefzoo9CihW/jySeU75IdvS+KWSn3xYyU9/nM/v3w6qvO41dfhfBwn4bjzzyZ8xbDMAyPXc3LtFuuiIiIXJH+/WHaNChfHnbsgJgYX0eUL6g28z495yIiIpLnunSBefOgeXP4/nvQYHyWPFmbaaatlzgcDg4cOKCdE8V0lPtiVsp9P/bjj84BW3Aui6ABW49Rvkt29L4oZqXcFzNS3ucjP/7oHLC1WmHKFA3YXoYnc16Dtl5it9tJSEjQzoliOsp9MSvlvp9KSoLevZ3HjzwCrVv7Np58Rvku2dH7opiVcl/MSHmfT9jtMGiQ87hPH6hVy6fhBAJP5rwGbUVERMQ8nn8e9u6FsmXhlVd8HY2IiIiIiP96/31ISICCBeHFF30djelo0FZERETMYc0amDrVeTxzJhQo4Nt4RERERET81ZkzzgkPACNHQrFivo3HhDRo6yUWi4VixYpp50QxHeW+mJVy388kJ0OvXmAY0LMntGvn64jyJeW7ZEfvi2JWyn0xI+V9PjBmDBw/DvHx8MQTvo4mYHgy5y2GYRgeu5qXabdcERERyZFnnoFXX4XSpeHXX51f8RKPU23mfXrORURExON+/x2qVwebDb75RhMecsGTtZlm2nqJ3W5n165dWoRbTEe5L2al3Pcj69fDpEnO4xkzNGCbh5Tvkh29L4pZKffFjJT3Ae6pp5wDtu3ba8A2l7QRWQByOBzs3r0bh8Ph61BEvEq5L2al3PcTKSnO5RAcDnjoIejQwdcR5WvKd8mO3hfFrJT7YkbK+wC2dCksWgTBwf9NfJAc82TOa9BWRERE8q9Ro2DXLihZEqZM8XU0IiIiIiL+Ky0NnnzSedy/P1St6tt4TE6DtiIiIpI/bdoEr7ziPH77bShc2LfxiIiIiIj4s+nTnRMeihaFESN8HY3padDWS6xWK+XLl8dq1VMu5qLcF7NS7vvYhQvQo4dzWYT774c77/R1RKagfJfs6H1RzEq5L2akvA9AJ07AyJHO47FjtQ/EFfJkzlsMwzA8djUv0265IiIikqnhw53FZvHi8NtvUKSIryMyBdVm3qfnXERERDziiSfgrbegdm3YvBmCgnwdUUDyZG2mjzy8xG63s3XrVu2cKKaj3BezUu770ObNMGGC83j6dA3YepHyXbKj90UxK+W+mJHyPsBs3+5cTgyc+0BowPaKeTLnNWjrJQ6Hg4MHD2rnRDEd5b6YlXLfR1JToWdPsNuhSxfo1MnXEZmK8l2yo/dFMSvlvpiR8j6AGAYMGuRcVqxzZ2jWzNcRBTRP5rwGbUVERCT/GDvWOVOgaFGYNs3X0YiIiIiI+Lcvv4QffoCwsP828RW/oEFbERERyR8SEmDcOOfxm29CsWI+DUdERERExK+lpMBTTzmPn34a4uJ8Go6406Ctl1itVqpWraqdE8V0lPtiVsp9L0tLcy6LYLM5l0S45x5fR2RKynfJjt4XxayU+2JGyvsAMWUK/PknlCoFzz3n62jyBU/mvMUwDMNjV/My7ZYrIiIiAIwZAyNGQOHC8NtvUKKEryMyJdVm3qfnXERERK7IkSNQpQqcOwcffAAPPeTriPIFT9Zm+sjDS2w2G2vXrsVms/k6FBGvUu6LWSn3vWj7duegLcAbb2jA1oeU75IdvS+KWSn3xYyU9wFg2DDngG2DBvDAA76OJt/wZM5r0NZLDMPg+PHjBPDEZpErotwXs1Lue4nN5lwWIS0N7rwT7rvP1xGZmvJdsqP3RTEr5b6YkfLez23aBLNnO4+nTgUtY+Exnsx5vSoiIiISuF55BTZvhkKFYPp0sFh8HZGIiIiIiP8yDBg40Hn80EPOmbbilzRoKyIiIoHp119h1Cjn8dSpzg0UREREREQka598AuvWQVQUjB/v62gkGxq09ZKgoCDq1KlDUFCQr0MR8SrlvpiVcj+P2WzQqxekpkKHDvDgg76OSED5LtnS+6KYlXJfzEh576fOn4chQ5zHQ4dCmTK+jScf8mTOB3vsSpItq9VKhQoVfB2GiNcp98WslPt5bNIk2LgRChSAd97Rsgh+wqr10CQbel8Us1Luixkp7/3UxInw118QFweDB/s6mnzJk/WwKmsvsdls/PDDD9o5UUxHuS9mpdzPQ7t2wYgRzuPJkzVDwI8o3yU7el8Us1Luixkp7/3QgQPOQVtw7gsREeHbePIpT+a8Bm29xDAMzp49q50TxXSU+2JWyv08YrdDz55w4QK0awc9evg6IrmI8l2yo/dFMSvlvpiR8t4PDRkCKSnQtCncfbevo8m3PJnzGrQVERGRwDF1KqxfDzExMGOGlkUQEREREbmc1ath7lywWmHKFNXQAUKDtiIiIhIY9uyB5593Hr/2GpQr59t4RERERET8nd0OAwc6jx9+GOrU8Wk4knMatPWSoKAgGjZsqJ0TxXSU+2JWyn0PczigVy/nV7patXIWnOJ3lO+SHb0vilkp98WMlPd+ZPZs2LrVuYHvSy/5Opp8z5M5H+yxK0m2rFYrxYsX93UYIl6n3BezUu572LRp8NNPEB0N776rr3T5KU/uliv5j94XxayU+2JGyns/kZgIw4Y5j0eOhGLFfBuPCXiyHlZl7SVpaWksXryYtLQ0X4ci4lXKfTEr5b4H7d0Lzz3nPH7lFahQwbfxSJaU75IdvS+KWSn3xYyU937ipZfg2DGoUgWeeMLX0ZiCJ3Neg7ZeZLPZfB2CiE8o98WslPse4HBA796QnAzNm0OfPr6OSESugt4XxayU+2JGynsf27PHuekYwOTJEBrq03Ak9zRoKyIiIv5r+nT48UeIioL33nPueCsiIiIiItl7+mlIS4N27aB9e19HI1dAf/mIiIiIf9q3D4YMcR5PmAAVK/o2Hsn33nzzTeLi4ggPD6dBgwZs3Lgx2/PnzZtHfHw84eHh1KxZkyVLlrhuS0tLY8iQIdSsWZOoqChKly5Nt27d+Pvvv92uERcXh8VicfuZMGFCnvRPRERETGLZMvjqKwgOhkmTfB2NXCEN2npJcHAwzZs3JzhYe7+JuSj3xayU+1fJMODhh+H8ebjlFnj8cV9HJDkQyPk+d+5cBg8ezMiRI9myZQu1a9embdu2HDt2LNPz165dy3333Ufv3r3ZunUrHTt2pGPHjuzYsQOApKQktmzZwgsvvMCWLVtYsGABu3fv5o477shwrRdffJEjR464fvr375+nffUVvS+KWSn3xYyU9z5ks8GTTzqP+/WDatV8G4/JeDLnLYZhGB67Wi5Nnz6d6dOns3//fgCqV6/OiBEjuPXWW3N0/8TERAoUKMCZM2eIjY3Nw0ivnmEY2Gw2goODsWjHazER5b6YlXL/Kr3zDjz6KEREwLZtcO21vo5IcuDMmTMULFgwIGqzSzVo0IB69eoxbdo0ABwOB+XKlaN///48l74R3kW6du3K+fPnWbRokavtpptuok6dOrz99tuZPsamTZuoX78+Bw4coHz58oBzpu2gQYMYNGjQFcWteljE/yn3xYyU9z40bRr07w9FijjXtS1UyNcRmYon62GffuRRtmxZJkyYQOXKlTEMgzlz5nDnnXeydetWqlev7svQPM5ms7FkyRLat29PSEiIr8MR8RrlvpiVcv8qHDjgXIMLYPx4DdgGkEDdcCQ1NZXNmzczdOhQV5vVaqVVq1asW7cu0/usW7eOwYMHu7W1bduWL774IsvHOXPmDBaLhYIFC7q1T5gwgTFjxlC+fHnuv/9+nnzyySxnaVy4cIELFy64fk9MTAScyzGk71ZstVoJCgrCbrfjcDjc+hQUFITNZuPieRtBQUFYrdYs2y/dBTk9tktf76zaQ0JCcDgcpKSksHz5clq3bk1oaCjBwcE4HA7sdrvrXIvFQnBwcJax+1ufMotdfVKfLu1TWloay5cvp30ma0oGap8ubs8vr5P65Nk+GYbBkiVLaN26tasWDvQ+BcTrlJiIMWIEFsA+ejSO6Ggs/xs8D9g+BdjrlJycjKf4dND29ttvd/t97NixTJ8+nfXr1+e7QVsRERHJAcOARx6Bc+egcWPnLAGRPHbixAnsdjslSpRway9RogS7du3K9D7//PNPpuf/888/mZ6fkpLCkCFDuO+++9xmXQwYMIAbbriBwoULs3btWoYOHcqRI0eYlMX6c+PHj2f06NEZ2pctW0ZkZCQA5cuX5/rrr2fbtm0cPHjQdU7VqlWJj49n48aNHD9+3NVep04dKlSowKpVqzh79qyrvWHDhhQvXpxly5a5/QHVvHlzIiIi3NbwBWjfvj3JycmsWLHC1RYcHEyHDh04ceKEawB8+fLlxMTE0KJFCw4dOkRCQoLr/GLFitGoUSP27NnD7t27Xe3+3idAfVKfsu1TuvzUp/z4OqlPnutTvXr1AOd7fn7pU0C8Tt98g+XUKc5UqMCPpUphLFkS+H0KsNcpKSkJT/Hp8ggXs9vtzJs3j+7du7N161auu+66DOdkNrOgXLlynDhxwlX8+usnAppZoD6ZtU+aWaA+mbVPKSkpfPvtt67ZBfmhT954nSyzZhHcty9GeDgkJECVKgHfp5zEnl/6dPLkSUqVKhUQX9W/2N9//02ZMmVYu3YtDRs2dLU/++yz/Pjjj2zYsCHDfUJDQ5kzZw733Xefq+2tt95i9OjRHD161O3ctLQ07r77bv766y9WrlyZ7XPz/vvv07dvX86dO0dYWFiG21UP+1efAuXfpvqkelivk/qkmbYmeJ127CCkXj2w27F9+y1G8+aB3ycC73XyZD3s8xWht2/fTsOGDUlJSSE6OpqFCxdmOmALmlngr30C//6UQ33yfZ/S5ac+5cfXSX3ybJ/27t0L/De7ID/0Ka9fp1Uff8wt/9s04dd776V8qVJE2GwB3af8+Dp5a2aBNxUtWpSgoKAMg61Hjx6lZMmSmd6nZMmSOTo/LS2NLl26cODAAX744YfLFu8NGjTAZrOxf/9+qlatmuH2sLCwTAdzQ0JCMizFEhQURFBQUIZz0/+wyGl7Vku85KbdarW62kNCQlyPZbVasVoz7o2cVez+1qfMYlef1Kfs2tUn9cksfUof1AqE/z/li9fJMODZZ8Fuh06dCG7TJsO5AdeniwTS6+TJpfF8PtM2NTWVgwcPcubMGebPn8+7777Ljz/+mO9m2tpsNmw25zoiVqvVbz8RyE2fAuVTDvXJt30yDAPDMAgNDc1x7P7ep4vb88vrpD55vk82m43U1FSCg52bL+SHPuXp62Sx4GjfHuvSpTgaNMC+ciXB/xuYCtg+5cfX6TJ9OnPmDMWKFQu4mbbgHCytX78+b7zxBuDciKx8+fL069cvy43IkpKS+Prrr11tjRo1olatWq6NyNIHbPfs2cOKFSsoVqzYZeP46KOP6NatGydOnKBQDjYO0UZkIv5PuS9mpLz3sq++gjvvhNBQ2LkTrrnG1xGZlic3IvP5oO2lWrVqRaVKlXjnnXcue26gFalnz54lJiZGb1hiKsp9MSvlfi7Nng09e0JYGGzdCtWq+ToiuQKeLFK9be7cuXTv3p133nmH+vXrM2XKFD777DN27dpFiRIl6NatG2XKlGH8+PEArF27lqZNmzJhwgQ6dOjAp59+yrhx49iyZQs1atQgLS2Nzp07s2XLFhYtWuS2/m3hwoUJDQ1l3bp1bNiwgebNmxMTE8O6det48sknufXWW5kzZ06O4lY9LOL/lPtiRsp7L7pwAapXh717YehQGDfO1xGZmifr4YxziH3M4XC4zabNL2w2GytWrMgwg0Ykv1Pui1kp93Ph8GEYNMh5PHq0BmwDWCDne9euXXn11VcZMWIEderUISEhgaVLl7oGWw8ePMiRI0dc5zdq1IiPP/6YGTNmULt2bebPn88XX3xBjRo1ADh8+DBfffUVf/31F3Xq1KFUqVKun7Vr1wLOpQ4+/fRTmjZtSvXq1Rk7dixPPvkkM2bM8P4T4AV6XxSzUu6LGSnvvWjqVOeAbalSzkFb8SlP5rxP17QdOnQot956K+XLl+fs2bN8/PHHrFy5km+//daXYYmIiIi3GAY8+iicOQP16sFTT/k6IjGxfv360a9fv0xvW7lyZYa2e+65h3vuuSfT8+Pi4rjcF9puuOEG1q9fn+s4RURERAD45x946SXn8fjxEBPj23jEo3w6aHvs2DG6devGkSNHKFCgALVq1XLttC0iIiIm8NFHsGiRc/2tWbMgi8X+RURERETkEs8/D2fPOic/PPSQr6MRD/PpX0bvvfeeLx/e67LadU4kv1Pui1kp9y/jyBEYMMB5PHKkcy0uEcnX9L4oZqXcFzNS3uexzZudkx7AuUSC1e9WQJWr5HcbkeVGIG28ICIiIhcxDOjUCb74Am64Adavh5AQX0clV0m1mffpORcRETEhw4AmTWDNGnjgAfi///N1RPI/nqzNNAzvJQ6Hg2PHjuFwOHwdiohXKffl/9m77/ioqvSP458poYcgvRcVCYgQlLKggiCKwv5WFgu67oplbStKsRdEbIgFwb66a9ldWV0RcVVEUGkKK9JEpQhI7wFJQhKSKff3xzWTDJmEJNyZOzP3+3698mLmzJ2Z52Qebs48c+Ycp1LuH8O775oF25QUc4aACrZJQfku5dF5UZxKuS9OpLyPsnffNQu2tWrBE0/YHY2UYGXOq2gbI4FAgCVLlhAIBOwORSSmlPviVMr9cuzdC0WbPT3wAHTpYm88Yhnlu5RH50VxKuW+OJHyPory8uCuu8zL99wDLVvaG4+EsTLnVbQVERGR2Bo5Eg4cgK5d4d577Y5GRERERCRxPPUUbN8OrVvDHXfYHY1EkYq2IiIiEjvvvQfTp4PXC2++qWURREREREQqavt2mDTJvPz001Czpr3xSFSpaBsjLpeL1NRUXC6X3aGIxJRyX5xKuR/B/v1wyy3m5XvvhYwMW8MR6ynfpTw6L4pTKffFiZT3UXL33ZCfb25CdskldkcjEViZ8y7DMAzLHi3GtFuuiIhIArn8cnPThNNOg2XLoFo1uyMSi2lsFnv6nYuIiDjE11/DWWeBywXLl0O3bnZHJBFYOTbTTNsYCQaDbN26VTsniuMo98WplPtHmTHDLNh6PPDGGyrYJinlu5RH50VxKuW+OJHy3mLBIIwaZV6+7joVbOOYlTmvom2MBAIBVq1apZ0TxXGU++JUyv0SDhyAm282L999N5xxhr3xSNQo36U8Oi+KUyn3xYmU9xZ76y1zdm3duvDoo3ZHI+WwMudVtBUREZHoGjUK9u2DTp3gwQftjkZEREREJHFkZ5v7QYA5lm7SxN54JGZUtBUREZHo+e9/4e23we02l0WoXt3uiEREREREEsfjj8PevdC+Pdx6q93RSAypaBsjLpeLRo0aaedEcRzlvjiVch/45Re46Sbz8h13QM+e9sYjUefofJdj0nlRnEq5L06kvLfIxo3w7LPm5cmTtS9EArAy512GYRiWPVqMabdcERGROHb11eb6W+npsHIl1Khhd0QSZRqbxZ5+5yIiIkls6FD48EMYNAg+/RRUBI97Vo7NNNM2RgKBAOvWrdMi3OI4yn1xKsfn/iefmAVblwtef10FW4dwbL5LhTj+vCiOpdwXJ1LeW+Dzz82CrcdjzrJVwTYhaCOyBBQMBlm/fj3BYNDuUERiSrkvTuXo3D90CG680bw8Zgz07m1rOBI7jsx3qTBHnxfF0ZT74kTK++Pk98Po0eblW24xN/SVhGBlzqtoKyIiIta6/XbYudPcLOGRR+yORkREREQksbz6Kvz4I9SvD+PH2x2N2ERFWxEREbHOZ5+ZyyEULYtQq5bdEYmIiIiIJI6DB2HcOPPyI4+YhVtxJBVtY8TtdtO6dWvcbv3KxVmU++JUjsz97Gz485/Ny7fdBmedZW88EnOOynepNEeeF0VQ7oszKe+Pw0MPmYXbzp3hhhvsjkYqycqcdxmGYVj2aDGm3XJFRETiyI03ml/lOvFEWL0aate2OyKJMY3NYk+/cxERkSSyZg106QKBgLkR2bnn2h2RVJKVYzNvVe60efNmFi1axNatW8nLy6NRo0Z069aN3r17U0O7Q0cUCARYvXo1Xbp0wePx2B2OSMwo98WpHJf7n39uFmzBXBZBBVtHivUO0RqTJhbHnRdFfqXcFydS3leBYZibjwUCMHSoCrYJysrxcKWKtm+//TZTp05l2bJlNGnShObNm1OzZk0OHjzIpk2bqFGjBldeeSV33303bdq0sSzIZBAMBtm2bRudO3fWCUscRbkvTuWo3M/JKV4W4ZZboF8/e+MR28Rqh2iNSROTo86LIiUo98WJlPdV8PHHMHcuVKsGTz9tdzRSRVaOhytctO3WrRvVqlXj6quv5v3336dVq1ZhtxcUFLBkyRLeeecdunfvzksvvcSll15qWaAiIiISp+6+G7ZuhbZt4Ykn7I5GkpzGpCIiIpJ0Cgth7Fjz8pgxcNJJ9sYjcaHCRdsnnniCQYMGlXl79erVOeecczjnnHN47LHH2LJlixXxiYiISDybNw9eftm8/Pe/Q5069sYjSU9jUhEREUk6zz0HGzdC06Zw//12RyNxosJF2/IGx0dr0KABDRo0qFJAycrtdtOhQwftnCiOo9wXp3JE7ufmwnXXmZdvugkGDLA3HrFdLPJdY9LE5YjzokgEyn1xIuV9JezdC488Yl5+/HFITbU3HjkuVuZ8lR5pxYoVfP/996HrH374IUOHDuW+++6jsLDQsuCSicfjIT09XWu5iOMo98WpHJH7994LmzdD69bw5JN2RyNxINb5rjFpYnHEeVEkAuW+OJHyvhIeeACys+GMM2DECLujkeNkZc5XqWh744038tNPPwHw888/c/nll1OrVi3ee+897rrrLsuCSyZ+v5/Fixfj9/vtDkUkppT74lRJn/sLF8Lzz5uXX3tNMwIEIOb5rjFpYkn686JIGZT74kTK+wpaudJcYgzMJRI0MznhWZnzVcqGn376iYyMDADee+89+vbty7Rp03jzzTd5//33LQsumRiGwf79+zEMw+5QRGJKuS9OldS5n5cH115rXv7zn+H88+2NR+JGrPNdY9LEktTnRZFyKPfFiZT3FWAYMGqU+e8VV0CfPnZHJBawMuerVLQ1DINgMAjA559/zuDBgwFo1aoVmZmZlgUnIiIiceiBB2DTJmjZEp5+2u5oxME0JhUREZGE9d57sGgR1KwJkybZHY3EoSoVbbt3786jjz7KP//5TxYsWMCQIUMA2Lx5M02aNLE0QBEREYkjX38NU6aYl199FdLSbA1HnE1jUhEREUlI+flw553m5bvvhlat7I1H4lKVirZTpkxhxYoVjBw5kvvvv5+TTz4ZgOnTp9NH07kj8ng8ZGRkaBFucRzlvjhVUuZ+fr65LIJhwNVXw4UX2h2RxJlY57vGpIklKc+LIhWg3BcnUt4fw9NPw7ZtZrG2qHgrScHKnHcZFi62cOTIETweDykpKVY9ZLmys7NJS0sjKyuLunXrxuQ5RUREHOvOO80BZvPm8MMPcMIJdkckcSZexmaxHpPaKV5+5yIiIlJBO3ZAhw7mPhHvvAPDh9sdkVjIyrGZpdvS1ahRwxGD46rw+/18+eWX2jlRHEe5L06VdLn/v//B5Mnm5b/+VQVbiShe8l1j0viUdOdFkQpS7osTKe/Lcc89ZsH2rLPgssvsjkYsZmXOeyt64AknnIDL5arQsQcPHqxyQMnKMAxycnK0c6I4jnJfnCqpcv/IEbjmGggG4U9/gt/+1u6IJE7FIt81Jk1cSXVeFKkE5b44kfK+DIsXw9tvg8tl7hNRwTGNJA4rc77CRdspRZuOAAcOHODRRx9l0KBB9O7dG4AlS5bw2WefMW7cOMuCExERkTgwYQKsWwdNmxZvQiZiE41JRUREJCEFgzBqlHn52mvhjDPsjUfiXoWLtiNGjAhdvvjii3n44YcZOXJkqO22227jhRde4PPPP2fMmDHWRikiIiL2+PZbePJJ8/Irr0D9+vbGI46nMamIiIgkpH/+E5Ytg9RUeOwxu6ORBFCljcjq1KnDqlWrQjv0Ftm4cSMZGRkcPnzYsgDLk0gbLwSDQTIzM2nYsCFut6VLCYvENeW+OFVS5H5BgTkD4Mcf4YorYNo0uyOSOHfo0CFOOOGEmI3N4mVMaieNh0Xin3JfnEh5f5ScHDjlFNizx5wQceeddkckUWLleLhK/3MaNGjAhx9+WKr9ww8/pEGDBscVULJyu900btxYJytxHOW+OFVS5P6jj5oF28aN4bnn7I5GEkCs811j0sSSFOdFkSpQ7osTKe+PMnGiWbA96SS47Ta7o5EosjLnK7w8QkkTJkzgz3/+M/Pnz6dXr14AfPPNN8yePZvXXnvNsuCSic/nY86cOZx//vnazVgcRbkvTpXwub9ihTm4BHjpJWjY0N54JCH4fL6YPp/GpIkl4c+LIlWk3BcnUt6X8PPP8Mwz5uXJk6F6dXvjkaiycjxcpaLt1VdfTceOHXnuueeYMWMGAB07duSrr74KDZilNL/fb3cIIrZQ7otTJWzuFxbCNddAIACXXgoXX2x3RCIRaUyaeBL2vChynJT74kTK+1/dcYc5vj7vPPi//7M7GkkgVSraAvTq1Yu3337bylhEREQkHjz+OKxebc6ufeEFu6MRKZfGpCIiIhK3vvwSPvgAPB549llwueyOSBJIlYu2wWCQjRs3sm/fPoLBYNhtffv2Pe7ARERExAbffVe8m+0LL5jr2YrEMY1JRUREJC75/TB6tHn55pvh1FNtDUcSj8swDKOyd/rf//7HH/7wB7Zu3crRd3e5XAQCAcsCLE8i7ZZrGAY5OTmkpqbi0icr4iDKfXGqhMx9nw969oRVq2DYMJg+XbMBpFKysrKoV69ezMZm8TImtZPGwyLxT7kvTqS8B15+Gf7yF6hfHzZsMP+VpGfleLhKM21vuukmunfvzieffEKzZs2c+x+wkmrWrGl3CCK2UO6LUyVc7k+aZBZs69c3Nx/T33eJcxqTJp6EOy+KWES5L07k6Lz/5RcYN868PGGCCrZSJe6q3GnDhg08/vjjdOzYkXr16pGWlhb2I6X5/X5mzZqlhbjFcZT74lQJl/s//AAPP2xefv55aNLE3ngkIcU63zUmTSwJd14UsYhyX5zI8Xk/YQIcOGAuiXDTTXZHIzFkZc5XqWjbq1cvNm7caFkQIiIiYiO/H66+2lwe4Xe/gyuusDsikQrRmFRERETiztq18OKL5uVnnwVvlbeTEoerUubceuut3H777ezZs4fTTjuNlJSUsNu7dOliSXAiIiISA08/DcuXQ7168MorWhZBEobGpCIiIhJXDAPGjDEnRfzud3DeeXZHJAmsSkXbiy++GIBrr7021OZyuTAMwzGbPoiIiCSFNWtg/Hjz8tSp0KyZvfGIVILGpCIiIhJXZs2Czz6DlBR45hm7o5EE5zKO3mq3ArZu3Vru7W3atKlyQJWRaLvl+v1+vF6vNskQR1Hui1MlRO77/XDmmbB0KQweDB9/rFm2clys3C23IuJlTGonjYdF4p9yX5zIkXlfWAinnQY//QR33glPPml3RGIDK8fDVZpp64QBcDTk5+eTmppqdxgiMafcF6eK+9x/9lmzYJuWBq++qoKtJByNSRNP3J8XRaJEuS9O5Li8f+EFs2DbuDE88IDd0UgSqNJGZACbNm3i1ltvZeDAgQwcOJDbbruNTZs2WRlbUvH7/cybN8+5OyeKYyn3xaniPvfXrYNx48zLkydDixb2xiNJwY5815g0ccT9eVEkSpT74kSOy/t9+2DCBPPy449DnH/7RaLHypyvUtH2s88+o1OnTixdupQuXbrQpUsXvvnmG0499VTmzp1rWXAiIiISBYEAXHstFBTAoEFwzTV2RyRSJRqTioiISFwYNw6ys+H00+Hqq+2ORpJElZZHuOeeexgzZgxPPPFEqfa7776b87Q7noiISPx67jlYsgRSU+G117QsgiQsjUlFRETEdqtWmWNqMDf29XhsDUeSR5Vm2q5du5brrruuVPu1117LmjVrjjuoZOX1VqlGLpLwlPviVHGZ+xs2wH33mZefeQZatbI3HpHjYPWY9MUXX6Rt27bUqFGDXr16sXTp0nKPf++990hPT6dGjRqcdtppzJo1K3Sbz+fj7rvv5rTTTqN27do0b96cq666il27doU9xsGDB7nyyiupW7cu9erV47rrruPw4cOVjj1RxOV5USQGlPviRI7Ie8OA0aPNf4cPh7POsjsiSSJVKto2atSIVatWlWpftWoVjRs3Pt6YklJKSgpDhgwhJSXF7lBEYkq5L04Vl7kfDMJ118GRIzBwIPz5z3ZHJEkm1vlu5Zj03XffZezYsYwfP54VK1bQtWtXBg0axL59+yIev3jxYq644gquu+46Vq5cydChQxk6dCg//PADAHl5eaxYsYJx48axYsUKZsyYwfr16/nd734X9jhXXnklP/74I3PnzuXjjz9m4cKF3HDDDZWKPVHE5XlRJAaU++JEjsn799+HBQugRg148km7o5E4YGXOuwzDMCp7p4cffphnn32We+65hz59+gDw9ddfM2nSJMaOHcu4oo1Noiw7O5u0tDSysrKoG+eLPAeDQTIzM2nYsCFud5X3fxNJOMp9caq4zP3nn4fbboM6deD776FtW7sjkiRz6NAhTjjhhJiNzawck/bq1YsePXrwwgsvAOb/4VatWnHrrbdyzz33lDp++PDh5Obm8vHHH4fafvOb35CRkcErr7wS8Tm+/fZbevbsydatW2ndujVr166lU6dOfPvtt3Tv3h2A2bNnM3jwYHbs2EHz5s2PGbfGwyLxT7kvTuSIvM/Ph44dYetWGD8eHnrI7ogkDlg5Hq7S/5xx48bx4IMP8vzzz9OvXz/69evHCy+8wEMPPcQDDzxwXAElq0AgwJIlSwgEAnaHIhJTyn1xqrjL/U2boKjw9OSTKthKVMQ6360akxYWFrJ8+XIGDhwYanO73QwcOJAlS5ZEvM+SJUvCjgcYNGhQmccDZGVl4XK5qFevXugx6tWrFyrYAgwcOBC3280333xT4fgTRdydF0ViRLkvTuSIvJ882SzYtmwJd91ldzQSJ6zM+SotMOJyuRgzZgxjxowhJycHgNTUVMuCEhEREQsFg+ZSCHl5cM45cOONdkckYgmrxqSZmZkEAgGaNGkS1t6kSRPWrVsX8T579uyJePyePXsiHn/kyBHuvvturrjiitCsiz179pRaxsHr9VK/fv0yH6egoICCgoLQ9ezsbMBcQ9fn8wFmwdnj8RAIBAgGg6Fji9r9fj8lv2zn8Xhwu91lthc9bskYAfx+f4XaU1JSCAaDocfx+Xy4XC68Xi/BYDDszU1Re1mxx1ufIsWuPqlPR/epZD+SpU8l29Un9SlSe5GSz5vofQp7nXbuxPv447gAnnySQPXqBCP0NaH6dFTsSfE62dCno287HlUq2m7evBm/30/79u3DBsYbNmwgJSWFtpq9IyIiEj9eeQXmz4dateDvf4dk/YqaOE6ijEl9Ph+XXXYZhmHw8ssvH9djTZw4kQkTJpRqnzNnDrVq1QKgdevWdOvWjdWrV7Nt27bQMR06dCA9PZ2lS5eyf//+UHtGRgZt2rRh4cKFoeI3QO/evWncuDFz5swJewPVv39/atasGbbxGsDgwYPJz89n3rx5oTav18uQIUPIzMwMzUKeO3cuqampDBgwgO3bt4etS9yoUSP69OnDhg0bWL9+fag93vsEqE/qU7l9KpJMfUrG10l9sq5PPXr0AMxzfrL0qeTrdPqUKbTKy+Nw167UufxyVq9alfB9KpJMr5MdfcrLy8MqVVrTtl+/flx77bWMGDEirP1f//oXf/vb35g/f75V8ZUrkdbw8vv9LFy4kL59+zpjB0WRXyn3xaniJve3bIHOnSE3F557Dm691b5YJOkdPHiQBg0axGxsZtWYtLCwkFq1ajF9+nSGDh0aah8xYgSHDh3iww8/LHWf1q1bM3bsWEaPHh1qGz9+PDNnzuS7774LtRUVbH/++We+/PJLGjRoELrt9ddf5/bbb+eXX34Jtfn9fmrUqMF7773H73//+1LPG2mmbatWrcjMzAz9zuN1hkxBQQGLFy+mT58+pKSkxO0Mmcr0KVFm/ahP9vbJ7/ezePFi+vXrx9FvvxO1TyXbk+V1Up+s7RPAggUL6NOnT+gxE71PRa+TsWQJ3rPPBiDwv//h6dUr4fuUTLlnd58OHjxI06ZNLRkPV6loW7duXVasWMHJJ58c1r5x40a6d+/OoUOHjiuoikqkoq2IiEjMGQacdx588QWcfbY521azbCWKYj02s3JM2qtXL3r27Mnzzz8PmBuotG7dmpEjR5a5EVleXh4fffRRqK1Pnz506dIltBFZUcF2w4YNzJs3j0aNGoU9RtFGZMuWLeOMM84AzFl4F1xwQVJuRCYiIpLwgkHo3RuWLoVrroHXX7c7IokzVo7NqvTOzeVyRfyqR1ZWVnIvMn0cgsEgW7duDavoiziBcl+cKi5y/7XXzIJtzZrmgFIFW4myWOe7lWPSsWPH8tprr/HWW2+xdu1abr75ZnJzc7nmmmsAuOqqq7j33ntDx48aNYrZs2fzzDPPsG7dOh566CGWLVvGyJEjAbNge8kll7Bs2TLefvttAoEAe/bsYc+ePRQWFgLQsWNHLrjgAq6//nqWLl3K119/zciRI7n88ssrVLBNNHFxXhSxgXJfnChp8/7tt82CbZ068PjjdkcjccjKnK/Su7e+ffsyceLEsMFwIBBg4sSJnHXWWZYFl0wCgQCrVq1SUVscR7kvTmV77m/bBnfcYV5+/HE4aiaiSDTEOt+tHJMOHz6cp59+mgcffJCMjAxWrVrF7NmzQ5uNbdu2jd27d4eO79OnD9OmTePVV1+la9euTJ8+nZkzZ9K5c2cAdu7cyX//+1927NhBRkYGzZo1C/0sXrw49Dhvv/026enpnHvuuQwePJizzjqLV1999Xh+LXHL9vOiiE2U++JESZn3hw/D3Xeblx94AJo2tTceiUtW5nyVFtmbNGkSffv2pUOHDpz96zoeixYtIjs7my+//NKy4ERERKQKDAOuvx5ycqBPH61jK0nL6jHpyJEjQzNljxZpfdxLL72USy+9NOLxbdu2LbV2ZST169dn2rRplYpTREREbDBxIuzeDSedBCXWtBeJlirNtO3UqROrV6/msssuY9++feTk5HDVVVexbt260OwCERERscnrr8OcOVCjhnnZ47E7IpGo0JhUREREYmLzZnjmGfPy009D9er2xiOOUOXtrJs3b87jWr+jwlwuF40aNQrtoijiFMp9cSrbcn/HDhg71rz8yCPQoUNsn18czY5zvcakiUNjAnEq5b44UdLl/Z13QkEBnHsuXHSR3dFIHLMy56u8I8miRYv44x//SJ8+fdi5cycA//znP/nqq68sCy6ZeL1e+vTpg9db5Tq5SEJS7otT2ZL7hgE33gjZ2fCb38CYMbF7bhGw5VyvMWni0JhAnEq5L06UVHk/fz68/765qe+zz0KyFKIlKqzM+SoVbd9//30GDRpEzZo1WbFiBQUFBYC5U69mOkQWCARYt25dci3CLVIByn1xKlty/x//gFmzzK9raVkEsUGsz/UakyYWjQnEqZT74kRJk/eBAIwaZV6+6SY47TR745G4Z2XOV6lo++ijj/LKK6/w2muvkZKSEmo/88wzWbFihWXBJZNgMMj69esJBoN2hyISU8p9caqY5/6uXcUbIjz0EHTsGJvnFSkh1ud6jUkTi8YE4lTKfXGipMn7v/0NVq+GE06Ahx+2OxpJAFbmfJWKtuvXr6dv376l2tPS0jh06NDxxiQiIiKVYRjmJ/+HDkH37nDHHXZHJBITGpOKiIhI1Bw6BA88YF5+6CFo0MDOaMSBqlS0bdq0KRs3bizV/tVXX3HiiSced1AiIiJSCdOmwUcfQUoKvPEGJMPaYSIVoDGpiIiIRM3DD0NmpvkNtptvtjsacaAqFW2vv/56Ro0axTfffIPL5WLXrl28/fbb3HHHHdysRI7I7XbTunVr3O4q7/0mkpCU++JUMcv9PXvg1lvNy+PHQ+fO0X0+kXLE+lyvMWli0ZhAnEq5L06U8Hm/bh08/7x5+dlnzckRIhVgZc67DMMwKnsnwzB4/PHHmThxInl5eQBUr16dO+64g0ceecSy4I4lOzubtLQ0srKyqFu3bsyeV0REJC4YBgwbBjNnQrdu8M03GlCKrWI9NouXMamdNB4WERGJgiFDzA1+f/tb8xttIhVk5disSkXbIoWFhWzcuJHDhw/TqVMn6tSpc1zBVFYiDVIDgQCrV6+mS5cueLSbtziIcl+cKia5/847cMUV5nIIy5dDly7ReR6RCvrll1+oX79+zMdmdo9J7aTxsEj8U+6LEyV03n/6KQwebE6G+OEHOOUUuyOSBGLlePi45uxWq1aNTp06kZ6ezueff87atWuPK5hkFgwG2bZtW+LvnChSScp9caqo5/6+fTBypHn5gQdUsJW4YNe5XmPSxKAxgTiVcl+cKGHz3ueDMWPMy7fdpoKtVJqVOV+lou1ll13GCy+8AEB+fj49evTgsssuo0uXLrz//vuWBSciIiJluOUWOHAAunaFe++1OxoRW2hMKiIiIpZ68UVYvx4aNYJx4+yORhyuSkXbhQsXcvbZZwPwwQcfEAwGOXToEM899xyPPvqopQGKiIjIUaZPN3+8XnjjDahWze6IRGyhMamIiIhYZv9+eOgh8/Ljj0Namq3hiFSpaJuVlUX9+vUBmD17NhdffDG1atViyJAhbNiwwdIAk4Xb7aZDhw6Ju3OiSBUp98Wpopb7mZnwl7+Yl++919yATCROxPpcrzFpYtGYQJxKuS9OlJB5/+CDkJUFGRlwzTV2RyMJysqcr9IjtWrViiVLlpCbm8vs2bM5//zzAXOx3Ro1algWXDLxeDykp6cn3gLcIsdJuS9OFbXcv/VWcxZA587mWrYicSTW53qNSROLxgTiVMp9caKEy/vvvoNXXzUvT50KiRK3xB0rc75KRdvRo0dz5ZVX0rJlS5o3b84555wDmF9RO+200ywLLpn4/X4WL16M3++3OxSRmFLui1NFJfc/+ADeecccRGpZBIlDsT7Xa0yaWDQmEKdS7osTJVTeGwaMHg3BIFx6KfTta3dEksCszHlvVe70l7/8hV69erFt2zbOO++80NTfE088UeuHlcEwDPbv349hGHaHIhJTyn1xKstz/8ABuPlm8/Jdd0H37tY8roiFYn2u15g0sWhMIE6l3BcnSqi8/+ADmD8fatSAJ5+0OxpJcFbmfJWKtgBnnHEGZ5xxRljbkCFDjjsgERERiWD0aNi7Fzp2NNfbEhFAY1IRERE5DkeOwO23m5fvvBPatrU1HJGSKrw8whNPPEF+fn6Fjv3mm2/45JNPqhyUiIiIlPDRR/Cvf4HbbS6LoLU6xcE0JhURERHLPPssbNkCLVrA3XfbHY1ImAoXbdesWUPr1q35y1/+wqeffsr+/ftDt/n9flavXs1LL71Enz59GD58OKmpqVEJOFF5PB4yMjISZxFuEYso98WpLMv9X36BG280L99+O/TqdfzBiURJLM71GpMmLo0JxKmU++JECZH3u3bBY4+ZlydNgtq17Y1HkoKVOe8yKrHYwnfffccLL7zA9OnTyc7OxuPxUL16dfLy8gDo1q0bf/7zn7n66qtjsmNvdnY2aWlpZGVlUbdu3ag/n4iISMxdfTW89RZ06AArV0LNmnZHJFKmWI3N4m1MaieNh0VERKpoxAj4xz+gd2/4+mtwueyOSJKAlWOzShVtiwSDQVavXs3WrVvJz8+nYcOGZGRk0LBhw+MKprISaZDq9/tZuHAhffv2xeut8lLCIglHuS9OZUnuz5oFQ4aYA8ivvoI+fawNUsRiBw8epEGDBjEbm8XLmNROGg+LxD/lvjhR3Of90qXF32BbuhR69LA3HkkaVo6Hq/Q/x+12k5GRQUZGxnE9uZMYhkFOTk5i7JwoYiHlvjjVced+VhbccIN5ecwYFWwlIcT6XK8xaWLRmECcSrkvThTXeW8YMGqUeXnECBVsxVJW5nyF17QVERGRGLr9dti5E04+GR55xO5oRERERESSw7Rp8L//mWvYPv643dGIlElFWxERkXgzZw78/e/msgivvw61atkdkYiIiIhI4jt8GO66y7x8//3QvLm98YiUQ0XbGPF4PPTu3Tu+d04UiQLlvjhVlXM/Oxv+/Gfz8q23wtlnWx+cSJToXC/l0ZhAnEq5L04Ut3k/aRLs2gXt2plLkIlYzMqcj8PVoJOT2+2mcePGdochEnPKfXGqKuf+XXfB9u1w4on6upYkHLdb8wGkbBoTiFMp98WJ4jLvt2yBp582Lz/9NNSoYWs4kpysHA8f1yNt3LiRzz77jPz8fCD2m08kEp/PxyeffILP57M7FJGYUu6LU1Up97/4Av76V/Py3/9urrMlkkDsOtdrTJoYNCYQp1LuixPFZd7fdRccOQL9+8Pvf293NJKkrMz5KhVtDxw4wMCBAznllFMYPHgwu3fvBuC6667j9ttvtyy4ZOP3++0OQcQWyn1xqkrlfk4OXHedefkvf4FzzolKTCLJRGPSxKMxgTiVcl+cKK7yfsECeO89cLthyhRz7wiROFelou2YMWPwer1s27aNWiU2Rxk+fDizZ8+2LDgRERHHuOce2LoV2rY119oSkWPSmFRERESOKRCA0aPNyzfcAF262BqOSEVVaU3bOXPm8Nlnn9GyZcuw9vbt27N161ZLAhMREXGM+fPhpZfMy3/7G9SpY2s4IolCY1IRERE5ptdfh1WroF49ePhhu6MRqbAqzbTNzc0Nm81Q5ODBg1SvXv24g0pGXq+X/v374/Vq7zdxFuW+OFWFcz83t3hZhBtvhHPPjX5wIlES63O9xqSJRWMCcSrlvjhR3OR9Vhbcf795efx4aNTI3ngk6VmZ81Uq2p599tn84x//CF13uVwEg0GefPJJ+vfvb1lwyaZmzZp2hyBiC+W+OFWFcv++++Dnn6FVK3jyyegHJZJENCZNPBoTiFMp98WJ4iLvH3kE9u+H9HS45Ra7oxGplCoVbZ988kleffVVLrzwQgoLC7nrrrvo3LkzCxcuZJLW4YvI7/cza9as+FqIWyQGlPviVBXK/UWL4LnnzMt/+xvUrRub4ESiJNbneo1JE4vGBOJUyn1xorjI+59+gqlTzcvPPgspKfbFIo5hZc5XqWjbuXNnfvrpJ8466ywuuugicnNzGTZsGCtXruSkk06q8ONMnDiRHj16kJqaSuPGjRk6dCjr16+vSkgiIiKJJS8Prr3WvHzddXD++fbGI5KArBqTioiISBK6/Xbw+2HwYLjgArujEam0Ki+0kJaWxv1F64JU0YIFC7jlllvo0aMHfr+f++67j/PPP581a9ZQu3bt43psERGRuDZuHGzcCC1bwjPP2B2NSMKyYkwqIiIiSWb2bPj4Y/B6YfJku6MRqZIqF22PHDnC6tWr2bdvH8FgMOy23/3udxV6jNmzZ4ddf/PNN2ncuDHLly+nb9++VQ1NREQkvi1ebH5FC+DVVyEtzd54RBKYFWNSERERSSI+H4wZY16+7Tbo0MHeeESqyGUYhlHZO82ePZurrrqKzMzM0g/ochEIBKoUzMaNG2nfvj3ff/89nTt3Pubx2dnZpKWlkZWVRd04XwfQMAz8fj9erxeXy2V3OCIxo9wXpyoz9/PzoVs3WL8eRoyAN9+0LUYRq2VlZVGvXr2Yjc2iNSZNJBoPi8Q/5b44ka15/9xzMGoUNGwIGzZAvXqxfX5xNCvHw1WaaXvrrbdy6aWX8uCDD9KkSZPjCqBIMBhk9OjRnHnmmWUWbAsKCigoKAhdz87OBsDn8+Hz+QBwu914PB4CgUDYbIuidr/fT8k6tcfjwe12l9le9LhFvF7zV3b0wsJltaekpBAMBvH7/Rw+fJg6dergdrvxer0Eg8GwNxMulwuv11tm7PHWp0ixq0/q09F9MgyDvLw80tLSkqZPJdvVJ/WpvD7l5ORQp04dXC5XKPbguHG416/HaNYM/5NP4g4EEqpPyfg6qU/W9eno26ItGmNSia78/HxSU1PtDkMk5pT74kS25H1mJowfb15+7DEVbCWhValou3fvXsaOHWvp4PiWW27hhx9+4KuvvirzmIkTJzJhwoRS7XPmzKFWrVoAtG7dmm7durF69Wq2bdsWOqZDhw6kp6ezdOlS9u/fH2rPyMigTZs2LFy4kJycnFB77969ady4MXPmzAl7A9W/f39q1qzJrFmzwmIYPHgw+fn5zJs3L9Tm9XoZMmQImZmZLFmyJNSemprKgAED2L59O6tWrQq1N2rUiD59+rBhw4awDdnUJ/Up0ftUFGcy9SkZXyf1ydo+rV+/no0bN4b3qbAQ16/LInxzzTXsXbIkofqUjK+T+mRtn/Ly8oilaIxJJXr8fj/z5s1j8ODBpGgHb3EQ5b44kW15P348HDoEXbuam/2KxNjRkzCOR5WWR7j22ms588wzuc6i/wAjR47kww8/ZOHChbRr167M4yLNtG3VqhWZmZmhKcfxOkPmyJEjzJ07l/POO49q1arF7QyZyvQpUWb9qE/29snn8zF37lwGDx7M0RK1TyXbk+V1Up+s79ORI0f47LPPOO+880hJScFdWIinRw9Yu5bgH/5A4NdlERKpT8n4OqlP1vbpwIEDNGvWLGZf1bd6TJqIEml5BJ/Px6xZs1S4EsdR7osT2ZL3338PGRkQDMK8eXDOObF5XpESDhw4QMOGDe1bHuGFF17g0ksvZdGiRZx22mml/gPedtttFXocwzC49dZb+eCDD5g/f365BVuA6tWrU7169VLtKSkppWLweDx4PJ5Sxxa9sahoe1knl8q0u93uUHtKSkroudxuN263u9TxZcUeb32KFLv6pD6V164+qU9O61PRY6WkpJif+q9dC02a4H7+edxHPUei9CkZXyf1ybo+xboYYdWYVERERBKcYcDo0WbB9pJLVLCVpFClou2///1v5syZQ40aNZg/f37YotIul6vCA+RbbrmFadOm8eGHH5KamsqePXsASEtLo2bNmlUJLa6V9eZHJNkp98WpQrm/bBk8+aR5+ZVXoH59+4ISSSJWjUkldjQmEKdS7osTxTTvP/wQvvwSqleHp56K3fOKRFGVlkdo2rQpt912G/fcc0/EGR4VfvIydhB84403uPrqq495/0T6OpiIiDhUQQF07w4//ACXXw7//rfdEYlETazHZlaNSROZxsMiIuJ4BQXQqRP8/DPcfz88+qjdEYmDWTk2q9LHHoWFhQwfPvy4B8dVqBcnrGAwSGZmJg0bNnTsmwpxJuW+OFIgQHDBAnJ++om6X3+N64cfoFEjeP55uyMTiaqSa+bGglVjUokNjQnEqZT74kQxzfspU8yCbbNmcM890X0ukWOwcjxcpf85I0aM4N1337UsCCcIBAIsWbIkbMMQESdQ7ovjzJgBbdviPvdc0m6+Gde//mW2jxgBDRvaG5tIlMX6XK8xaWLRmECcSrkvThSzvN+9u3hm7aRJUKdOdJ9P5BiszPkqzbQNBAI8+eSTfPbZZ3Tp0qXUpg+TJ0+2JDgREZGEMmOGufFBpG+SPPMM9O4Nw4bFPi6RJKUxqYiIiMPddx8cPgy9esGVV9odjYilqlS0/f777+nWrRsAP/zwQ9htZa1TKyIiktQCARg1KnLBtsjo0XDRReDxxCwskWSmMamIiIiDffstvPmmeXnqVNDyI5JkqlS0nTdvntVxJD2Xy0VqaqreQIjjKPfFMRYtgh07yr7dMGD7dvO4c86JWVgisRTrc73GpIlFYwJxKuW+OFHU894wzAkTAH/6kznTViQOWJnzVSraSuV5vV4GDBhgdxgiMafcF8fYvdva40QSkNeroaWUTWMCcSrlvjhR1PP+3/+GJUugdm2YODF6zyNSSVaOhyv8SMOGDePNN9+kbt26DDvGenwzZsw47sCSTTAYZPv27bRq1Uo7hoqjKPfFEbZsgVdfrdixzZpFNRQRO1m5W25ZNCZNXBoTiFMp98WJopr3ublw993m5XvvhRYtrH18keNg5Xi4wkXbtLS00BTftLQ0ywJwikAgwKpVq2jevLn+UIujKPclqR0+bH6y/8wzUFBQ/rEuF7RsCWefHZvYRGwQi53RNSZNXBoTiFMp98WJopr3Tz5pLkvWti2MHWvtY4scJyvHwxUu2r7xxhs8/PDD3HHHHbzxxhuWBSAiIpJwgkH4xz/MT/b37DHb+veH3/4W7rjDvF5yQ7KidY2mTNEmZCLHSWNSERERB9u2zSzaAjz1FNSsaW88IlFUqY87JkyYwOHDh6MVi4iISPz76ivo2ROuucYs2J50EnzwAXzxhflJ//Tppb+i1bKl2X6Mr3KLSMVoTCoiIuJQd90FR45Av35w8cV2RyMSVZUq2holZw1JpbhcLho1aqQdQ8VxlPuSNLZsgeHDzeUNli+HunXNT/d//BGGDi2eTTtsGGzZQuDzz/npoYcIfP45bN6sgq04QqzO9dEak7744ou0bduWGjVq0KtXL5YuXVru8e+99x7p6enUqFGD0047jVmzZoXdPmPGDM4//3waNGiAy+Vi1apVpR7jnHPOweVyhf3cdNNNVnYrbmhMIE6l3BcnikreL1oE774Lbrf5DTb9n5I4ZGXOV3phEf2hqRqv10ufPn20q7I4jnJfEt7hw3D//ZCeDv/5jzlIvOEG2LDBXAqhevXS9/F48Jx7LqeMH4/n3HO1JII4RizP9VaPSd99913Gjh3L+PHjWbFiBV27dmXQoEHs27cv4vGLFy/miiuu4LrrrmPlypUMHTqUoUOH8sMPP4SOyc3N5ayzzmLSpEnlPvf111/P7t27Qz9PFn3tM8loTCBOpdwXJ7I87wMBGDXKvPznP0NGhjWPK2IxK8/1LqMSUxXcbnfY5g9lOXjw4HEHVhHZ2dmkpaWRlZVF3bp1Y/KcVRUIBNiwYQPt27fHozfv4iDKfUlYZa1b++yz0LXrMe+u3Bcn+uWXX6hfv37Ux2bRGJP26tWLHj168MILLwDmzr+tWrXi1ltv5Z577il1/PDhw8nNzeXjjz8Otf3mN78hIyODV155JezYLVu20K5dO1auXEnGUW8yzznnHDIyMpgyZUqFYy1J42GR+KfcFyeyPO///nezWJuWZk6eaNTo+B9TJAqsHA9Xuvw7YcIE7dRbBcFgkPXr13PSSSfpD7U4inJfEtKiRTBmjLkMApjr1j79NFx0UYW/hqXcFycKBoMxey4rx6SFhYUsX76ce++9N9TmdrsZOHAgS5YsiXifJUuWMPaoHasHDRrEzJkzK/38b7/9Nv/6179o2rQp//d//8e4ceOoVatWpR8n3um8KE6l3BcnsjTvs7PhvvvMy+PHq2Arcc3K8XCli7aXX345jRs3tiwAERGRuLFli7m5wXvvmdfr1oVx4+DWWyMvgyAitrFyTJqZmUkgEKBJkyZh7U2aNGHdunUR77Nnz56Ix+8pmplfQX/4wx9o06YNzZs3Z/Xq1dx9992sX7+eGTNmRDy+oKCAgoKC0PXs7GwAfD4fPp8PMAvOHo+HQCAQ9sahqN3v94etC+zxeHC73WW2Fz1ukaKv/fn9/gq1p6SkEAwGQ4/j8/lwuVx4vV6CwSCBQCB0bFF7WbHHW58ixa4+qU9H96lkP5KlTyXb1Sf1KVJ7kZLPW9U+BR9+GPe+fRjt2+O/4QY8waBeJ/Upbvt09G3Ho1JFW61nKyIiSSknB554Ap55BgoKzHVrr78eHn4Y9EGlSNxJpjHpDTfcELp82mmn0axZM84991w2bdrESSedVOr4iRMnMmHChFLtc+bMCc3Obd26Nd26dWP16tVs27YtdEyHDh1IT09n6dKl7N+/P9SekZFBmzZtWLhwITk5OaH23r1707hxY+bMmRP2Bqp///7UrFmz1MZrgwcPJj8/n3nz5oXavF4vQ4YMITMzMzRree7cuaSmpjJgwAC2b98etkFbo0aN6NOnDxs2bGD9+vWh9njvE6A+qU/l9qlIMvUpGV8n9cm6PvXo0QMwz/nH1afCQpg6FYD/DR/Ovs8/1+ukPsV1n/Ly8rBKpde03bNnT9zMtE20NbxWr15Nly5d9JUYcRTlvsS1YBDeesv8ulXJdWunTIEuXY7roZX74kSxXNPWyjFpYWEhtWrVYvr06QwdOjTUPmLECA4dOsSHH35Y6j6tW7dm7NixjB49OtQ2fvx4Zs6cyXfffRd2bHlr2h4tNzeXOnXqMHv2bAYNGlTq9kgzbVu1akVmZmbodx6vM2QKCwv58ccfOfXUU/F6vXE7Q6YyfUqUWT/qk719CgQCrFmzhq5du5b62myi9qlke7K8TuqTtX1yuVx89913dOrUKTQWrlKffv97+O9/CQ4aROCjj2ztUzK+TuqT9X365ZdfaNKkiSXj4UoVbeNNIhVtRUQkzixaBKNHw4oV5vWTTjJn2v7udxVet1ZEwiXy2KxXr1707NmT559/HjDXI2vdujUjR44scyOyvLw8Pvr1DSRAnz596NKlS6U2Ijva119/zVlnncV3331Hlwp8eJTIv3MREZFyzZkDgwaB1wurV0PHjnZHJHJMVo7N3BbFJMcQCARYuXJl2CcAIk6g3Je4s2ULXHYZ9O1rFmzr1oWnnoIff6zURmPHotwXJ0rkfB87diyvvfYab731FmvXruXmm28mNzeXa665BoCrrroqbKOyUaNGMXv2bJ555hnWrVvHQw89xLJlyxg5cmTomIMHD7Jq1SrWrFkDwPr161m1alVo3dtNmzbxyCOPsHz5crZs2cJ///tfrrrqKvr27Vuhgm2i0XlRnEq5L0503Hnv95sbAwOMHKmCrSQMK8/1KtrGSDAYZNu2bTHdVVkkHij3JW7k5MD990N6urnRmNsNN94IGzbAHXdYvtGYcl+cKJHzffjw4Tz99NM8+OCDZGRksGrVKmbPnh3abGzbtm3s3r07dHyfPn2YNm0ar776Kl27dmX69OnMnDmTzp07h47573//S7du3RgyZAhgbp7WrVu30EzcatWq8fnnn3P++eeTnp7O7bffzsUXXxw2ezeZ6LwoTqXcFyc67rx/5RVYswYaNIAHH7Q2OJEosvJcX6mNyERERBJOFNetFZHkMnLkyLCZsiXNnz+/VNull17KpZdeWubjXX311Vx99dVl3t6qVSsWLFhQ2TBFRESS24EDxYXaRx+FE06wNx4Rm6hoKyIiyUvr1oqIiIiIJJaHHoJffjEnWFx/vd3RiNhGyyPEiNvtpkOHDrjd+pWLsyj3xRabN5det/bppy1ft7Y8yn1xIuW7lEfnRXEq5b44UZXz/ocf4OWXzctTpoDHY3lsItFk5bneZRiGYdmjxZh2yxURkTA5OTBxIkyeDAUF5rq1118PDz8MjRvbHZ1I0tPYLPb0OxcRkaRhGHD++fD55zBsGLz/vt0RiVSalWMzfdQXI36/n8WLF+P3++0ORSSmlPsSE8EgvPEGnHKKWbQtKIABA2DlSnMTAxsKtsp9cSLlu5RH50VxKuW+OFGV8v6jj8yCbbVq8NRT0QtOJIqsPNdrTdsYMQyD/fv3k8ATm0WqRLkvURen69Yq98WJlO9SHp0XxamU++JElc77ggIYO9a8fPvtcOKJ0QtOJIqsPNdrpq2IiCSmOFi3VkRERERELDB1KmzaBM2awb332h2NSFzQTFsREUksWrdWRERERCR57NkDjz5qXp44EVJT7Y1HJE6oaBsjHo+HjIwMPNr5UBxGuS+WCQbhrbfgvvvMgR2Y69Y++yx06WJvbBEo98WJlO9SHp0XxamU++JElcr7++83J2b06AF/+lP0gxOJIivP9Sraxojb7aZNmzZ2hyESc8p9sUScrltbHuW+OJHbrZW3pGw6L4pTKffFiSqc98uXmxsKg7lEgsYSkuCsHA/rf0OM+P1+vvzyS+0YKo6j3JfjksDr1ir3xYmU71IenRfFqZT74kQVynvDgFGjzH+vvBJ6945dgCJRYuW5XjNtY8QwDHJycrRjqDiOcl+qJNK6tTfcYK5b26iR3dFViHJfnEj5LuXReVGcSrkvTlShvH/3Xfj6a6hVC554InbBiUSRled6FW1FRCR+BIPw5pvmurV795pt555rrlt72mm2hiYiIiIiIhbJy4O77jIv33MPtGxpbzwicUhFWxERiQ8LF5rr1q5caV4/+WRz3dr/+7+4XgZBREREREQq6amnYPt2aN0a7rjD7mhE4pLWtI0Rj8dD7969tWOoOI5yX45p82a49FLo188s2JZctzaONxo7FuW+OJHyXcqj86I4lXJfnKjcvN++HSZNMi8//TTUrBnb4ESiyMpzvWbaxojb7aZx48Z2hyESc8p9KVMSrFtbHuW+OJGVu+VK8tF5UZxKuS9OVG7e33035OfD2WfDJZfENjCRKLNyPKyRdYz4fD4++eQTfD6f3aGIxJRyX0oJBuH116F9e7NoW1Bgrlu7ahW8/HJSFGxBuS/OpHyX8ui8KE6l3BcnKjPvv/4a/v1v89t0U6cm7LfqRMpi5bleM21jyO/32x2CiC2U+xLisHVrlfsiIuF0XhSnUu6LE5XK+2AQRo0yL193HXTrFvugRBKIZtqKiEj0Hb1ubVqaWaxN8HVrRURERESkgt56C5YvN/ewePRRu6MRiXuaaSsiItGTkwOPPw7PPpuU69aKiIiIiEgFZGfDvfealx98EJo0sTcekQTgMgzDsDuIqsrOziYtLY2srCzq1q1rdzjlMgyDnJwcUlNTcWlGmTiIct+hAgHzk/T77oO9e822c881i7ennWZvbDGi3BcnysrKol69egkxNksWGg+LxD/lvjhRqby/5x6YNMnc1+KHH6BaNbtDFIkKK8fDmmkbQzVr1rQ7BBFbKPcd5uh1a9u3N5dC+O1vHbcMgnJfRCSczoviVMp9caJQ3m/caE7eAJg8WQVbkQrSmrYx4vf7mTVrlhagF8dR7jtIWevW/vBD0m40Vh7lvjiR8l3Ko/OiOJVyX5woLO/vuAMKC2HQIBgyxO7QRKLKynO9ZtqKiMjxKVq3dvJkczCmdWtFRERERJwrEMC1YAEtFi7EvX49fPgheDzm+wWHTeQQOR4q2oqISNVo3VoRERERESlpxgwYNQrvjh10L9k+aBB06mRXVCIJScsjiIhI5S1cCD16wHXXmQXb9u3hv/+FuXNVsBURERERcaIZM+CSS2DHjtK3ffqpebuIVJjLMAzD7iCqKtF2y/X7/Xi9Xu0YKo6i3E8ymzfDXXfB9Onm9bQ0ePBBGDlSGwocRbkvTmTlbrlSMRoPi8Q/5b44QiAAbdtGLtiCuSxCy5bm+wmPJ6ahicSSleNhzbSNofz8fLtDELGFcj8J5OTAvfdCerpZsHW74eabYcMGGDtWBdsyKPdFRMLpvChOpdyXpLdoUdkFWwDDgO3bzeNEpEJUtI0Rv9/PvHnztGOoOI5yP8EFAvD3v5vLHzzxhLnR2MCB8N138NJL2misHMp9cSLlu5RH50VxKuW+OMKuXRU7bvfu6MYhYjMrz/XaiExERCJbsADGjIGVK83r7dvDM8/Ab3+rXV9FRERERMScQfv55/DIIxU7vlmz6MYjkkQ001ZERML9/LO5gcA555gF27Q0mDwZfvgB/u//VLAVERERERFYvBgGDIDzz4d168p/n+ByQatWcPbZsYtPJMGpaBtDXq8mNoszKfcTRHa2uW5tx47w/vvh69aOGaN1a6tAuS8iEk7nRXEq5b4klZUrYcgQOPNMmD/ffJ8wahT87W9mcfbo4m3R9SlTtAmZSCW4DMMw7A6iqhJpt1wRkbgVCMCbb8L998PevWbbwIHw7LPQubOtoYlIYtHYLPb0OxcRkZhZtw4efBDee8+87vHAtdfCuHHmLFqAGTPMAm7JTclatTILtsOGxTxkkVizcmymmbYxEgwG2bdvH8Fg0O5QRGJKuR/nFiyAHj3gz382C7bt28N//wtz5qhge5yU++JEyncpj86L4lTKfUl4W7bANdfAqaeaBVuXC/7wB1i7Fl59tbhgC2ZhdssWgl98QdbLLxP84gvYvFkFW3EMK8/1KtrGSCAQYMmSJQQCAbtDEYkp5X6c0rq1UafcFydSvkt5dF4Up1LuS8LavRtGjoRTTjG/mRcMwkUXwXffwdtvmxM+IvF4CJx9NvObNSNw9tlaEkEcxcpzvRbWERFxkuxsmDjRLNAWFprr1t50E0yYAA0b2h2diIiIiIjY7cABePJJeP55yM832wYOhEcfhV697I1NxEFUtBURcQKtWysiIiIiIuXJzjbXnn3mGfMyQO/e8Nhj0L+/raGJOJGKtjHicrlITU3Fpa8ci8Mo9+PAggUwejSsWmVeb9/enGk7ZIiWQYgi5b44kfJdyqPzojiVcl/iXn4+vPSS+Y28AwfMtq5dzWLt4MFVes+gvBensjLnXYZhGJY9Woxpt1wRkXL8/DPceae5gyuY69aOHw+33ALVqtkbm4gkJY3NYk+/cxERqbLCQvj7381lD3btMts6dICHHzb3v3BrGySRyrJybKb/gTESDAbZunWrdgwVx1Hu2yA7G+65Bzp2NAu2bjf85S+wcSOMGaOCbYwo98WJlO9SHp0XxamU+xJ3AgH4xz8gPd18n7BrF7RpA6+/bm5MfNllx12wVd6LU1mZ8yraxkggEGDVqlXaMVQcR7kfQ4GA+Un5KafApEnmJ+cDB5q7u774ojYaizHlvjiR8l3Ko/OiOJVyX+KGYcD770OXLjBiBGzeDE2amBuOrV8P11wDXmtW0VTei1NZmfNa01ZEJBlo3VoREREREYnEMOCzz8xNiVesMNtOOAHuvhtGjoTate2NT0QiUtFWRCSRHb1ubb165rq1f/mLlkEQEREREXG6RYvgvvvgq6/M63XqmEum3X67ueeFiMQtFW1jxOVy0ahRI+2cKI6j3I+S7Gx4/HF49llzGQS3G266CSZM0DIIcUK5L06kfJfy6LwoTqXcF1ssX27OrP3sM/N69ermrNq774ZGjaL+9Mp7cSorc95lGIZh2aPFmHbLFRHHCQTgzTfNAdjevWbbeeeZSyF07mxraCIiGpvFnn7nIiISZs0aGDeu+Jt4Xi/8+c/wwAPQooW9sYk4gJVjM21EFiOBQIB169ZpEW5xHOW+hRYsgO7dzUHX3r3mhmMffWR+eq6CbdxR7osTKd+lPDovilMp9yUmfv4ZrrrKfF8wY4a5r8Wf/mRuMPbyyzEv2CrvxamszHkVbWMkGAyyfv16gsGg3aGIxJRy3wI//wwXXwznnGNuNFavnrkswvffw29/q43G4pRyX5xI+S7l0XlRnEq5L1G1cyfcfDN06AD//Ke56diwYeZ7hX/8A0480ZawlPfiVFbmvNa0FRGJV9nZ8NhjMGWK1q0VEREREZFimZnwxBPw4otw5IjZNmgQPPqo+e08EUl4KtqKiMSbQADeeMNct3bfPrNN69aKiIiIiEhWFjzzjPnNu8OHzbazzjIne/Tta29sImIpFW1jxO1207p1a9xurUghzqLcr6T582HMGHMZBDDXrX3mGRgyRMsgJBjlvjiR8l3Ko/OiOJVyXyyRlwfPPw+TJsEvv5htp59uFmsHDYq79wrKe3EqK3PeZRiGYdmjxZh2yxWRpPHzz3DnncW7vNarB+PHw1/+AtWq2RqaiEhFaWwWe/qdi4gkuYICeO01szi7Z4/Z1rEjPPKIuXZtnBVrRZzOyrGZPvKIkUAgwMqVK7VzojiOcv8YsrPh7rvNgdeMGeDxwC23wIYNMHq0CrYJTLkvTqR8l/LovChOpdyXKvH7zSXTOnSAW281C7bt2sFbb5mbjF18cVwXbJX34lRW5ryKtjESDAbZtm2bdk4Ux1HulyEQgL/9Ddq3hyefNDcaO+88+O47eOEFbTSWBJT74kTKdymPzoviVMp9qZRgEP7zH3Mvi2uvha1boVkzeOklWLcOrrrKnOgR55T34lRW5rzWtBURibVI69ZOngyDB8f1p+UiIiIiIhIlhgGzZpmbEX/3ndnWoAHcc4/5TbyaNe2NT0RiTkVbEZFY0bq1IiIiIiJytPnz4b77YMkS83pqKtxxh7lcmtYrF3EsFW1jxO1206FDB+2cKI6j3Mdct/axx2DKFHMZBI8HbroJHnpIyyAkMeW+OJHyXcqj86I4lXJfyrR0qTmz9vPPzes1a5rr1951lznLNoEp78WprMx5/e+JEY/HQ3p6Op4EWHtGxEqOzv1AwNzpVevWOpKjc18cK9Hz/cUXX6Rt27bUqFGDXr16sXTp0nKPf++990hPT6dGjRqcdtppzJo1K+z2GTNmcP7559OgQQNcLheripbFKeHIkSPccsstNGjQgDp16nDxxRezd+9eK7sVN3ReFKdS7ksp338PQ4dCr15mwTYlxVwCYdMmmDQp4Qu2oLwX57Iy51W0jRG/38/ixYvx+/12hyISU47N/fnz4Ywz4IYbYN8+c93ajz+Gzz6DU0+1OzqJAcfmvjhaIuf7u+++y9ixYxk/fjwrVqyga9euDBo0iH379kU8fvHixVxxxRVcd911rFy5kqFDhzJ06FB++OGH0DG5ubmcddZZTJo0qcznHTNmDB999BHvvfceCxYsYNeuXQwbNszy/sUDnRfFqZT7ErJhA/zhD9C1K3z4IbjdcPXV8NNP5qSOZs3sjtAyyntxKitzXkXbGDEMg/3792MYht2hiMSU43J/0yYYNgz69zdn1NarZy6L8P33MGSINhpzEMflvggkdL5PnjyZ66+/nmuuuYZOnTrxyiuvUKtWLV5//fWIx0+dOpULLriAO++8k44dO/LII49w+umn88ILL4SO+dOf/sSDDz7IwIEDIz5GVlYWf//735k8eTIDBgzgjDPO4I033mDx4sX873//i0o/7aTzojiVcl/Yvh2uvx46doR//9vcdOzSS+HHH+GNN6BtW7sjtJzyXpzKypzXmrYiIlYoa93aCROS4utNIiLJrLCwkOXLl3PvvfeG2txuNwMHDmRJ0aYwR1myZAljx44Naxs0aBAzZ86s8PMuX74cn88XVtRNT0+ndevWLFmyhN/85jel7lNQUEBBQUHoenZ2NgA+nw+fzxeK3ePxEAgECAaDYX3yeDz4/f6wNxQejwe3211me9HjFvF6zbcQR88kKas9JSWFYDAYehyfz4fL5cLr9RIMBgkEAqFji9rLij3e+hQpdvVJfTq6TyX7kSx9KtmuPpXTp927cU+ahPuvf8VVWAiAceGF+B96CLp1M/vk9ydWnyr4OhUp+byJ3qdkfJ3UJ+v7dPRtx0NFWxGR4xEIwOuvwwMPmMsgAJx/PkyerGUQREQSRGZmJoFAgCZNmoS1N2nShHXr1kW8z549eyIev2fPngo/7549e6hWrRr16tWr8ONMnDiRCRMmlGqfM2cOtWrVAqB169Z069aN1atXs23bttAxHTp0ID09naVLl7J///5Qe0ZGBm3atGHhwoXk5OSE2nv37k3jxo2ZM2dO2Buo/v37U7NmzVJr+A4ePJj8/HzmzZsXavN6vQwZMoTMzMxQAXzu3LmkpqYyYMAAtm/fHrbWb6NGjejTpw8bNmxg/fr1ofZ47xOgPqlP5fapSDL1KRlfJ8v61KcP+RMmUP2vf8Vz5AgAv3Ttygkvvsi2li3NPu3enVh9quTr1KNHD8A85ydLn5LxdVKfrO9TXl4eVnEZCTxXPTs7m7S0NLKysqhbt67d4ZQrGAyyfft2WrVqpd0TxVGSOvfnz4fRo81lEMBct3byZBg8WMsgSHLnvkgZDh06xAknnJAQY7OSdu3aRYsWLVi8eDG9e/cOtd91110sWLCAb775ptR9qlWrxltvvcUVV1wRanvppZeYMGFCqY3EtmzZQrt27Vi5ciUZGRmh9mnTpnHNNdeEzZwF6NmzJ/3794+4Fm6kmbatWrUiMzMz9DuP1xkyPp+PnTt30qJFCzweT9zOkKlMnxJl1o/6ZG+fgsEgu3fvpnXr1mHPmch9KtmeLK+TJX06fBj3Cy/gmTwZDh0CINi9O8GHH4aBA/GW09e47dOvKvs6ud1utm7dSvPmzUNj4UTvUzK+TuqT9X06dOgQjRs3tmQ8rJm2MeJ2u2nTpo3dYYjEXFLm/qZNcOed8MEH5vV69eChh+AvfzF3fhUhSXNf5BgS9QOKhg0b4vF4ShVb9+7dS9OmTSPep2nTppU6vqzHKCws5NChQ2Gzbct7nOrVq1O9evVS7SkpKaQc9TfI4/FE3MG46I1FRduPftyqtLvdbqpXr86JJ55Yqj1S3pQVe7z1KVLs6pP6FKm97a9rlpZ1nkzEPpUXY2XbE75PgQD89a/w+OPF377r3BkeeQT3RRfhLjGhI2H6ZMHr1K5du4iPnch9SsbXSX2ytk+RxmlVlZgj6wTk9/v58ssvtXOiOE5S5X52Ntx9N3TqZBZsPR645RbYuBFGjVLBVsIkVe6LVFCi5nu1atU444wz+OKLL0JtwWCQL774ImzmbUm9e/cOOx7Mr4CWdXwkZ5xxBikpKWGPs379erZt21apx0kUOi+KUyn3k5jPB3/7G7Rvb34Db98+OOkkePttWLUKhg517DfwlPfiVFbmvGbaxohhGOTk5GjnRHGcpMh9rVsrVZAUuS9SSYmc72PHjmXEiBF0796dnj17MmXKFHJzc7nmmmsAuOqqq2jRogUTJ04EYNSoUfTr149nnnmGIUOG8M4777Bs2TJeffXV0GMePHiQbdu2sWvXLoDQ2mpNmzaladOmpKWlcd111zF27Fjq169P3bp1ufXWW+ndu3fETcgSnc6L4lTK/SQUDMK778KDD5oTOABatIDx4+HqqzWZA+W9OJeVOa+irYhIeebNMz81X73avK51a0VEktLw4cPZv38/Dz74IHv27CEjI4PZs2eHNhvbtm1b2Nfv+vTpw7Rp03jggQe47777aN++PTNnzqRz586hY/773/+Gir4Al19+OQDjx4/noYceAuDZZ5/F7XZz8cUXU1BQwKBBg3jppZdi0GMREak0w4CPPjInc3z/vdnWqBHcdx/cdBPUqGFvfCKSVFS0FRGJROvWiog4zsiRIxk5cmTE2+bPn1+q7dJLL+XSSy8t8/Guvvpqrr766nKfs0aNGrz44ou8+OKLlQlVRERiyTDgiy/g/vth6VKzLS3NfL8wahTUqWNvfCKSlFS0jRGPx0Pv3r0jLngskswSLvezs+HRR2HqVCgsNNetvflms2DboIHd0UkCSbjcF7GA8l3Ko/OiOJVyP8EtXmwWa4s+vKtVyyzU3nknnHCCraHFM+W9OJWVOa+ibYy43W4aN25sdxgiMZcwuR9p3dpBg8ylEDp1sjc2SUgJk/siFiprV3QR0HlRnEu5n6BWrTLfG3zyiXm9WjVzCYT77oNfl86RsinvxamsHA9rZB0jPp+PTz75BJ/PZ3coIjGVELk/bx6cfjrccINZsO3QwRycffqpCrZSZQmR+yIWU75LeXReFKdS7ieY9eth+HDo1s18T+DxwHXXwYYN5rfxVLCtEOW9OJWVOa+ibQz5/X67QxCxRdzm/qZNMGwYDBhgbjRWrx5MmWJuKqCNxsQCcZv7IiI20XlRnEq5nwC2boVrrzUnbfznP2bb5ZfDmjXwt79B69b2xpeAlPcix0fLI4iI82jdWhERERERAdizBx57DP76VyiaIfe738Ejj0CXLvbGJiKOpqKtiDiH1q0VERERERGAgwfhySfhuecgP99sO/dcc3LHb35jb2wiIoDLMAzD7iCqKjs7m7S0NLKysqhbt67d4ZTLMAxycnJITU3Fpa9ci4PETe7PmwejR5vLIIC5bu3kyXDhhVoGQaIibnJfJIaysrKoV69eQozNkoXGwyLxT7kfZ3JyzCXRnn7a/AYemEXaxx4zl00TSyjvxamsHA9rpm0M1axZ0+4QRGxha+5v3Ah33gkzZ5rXTzjBXAbh5pshJcW+uMQRdN4XEQmn86I4lXI/DuTnw8svw8SJkJlptnXpYhZrhwzRRI4oUN6LHB9tRBYjfr+fWbNmaSFucRzbcj8rC+66y1z2YOZMc93akSPNXV9vu00FW4k6nffFiZTvUh6dF8WplPs28/nM9WpPPhluv90s2LZvD++8AytXwm9/q4JtFCjvxamszHnNtBWR5BIIwN//bq5bu3+/2aZ1a0VEoi4QgK++0pteERGJE4EATJtmfsvu55/NttatYfx4uOoq8KocIiLxTWcpEUkeR69bm55evG6tiIhEzYwZMGoU7NihoaWIiNjMMOCDD2DcOFizxmxr0gTuvx9uuAGqV7c3PhGRCtLIWkQSn9atFRGxzYwZcMkl5ntkERER2xgGzJljFmeXLzfbTjjBXDLt1luhdm174xMRqSSXYdg3xF64cCFPPfUUy5cvZ/fu3XzwwQcMHTq0wvdPtN1y/X4/Xq9XOyeKo0Q197OyzI0Dpkwx16ryeMxC7UMPQYMG1j6XSCXpvC9OEAhA27awY0dRSzaQGGOzZKHxsEj8U+7HwKJFZrF20SLzeu3aMGaMuYZtvXq2huZUyntxqqysLOrVq2fJ2MzWjchyc3Pp2rUrL774op1hxEx+fr7dIYjYwvLcDwTg1VfNDQSeesos2A4aZC6L8PzzKthK3NB5X5Ld/PklC7Yix6bzojiVcj9Kli83l0Lr29cs2FavbhZrf/4ZHnlEBVubKe9Fjo+tRdsLL7yQRx99lN///vd2hhETfr+fefPmaedEcRzLc3/ePDj9dLjxRnOjsfR0mDULZs/WRmMSV3Tel2S1fz/8619w5ZVQiS9Iiei8KI6l3I+CNWvMtXm6dzffB3i95vuDjRvNPS0aN7Y7QsdT3otTWZnzCbWmbUFBAQUFBaHr2dnZAPh8Pnw+HwButxuPx0MgECAYDIaOLWr3+/2UXBHC4/HgdrvLbC963CLeX3eYPPpFKKs9JSWFYDAYehyfz4fL5cLr9RIMBgkEAqFji9rLij3e+hQpdvVJfTq6TyX7cVx92riRlPvuC61ba5xwAsFx4zBuuglvzZp6ndSnuOwTEPd/n/Q6qU/H6lMwCMuXu/jsMw+zZ7v49lsDw9DXHEVExAY//wwTJpifHgaD4HKZnyI+9BCcdJLd0YmIWCqhirYTJ05kwoQJpdrnzJlDrVq1AGjdujXdunVj9erVbNu2LXRMhw4dSE9PZ+nSpezfvz/UnpGRQZs2bVi4cCE5OTmh9t69e9O4cWPmzJkT9gaqf//+1KxZk1mzZoXFMHjwYPLz85k3b16ozev1MmTIEDIzM1myZAkAc+fOJTU1lQEDBrB9+3ZWrVoVOr5Ro0b06dOHDRs2sH79+lB7vPcJUJ/Up3L7VKQqffLm5nLKe+9x0scfg9+P4fGw+YILWDd8OL66dUldskSvk/oUl33atGkTYJ73k6VPyfg6qU+R+1S7dhuefXY9X39dl5UrG5OdXXKnbRdt22Zxxhl7ycjYx8sv92H3bpcKuSIiEj27dsGjj8Jrr0HR39Xf/x4efhg6d7Y3NhGRKLF1I7KSXC7XMTciizTTtlWrVmRmZoYW943XGTJHjhzhyy+/ZMCAAVSrVk2zftQnx/TJ5/Px5ZdfMmjQII5Wbp8CAQKvvYZn/HhcRcWECy4g+PTTBE45xdY+lWxPltdJfbK+T0eOHOHzzz9nwIABpKSkJEWfkvF1Up/MPgUCBitWuPj0Uxdz5rhZutRFyRFi3boGAwfC4MEuzj3XR4sWxbf9979eLr3UvGwYOWgjsthKpI3IfD4fc+bM4fzzzyclJcXucERiRrl/HDIzYdIkeOEFOHLEbDv/fLOA26OHvbFJuZT34lQHDhygYcOGlozNEqpoe7REGqSKSCV8+aW5gcDq1eb19HRzbaoLL7Q3LhGRJHLgAMyZYy4L/tln5lq1JXXpYp52L7wQ+vSB8t5vzZgBo0bBjh3ZqGgbWxoPi0hSysoyx//PPgtF32Q580x47DHo18/e2EREymHl2CyhlkdIZMFgkMzMTBo2bIjbbev+byIxVanc37gR7rwztG4tJ5xgrll1003lVwtE4pDO+xJvgkFYsQI+/dQs1C5darYVSU2F884zi7QXXAAtW1b8sYcNg4suglmzgvzud9bHLslB50VxKuV+JeTlmbNqJ02CgwfNtm7dzGLtBReYa9hKQlDei1OV/Gbb8bL1f87hw4dZtWpVaD22zZs3s2rVqrD11pJFIBBgyZIlYV9jFHGCCuV+VpZZrO3UySzYejxw662wYYP5rwq2koB03pd4cPAgvPMOjBgBzZqZ3yR98EH43//Mgu1pp8Fdd8G8eeY3UN9/H/7858oVbIt4PNCnj/JdyqbzojiVcr8CCgvhxRfNzcTuvtv8A5aeDu+9B8uWmZ8oqmCbUJT34lRW5rytM22XLVtG//79Q9fHjh0LwIgRI3jzzTdtikpEYiYQgL//HR54oPh7uRdcYH4VqmNHe2MTEUlAwSCsXFk8m/abb8Jn09apEz6btlUr+2IVERHB74d//cv8dt2WLWZb27bw0EPwxz+anwiKiDiUrUXbc845hzhZUldEoiEQwLVgAS0WLsRVuzb071888NK6tSIilvjlF3Nt2k8/hdmzYe/e8Ns7dy5em/bMM6FaNXviFBERCQkGza93jBsH69ebbc2amZM5/vxn/bESEUFr2saMy+UiNTUVl77SIU7x66403h076A5mQbZlS/PrTp9/Dh9+aB6ndWslSem8L9ESDMKqVWaR9tNPYcmS0rNpBw4snk3bunXsYlO+S3l0XhSnUu6XYBjmH6/77zf/mAHUrw/33AO33AK1atkanlhHeS9OZWXOu4wEnuqq3XJF4tSMGXDJJeagrCwejzkwGz/eHKiJiEiZfvkF5s4tLtQePZu2UycYPNgs1J51ln0TlDQ2iz39zkUkYSxYAPfdB4sXm9dTU2HsWPNH5y8RSRJWjs000zZGgsEg27dvp1WrVto5UZJbIACjRpVfsK1RA7791vzOrkiS0nlfjodhlJ5NW3JPg9q14dxzzULtBRdAmza2hRrGyt1yJfnovChO5fjc//Zbc2bt3Lnm9Ro1zM2G77oLGja0NzaJGsfnvTiWleNhFW1jJBAIsGrVKpo3b64TliQvw4APPoAdO8o/7sgRc5tykSSm875U1qFDxbNpZ8+G3bvDb+/YMXw2bfXqtoRZLu0QLeXReVGcyrG5/8MP5pq1M2ea11NS4PrrzQJu8+a2hibR59i8F8ezcjysoq2IVM2RI7B2LXz3nbmZWNG/FS3GHl2NEBFxGMMwT51Fs2kXLw6fTVurVvhs2rZtbQtVRESk4jZuhIcegmnTzD92bjf86U/msmjt2tkdnYhIwlDRVkTKZxhmgbVkYfa772DduvDqQhGXq/ylEYo0a2Z9rCIicS4rK3w27a5d4benp5szaQcPhrPPjs/ZtCIiIhFt3w6PPAKvv178PuGSS+Dhh82vi4iISKWoaBsjLpeLRo0aaedEiW+VnT17wgnQtWvxT5cu0KGDOSjbuTNy8dblgpYtzWqESBLTeV/APA1+/z3MmlU8m9bvL769Vi0YMMAs1F54YeJPQFK+S3l0XhSnSvrc37cPJk6El1+GggKz7cIL4dFH4fTT7Y1NbJP0eS9SBitz3mUYFZkSF5+0W65IFRXNnj26OFvW7Fm32yzGFhVmi/5t0cIswh5txgzzU/Wi5ypSdOz06TBsmPX9EhGJA9nZ8PnnZqF29mzzM6ySOnQIn01bo4Y9cUaDxmaxp9+5iNjm0CF4+mmYMgVyc822vn3hscfMxddFRBzIyrGZZtrGSCAQYMOGDbRv3x6Px2N3OOIklZ09W79+6eJsp05Qs2bFn3PYMLMwO2pU+KZkLVuagzoVbMUBdN53DsMw91opmk379dfhs2lr1jRn0xZtIpbos2nLo43IpDw6L4pTJV3u5+bCc8/Bk0+ahVuA7t3NYu1550We1CGOk3R5L1JB2ogsAQWDQdavX89JJ52kE5ZEhxWzZ7t2NXdytWKgNWwYXHQR/nnzWPXpp2RceCHe/v1B+S8OofN+csvOhi++KJ5NW/LzKYBTTimeTdu3b3LNpi1PMBi0OwSJYzovilMlTe4XFMBf/2oWZ/ftM9tOPdVcx3boUBVrJUzS5L1IJVk5HlbRViQR2TF7tio8Hox+/diZm0vXfv1UsBWRhGUY8OOP5kzaWbPgq69Kz6bt3794bdqTTrIvVhEREUv5/fDWWzBhgrnZGMCJJ5rXr7hCY3wRkShR0VYknsXb7FkREQfJyTFn0376qflT9D61SPv24bNpo/05mIiISEwFg/Cf/8CDD8KGDWZbixbm9WuugZQUe+MTEUlyKtrGiNvtpnXr1rjdbrtDkXiVKLNnK0m5L06l3E88hgFr1oTPpvX5im+vUSN8Nu3JJ9sXa7xSvkt5dF4Up0q43DcM+PhjeOAB8/0IQMOGcN99cNNNcfd+Q+JTwuW9iEWszHmXYZTc2j2xaLdcSUiVnT3r8ZizZ0sWZzV7VkTEEocPh8+m3bYt/PaTTy4u0p5zjt6nHovGZrGn37mIWOqLL+D+++Gbb8zraWlwxx3mBsOpqfbGJiKSAKwcm2mmbYwEAgFWr15Nly5dtAi3k1gxe/bUUxN6BxvlvjiVcj8+GYZ5Wi6aTbtoUenZtOecU1yobd/etlATkpW75Ury0XlRnCohcn/JErNYO2+eeb1WLbjtNrjzTvM9ikglJUTei0SBleNhFW1jJBgMsm3bNjp37qwTVjLS7NkyKffFqZT78ePwYfjyy+JC7dGzaU880VyXtmg2ba1atoSZFKzcLVeSj86L4lRxnfvffWcug/Dxx+b1atXgxhvNpRCaNrU3NklocZ33IlFk5XhYRVuRyio5e7ZkkfbAgcjHJ+HsWRGReGYY5mdmJWfTFhYW3169eunZtEn2eZmIiEj5fvrJ3FDs3XfN6243XH212damja2hiYiISUVbkbJo9qyISMLIzQ2fTbt1a/jt7doVz6bt31+zaUVExKG2boWHH4a33ip+TzN8OEyYYL6XERGRuKGibYy43W46dOignRPjVVVnz5YsznbqpNmzESj3xamU+9FlGLB+ffEGYgsWhM+mrVYtfDbtKafo87NYUL5LeXReFKeKi9zfswcefxz++tfiP5i//S088ghkZNgXlyStuMh7ERtYmfMuwzAMyx4txrRbrlSaZs+KiCSs3Fxzf5SiQu3mzeG3t20bPpu2dm1bwnQ0jc1iT79zESnXwYPw1FPw3HOQl2e2DRgAjz4KvXvbG5uISBKycmymmbYx4vf7Wbp0KT179sTr1a89JjR7Ni4o98WplPvHzzDMJfdKzqYtKCi+vVo16Nu3uFDboYM+T7Ob3++3OwSJYzovilPZkvs5OTB1qlmwzc4223r1gsceg3PPjU0M4mg654tTWTke1v+cGDEMg/3795PAE5vjl2bPxjXlvjiVcr9q8vLCZ9P+/HP47W3ahM+mrVPHnjglMuW7lEfnRXGqmOb+kSPw8svmUgiZmWbbaaeZxdrf/lbvdyRmdM4Xp7Iy51W0lcRy5AisWRNenNXsWRGRhLZhg7l52Kefwvz54bNpU1KgX7/itWnT0/V+U0REpBSfD954w9xkbOdOs619e/P6ZZeB1hUVEUk4KtpKfIo0e/a778xdZzR7VkQkoeXlmcXZotm0mzaF3966dfFs2gEDNJtWRESkTIEAvPMOjB9f/Ae1VSvz+ogRoK+li4gkLJ3BY8Tj8ZCRkYHH47E7lPij2bNJTbkvTqXcD7dxY/hs2iNHim9LSYGzzy4u1HbsqM/bEpXyXcqj86I4VVRy3zBg5kwYNw5+/NFsa9wY7r8fbrhB743Edjrni1NZmfMq2saI2+2mTZs2dodhr+OdPVtUpNXs2YSi3Bencnru5+ebG4cVFWo3bgy/vVWr8Nm0qan2xCnWcuvrt1IOp58XxbkszX3DgLlzzeLssmVmW716cNddcNttULu2Nc8jcpx0zhensnI8rKJtjPj9fhYuXEjfvn2dsXOiZs/KrxyX+yK/cmLub9pUXKSdN6/0bNqzziou1HbqpM/fkpGVu+VK8nHieVEELMz9r74yi7ULF5rXa9eG0aPhjjvMwq1IHNE5X5zKyvGw/ufEiGEY5OTkJN/OiZo9K8eQtLkvcgxOyP0jR8Jn027YEH57y5ZmgXbwYDj3XM2mdYJkznc5fk44L4pEcty5v2IFPPCA+ccWoHp1uPlmuPdec0kEkTikc744lZU5r6KtVJxmz4qION7PP5vvGWfNMmfT5ucX3+b1mrNpiwq1p56qz+NERESqbO1aePBBmD7dvO7xwLXXmuvYtmplb2wiIhJ1KtpKaVWdPVuyOKvZsyIiSeHIEfNbmEWF2p9+Cr+9RYvw2bR169oTp4iISNLYvBkmTIB//hOCQfM91RVXmG0nn2x3dCIiEiMq2saIx+Ohd+/e8bdzombPSpTFbe6LRFki5/7mzeGzafPyim/zeuHMM4sLtZ076/M5KZaI+S6xk8jnRZHjUeHc37ULHnsMXnsNfD6zbehQePhhOO20qMcpYiWd88WprMx5FW1jxO1209jO9YYMwxwEHF2c1exZiTLbc1/EJomU+wUF4bNp168Pv715c7NIe+GFMHAgpKXZE6fEPyt3y7XDiy++yFNPPcWePXvo2rUrzz//PD179izz+Pfee49x48axZcsW2rdvz6RJkxg8eHDodsMwGD9+PK+99hqHDh3izDPP5OWXX6Z9+/ahY9q2bcvWrVvDHnfixIncc8891nfQZol0XhSx0jFz/8ABmDQJnn++eBfP886DRx+Fcs5BIvFM53xxKivHwyraxojP52POnDmcf/75pKSkRPfJNHtW4khMc18kjsR77m/ZYhZpP/0UvvgifDatx1M8m/bCC80/Dfq8TirCVzQzLAG9++67jB07lldeeYVevXoxZcoUBg0axPr16yO+6Vy8eDFXXHEFEydO5Le//S3Tpk1j6NChrFixgs6dOwPw5JNP8txzz/HWW2/Rrl07xo0bx6BBg1izZg01SoyzHn74Ya6//vrQ9dQk3bUv3s+LIlERCOCfN4/v58zhtPPPx9u/v/mHFiA7G559Fp55BnJyzLY+fczZtuecY1vIIlbQOV+cysrxsIq2MeT3+619QM2elQRhee6LJIh4yv2CAli0qLhQu3Zt+O3NmoXPpq1Xz5YwRWwzefJkrr/+eq655hoAXnnlFT755BNef/31iLNep06dygUXXMCdd94JwCOPPMLcuXN54YUXeOWVVzAMgylTpvDAAw9w0UUXAfCPf/yDJk2aMHPmTC6//PLQY6WmptK0adMY9NJ+8XReFIm6GTNg1Ci8O3bQDeCpp6BlS3jySdixw5xdWzSxJiPDLNZeeKHem0nS0Dlf5PioaBsLgQCuBQtosXAhrtq1oeSnqxVV2dmzDRqULs5q9qyIiKNs3Ro+mzY3t/g2j8eczFNUqO3aVe8RxbkKCwtZvnw59957b6jN7XYzcOBAlixZEvE+S5YsYezYsWFtgwYNYubMmQBs3ryZPXv2MHDgwNDtaWlp9OrViyVLloQVbZ944gkeeeQRWrduzR/+8AfGjBmD16thukhCmzEDLrnEnGhT0o4d8Ic/FF/v0AEeeQQuvhgSfIkZERGxlkaD0Vbi09XuAJMnm5+uTp0Kw4aVPl6zZ0VEpIoKC8Nn065ZE35706bFRdrzztNsWpEimZmZBAIBmjRpEtbepEkT1q1bF/E+e/bsiXj8nj17QrcXtZV1DMBtt93G6aefTv369Vm8eDH33nsvu3fvZvLkyRGft6CggIKCgtD17OxswPwqXtHX8dxuNx6Ph0AgQDAYDB1b1O73+zFKFJI8Hg9ut7vM9qO/5ldUUD56BlVZ7SkpKQSDwdDj+Hw+XC4XXq+XYDBIoMQYt6i9rNjjrU+RYlef1Ce3YWCMGgWGQVnvyAyPh8DLL+O5+mrwes3YS8QTd31KxtdJfYpqn4qUfN5E71Myvk7qk/V90vIIiaKsT1d37jTbp02DU07R7FlJal6vl/79+2vGkDhOrHJ/27bw2bSHDxff5vFA797FhdqMDH2eJ9Glc33llZyt26VLF6pVq8aNN97IxIkTqV69eqnjJ06cyIQJE0q1z5kzh1q1agHQunVrunXrxurVq9m2bVvomA4dOpCens7SpUvZv39/qD0jI4M2bdqwcOFCcorW1QR69+5N48aNmTNnTtgbqP79+1OzZk1mzZoVFsPgwYPJz89n3rx5oTav18uQIUPIzMwMzVqeO3cuqampDBgwgO3bt7Nq1arQ8Y0aNaJPnz5s2LCB9SV2RYz3PgHqk/pU3KdNm3Dt2EF5XIEA/9u7ly75+YnRp2R8ndSnqPapd+/etG3blrlz5yZNn5LxdVKfrO9TXsnNQo6TyzCOrigmjuzsbNLS0sjKyqJu3bp2hxMuEIC2bc2vv1SWZs9KEjEMA7/fj9frDfvEVSTZRSv3Cwvhq6+KC7U//hh+e9OmcMEFxbNpTzjBsqcWOaasrCzq1asXn2OzchQWFlKrVi2mT5/O0KFDQ+0jRozg0KFDfPjhh6Xu07p1a8aOHcvo0aNDbePHj2fmzJl89913/Pzzz5x00kmsXLmSjIyM0DH9+vUjIyODqVOnRozlxx9/pHPnzqxbt44OHTqUuj3STNtWrVqRmZkZ+p3H6wwZv98fOi+63e64nSFTmT4lyqwf9SmKffL5MLZtw7VsGa5vv8W9fDmu//3PXN7uGPz/+AeeP/4x/vqUjK+T+hTzPnk8HgoKCnC73aGxcKL3KRlfJ/XJ+j5lZWXRqFEjS8bDmg4RLYsWVaxgW7cudO+u2bOStPx+P7NmzWLw4MHaNVQcxcrc3769uEj7+efhs2nd7tKzabUkntglUTccqVatGmeccQZffPFFqGgbDAb54osvGDlyZMT79O7dmy+++CKsaDt37lx69+4NQLt27WjatClffPFFqGibnZ3NN998w80331xmLKtWrcLtdtO4ceOIt1evXj3iDNyUlJRS5xqPx4Mnwj4KZc2ILqu9rHNYZdqL3rTPnTuXwYMHh57L7XbjjnDSKiv2eOtTpNjVpyTv04EDsGwZLF0KS5fi/fZb2Ls3YjzH4m3VKjQpR6+T+pRsffL5fHz22WcRx8KJ2idIvtcJ1Cewtk9WTthR0TZadu+u2HEvvxy+EL2IiDheYSF8/XVxofaHH8Jvb9w4fG3a+vXtiVMkmYwdO5YRI0bQvXt3evbsyZQpU8jNzeWaa64B4KqrrqJFixZMnDgRgFGjRtGvXz+eeeYZhgwZwjvvvMOyZct49dVXAXOmx+jRo3n00Udp37497dq1Y9y4cTRv3jxUGF6yZAnffPMN/fv3JzU1lSVLljBmzBj++Mc/coKmyYvYLy8PVq4MFWj59lvYtKn0cR4PnHYa9Oxp/px+Ovzud+ayeJG+2OpymfucnH129PsgIiIJS0XbaGnWrGLHNW8e3ThERCTmAgFYsMDFwoUtqF3bRf/+5vu58uzYET6btsTSSbjd8JvfFBdqu3XTbFoRqw0fPpz9+/fz4IMPsmfPHjIyMpg9e3ZoI7Ft27aFzeTo06cP06ZN44EHHuC+++6jffv2zJw5k86dO4eOueuuu8jNzeWGG27g0KFDnHXWWcyePZsav36jqnr16rzzzjs89NBDFBQU0K5dO8aMGRO2zq2IxIjfb645VLJA+8MPkTeDbt8eevQwC7Q9ephfc/l1TemQqVPNfUxcrvDCbdEMrClTjj04EBERR9OattFStKbtsT5d3bxZf6wlqfl8Pi2PII4yYwaMGhW+Qk7LluZ7t2HDitt8vvDZtN9/H/44jRoVF2nPP1+zaSUxHDhwgIYNG8bn2CxJxfV4+CgaE0jcMAz4+efwAu2KFZCfX/rYpk2LZ9D26GEubVfRP8qRBgWtWpkF25KDApEkpHO+OJWV42EVbaNpxgzz01WI/Onq9On6Yy1JTxuRiZMUnfaP/stalPp//at5uWg2bXZ2+DG9esHgwWah9vTTNZtWEk+ibkSWyOJ+PFyCxgRimz17zMJsUYH222/h4MHSx6Wmhs+g7dkTWrQ4vs2gAwGMhQsJ7NiBp2VLXH37atKOOILO+eJUVo6HtTxCNA0bZhZmI0250qer4iD5+fmkpqbaHYZIVAUC5uk+0kehRW033BDe3qgRXHBB8WzaBg2iH6eIiJ00JpCoy86G5cuLC7RLl5o7eh6tWjVzWYOSBdpTTrH+E1OPB845h7ycHDP3VbwSB9E5X+T4qGgbbcOGwUUX4Z83j1WffkrGhRfircjihiJJwu/3M2/ePH0tRpLe55+Hfz5Xlo4d4fLLzULtGWdoNq0kF7/fb3cIEsc0JhDLFRTA6tXFxdmlS2HdushfeenYMbxA26WLWbiNAeW+OJHyXpzKyvGwirax4PFg9OvHztxcuvbrp4KtiEgCO3AA1q413xOW/Hfz5ordf9w4uOKK6MYoIiKSdIJBWL8+vED73XdQWFj62Natwwu0p58Ocb58iIiIyNFUtBURETlKMAjbtpUuzK5dC5mZx/fYzZpZE6OIiEjSMgzz6ysllzhYvjx8Mfgi9euHF2h79IAmTWIfs4iIiMVUtI0hr1e/bnEm5b7Eq4IC2LChuCBbVJxdvz7yBtJFWrc2v2XZsSOkp5v/tm9vvlfcuTPyurYul7mk+dlnR68/IiLxTmMCiejgQVi2rHgG7bffmpuHHa1mTXNtoaICbc+e0K5dQqwTq9wXJ1Leixwfl2FEemuZGBJpt1wREbHPoUOlZ8yuWwc//2zOqo0kJcXcj6RkYTY9HTp0gNq1I99nxgy45BLzcsm/rkXvJadP1x6Uktw0Nos9/c4l4eTnw8qV4QXajRtLH+fxwGmnhc+gPfVUUBFIRETimJVjM/3Fi5FgMEhmZiYNGzbErV1nxEGU+xIrRd+kPLowu3Yt7N1b9v3q1i2eNVuyQNuuXeXfFw4bZhZmR40K35SsZUuYMkUFW0l+wbI+BRFBYwJH8vthzZrwAu3330MgUPrYk08OL9B26wa1asU+5ihQ7osTKe/FqawcD6toGyOBQIAlS5YwePBgnbDEUZT7YrXCQti0qXRhdt06yM0t+34tWpQuzKanQ9Om1n6rctgwuOgimDfPz6efruLCCzPo39+rPSjFEQKRCjEiv9KYIMkZhrkrZ8kC7YoVkJdX+tgmTYqXN+jZE7p3N9emTVLKfXEi5b04lZXjYRVtRUQkLmVnm4XYo2fObtpkTtyJxOs1J+ocXZhNT4fU1NjF7vFAv34Gubk76devqwq2IiKSfPbuLd4k7NtvzZ8DB0ofl5pqFmVLbhbWsmVCrEMrIiJiJxVtRUTENoYBu3dHXtJg166y71enTnFRtmSB9qSTzLVoRURExEI5ObB8eXGBdulS2Lat9HHVqkHXruEF2g4dQLPsREREKk1F2xhxuVykpqbi0ifK4jDKfQFzZuzPP0de0iA7u+z7NW0aeUmDFi3if4KOcl+cSPku5dF5MUEUFsLq1eEF2rVrw3fYBPMPcXp6eIG2SxeoXt2euOOYcl+cSHkvTmVlzrsM4+i/volDu+WKiMSXw4dh/frSM2c3bACfL/J93G5zhmykJQ3q1Ytp+CJynDQ2iz39zuW4BIPw00/hBdpVq8zC7dFatQov0J5xhrmbp4iIiIRYOTbTTNsYCQaDbN++nVatWmkRbnEU5X7yMQzYty/ykgbbt5d9v1q1zG9IHj1z9uSTk3NSjnJfnMjK3XIl+ei8aDPDgJ07wwu0y5ZF/srLCSeEF2h79DC//iJVotwXJ1Lei1NZOR5W0TZGAoEAq1atonnz5jphiaMo9xNXIABbtkRe0uCXX8q+X6NGkZc0aNXKWUvaKffFiazcLVeSj86LMfbLL8UbhBUVanfvLn1czZpw+unFBdqePeHEE+N/HaIEotwXJ1Lei1NZOR5W0VZExOHy8yMvafDTT1BQEPk+Lhe0axd5SYMGDWIbv4iIiOPl55vLGixdWlyg3bCh9HEeD3TuHF6gPfVU8OptoYiISLzRX2cREYfIzIy8pMHWraX3FilSvXrkJQ3atzcn5oiIiEiM+f2wZk34DNrvvzfbj3bSSeEF2m7dzPWKREREJO6paBsjLpeLRo0aaedEcRzlfmwFg7BtW3hRtuhyZmbZ96tfP/KSBm3amJNypPKU++JEyncpj86LVWAYsHlzeIF2+XLIyyt9bOPGxcXZnj2he3d9/SVOKPfFiZT34lRW5rzLMMqaXxX/tFuuiDjVkSPmtx6PLsyuX29+Q7IsbdoUF2VLFmgbNtTSdSJy/DQ2iz39zpPMvn3hBdqlS+HAgdLH1aljFmWLCrQ9epiLx+uPuYiIiK2sHJtppm2MBAIBNmzYQPv27fFo2po4iHL/+PzyS+QlDTZvNmfVRlKtmrl8wdGF2VNOgdq1Yxu/kyn3xYm0EZmUR+fFoxw+bM6aLVmg3bq19HEpKdC1a3iBtkMHfRUmgSj3xYmU9+JU2ogsAQWDQdavX89JJ52kE5Y4inL/2AwDduyIvKTB3r1l3y8tLfKSBu3aaT+ReKDcFycKlvVpkggOPy8WFprrzpYs0K5dG/kT2PT08AJt167mIvOSsByd++JYyntxKivHw3pbLyISI4WFsHFj6cLsunWQm1v2/Vq2jLykQZMm+hakiIhI3AkGzTWMShZoV62CgoLSx7ZsGV6gPeMM81NZERERiYUzFAAAK/hJREFUcTwVbUVELJadHXlJg02boKxvSni9cPLJpQuzHTpAamps4xcREZFK2LkzvEC7bBlkZZU+7oQTzMJsUYG2Rw9o1iz28YqIiEhCUNE2RtxuN61bt8btdtsdikhMJWvuGwbs3h15SYNdu8q+X2pq+FIGRf+edJK5ZJ0kj2TNfZHyKN+lPElxXvzlF7MoW3KzsEh/+GvUgNNPLy7Q9uxp/rHXV2QcKSlyX6SSlPfiVFbmvMswDMOyR4sx7ZYrItHm95szZCMtaZCdXfb9mjWLvKRB8+Z6vyYiyUtjs9jT7zyK8vPNZQ1KFmh/+qn0cW43dO4cXqA99VR9GisiIuJAVo7NNNM2RgKBAKtXr6ZLly5ahFscJVFy//BhWL++9JIGGzeCzxf5Pm63OWkm0pIG9erFNHyJQ4mS+yJWsnK3XEk+cX1eDARgzZrwAu3q1eant0c78cTwAm23blC7duxjloQR17kvEiXKe3EqK8fDKtrGSDAYZNu2bXTu3FknLHGUeMp9w4B9+0oXZtetg+3by75frVqRlzQ4+WRt5ixli6fcF4kVK3fLleQTN+dFw4AtW8ILtMuXR94VtHHj8AJt9+7QsGHMQ5bEFje5LxJDyntxKivHwyraikjSCQRg8+bShdm1a+HQobLv17hx5CUNWrY0Z9WKiIhIAtq/P7xAu3QpZGaWPq5OHbMoW1Sg7dkTWrXSukYiIiJiCxVtRSRh5eWZSxocXZj96ScoLIx8H5cL2rUrXZhNT4f69WMbv4iIiFjs8GFYsSK8QLtlS+njUlKga9fiAm2PHuZgQLPBREREJE6oaBsjbrebDh06aOdEcZRAABYudLNxY3cWLnRzzjlVey+UmRl5SYOtW81vOEZSo4a5tuzRhdlTTjFvE4k2nffFiZTvUh7Lz4s+H3z/fXiBds0aiPS1xPT08AJt164aEEjMaEwgTqS8F6eyMudVtI0Rj8dDenq63WGIxMyMGTBqFOzY4QFaAOYyA1OnwrBhpY8PBs0ibKQlDQ4cKPt56tcvnjVbskDburUmy4i9dN4XJ9KadVKmQADPokWk794Ne/bA2WdX7g91MGjuDlqyQLtyJRQUlD62ZcvwAm337pCWZl1fRCpJYwJxIuW9OJWV42EVbWPE7/ezdOlSevbsiderX7sktxkz4JJLSs+C3bnTbH/mGWjRIrwwu349HDlS9mO2bRt5M7BGjaLaFZEq03lfnMjv99sdgsSj4k9yi9vK+yQXYNeu8ALtt99CVlbp4+rVCy/Q9ugBzZtHpRsiVaUxgTiR8l6cysrxsP7nxIhhGOzfvx+jrO9yiyQgwzDXlT18uPgnOxtuvDHysgVFbWPHRn68atXM5QuOLsx26AC1akWvHyLRoPO+OJHyXUo51ie506fDgAGwbFl4gXbnztKPVaMGdOtWXKDt2RNOPlkbhUnc05hAnEh5L05lZc6raCviAIZhbsxVsrh6+DDk5pZuK6s9UlteXtlryh5Lp07Qq1d4gbZtW9CHsCIiIkkiEDBn2Jb3Se7w4RBpRorbDaeeGl6g7dzZ3EBMRERExAFUHhGJM35/5QqnFW0LBKIbd5065k8wCPv2Hfv4Bx6AK66IbkwiIiJio0WLwpdEiKSoYNuuXXiB9vTToXbt6McoIiIiEqdUtI0Rj8dDRkaGNuhIIsFg8dIAx1NMPbo90n4aVqpRo7jAWrt28eXy2o51bM2a5oQYgPnzoX//Y8fRrFlUuyliO533xYmU7xJm9+6KHffXv8INN0Q3FhEbaUwgTqS8F6fSRmQJyO1206ZNG7vDcCTDMAuhVi0JUNSWmxvduL3e4y+mHt1Wu3b0lx84+2xzb5GdOyN/G9LlMm8/++zoxiFiN533xYncRZ/giUDFP6E95ZToxiFiM40JxImU9+JUVo6HVbSNgUAA5s8PMG/eOvr3T+ecczzow6bIfL7ShVIrlgmI5tIALlflCqcVbatWLTH31fB4zM2gL7nEjL9k4baoP1OmoP8DkvT8fj8LFy6kb9++2jFXHMPK3XIlCeiTXBFAYwJxJuW9OJWV42H9z4myGTPM/Rd27PAAp/LYY+bYdOpUGDbM7uiqruTSAFZubFVYGN24a9a0vsBas2ZiFlejadgwczNoM/eL21u2NAu2iZz7IhVlGAY5OTnaMVccRfkuYfRJrgigMYE4k/JenMrKnFfRNopmzDDHqEe/Xjt3mu3Tp0e/eGUYcOSI9Rtb5eVFN26vF1JTrS2w1q6t9wSxNGwYXHQRzJvn59NPV3HhhRn07+/VayAiIuIk+iRXREREHCIQDPDVtq8sezwVbaMkEDDHppEK7IZhTi4YPdosahUVsXw+a5cEKPoJBqPXT5fLmtmqR7dXqxa9mCV2PB7o188gN3cn/fp1VcFWRETEiX79JNc/bx6rPv2UjAsvxNu/vz5NFxERkaQxY+0MRs0exY59O459cAWpaBslixaFTyY4mmHA9u3m/gyBQOyWBrByUystDSAV4fF46N27t3YNFcdR7osTKd+lTB4P7gEDaN2lC+6GDUGb1omDaEwgTqS8FyeZsXYGl/znEoygC7aeCXxtyeOqaBslu3dX7Lj9+0u3paRYv6mVlgYQu7jdbho3bmx3GCIxp9wXJ7Jyt1xJPjovilMp98WJlPfiFIFggFGzR2GsGQqzp0J2GpBmyWOraBslzZpV7LiXX4Z+/cILrFoaQJKJz+djzpw5nH/++aSkpNgdjkjMKPfFiXw+n90hSBzTeVGcSrkvTqS8l2QRNILkFORw6MihiD8r96xkx/96wH+m/3qPw5Y9t4q2UXL22eb+Cjt3Rl7X1uUyb7/+es2AleTn9/vtDkHEFsp9EZFwOi+KUyn3xYmU9xIPgkaQ7ILsiAXXrCNZxdcLIhdls45kYRChsBd6AjfM3vLrFWu/daaibZR4PDB1KlxyiVmgLVm4LVoDdsoUFWxFREREREREREQiKa/oWpGfrPzD4KsBvlrgq/3rv2X9tC3dVmjex+WvgzeQitufistvthmFNfDl1yDoqx6VvqtoG0XDhsH06TBqVPimZC1bmgXbYcNsC01ERERERERERCSqAsFAxKLrL/mHOHA4m8ysPDKz8jiYc4RfsgvJyvGRddhHzuEAuXkG+Xku8NWMXGQtbAq+E8svxAZqWNIPA4j1QmAuw4j05f3EkJ2dTVpaGllZWdStW9fucMoUCMDChQabN+fTrl1N+vZ1aYatOIZhGOTk5JCamoqraJq5iAMo98WJsrKyqFevXtyPzZJJwoyHgwEWbl3I5v2badeoHX3b9MXj1oBYnEFjAnEi5X1iCwbhyBHIy4OcwwH2/JLDvkOHyczKJTMrn4PZR/gl+wiHcnxkH/aTfdjP4bwgubkG+fkujuS5OJLvxXfEi7+gWtkFVSO2c0lr1jT3kqpVq+I/xzp+9Wq48sqSz5INWDM200zbGPB44Jxz4KyzUvB6i5dHEHGKmjVr2h2CiC2U+yIiMGPtDEbNHsWO7OKvnrWs25KpF0xlWEd99UycQWMCcRJz4hrs2FGLli2hb18tDWklv98sppb8yc0t3Vb0czg3wC85hRzKLvx1Bquf7MNBcvMM8vIwC6xHPBQe8eI7kkKgsBrBwpKzUz1AvV9/osPlDlCtpp/qNQLUrBmkZi2D2rVc1K7tJrW2h7qpXlJreypdUC35U6MGuK1dchaAjh3h7rthx04DDGsLfiraxojf72fWrFkMHjxYOyeKoyj3xamU++JE2nBEjjZj7Qwu+c8lpTbw2Jm9k0v+cwnTL5uuwq0kPY0JxElmzChaItJFUcmpZUtzz59kXyLSMKCgoHIF1Ug/pY83OJxrkJdnzmL1+ypbefQANX/9qQLPEaiWi6taPu5qBaRULySlho9qNQLUqBGgZi2DWrVc1Kntok5tD3XreKlXN4UTUqtTv24NGtStScO0WqSlppRZUE1J8eByJWZlv3hPKxe4DKxcz0BFWxERERERsVwgGGDU7FERd1wuarv+o+vJLcylurc61TzVSHGnkOJJCV2u5qlGiielzMtFx3ndXn39VkTEZjNmmJuxH1202rnTbJ8+3b7CbSAA+fnRKKiG/0RnAVLXrz9HC0K1XEjJq9CPt7rv1wKrQe3aLnMGax0PaXWqUS+1GifUrR4qsDaqV5sm9VJpUq8uDWrXI61GGtU8DaLRuaRQvKeVK2xPq+Oloq2IiIjFAsEAC7YuYOEvC6m9tTb9T+yvtRsl6QWCAb7a9pXdYUgcWbRtUfGSCEE3bD0bDjeDOruhzSJwBzmYf5CrZl5lyfMdXfAt6/IxC8HuihWKq1JcLuuy/kYkL40JxCkCAXOGbaSipWGYy0SOHg0XXRS+VIJhgM9XtQJpZY4vKIjZrwIAl8dnFlS9eRgpFS2sln9czVqQVieFtNQU6tetwQl1anFCzXrUq1HWTxvq1ahHWvW0X4uu1WL7S3CYYcPM/P74Yz9Dh1rzmCraxoD+UItTKffFiY5eu3Hy1slau1GSXijv91k4tUAS3u6c3eaFNb+H2VMhu1XxjXW3wwWjoNMHdG7cmQY1G1AYKMQX9OEL+CJeLgwU4gv48AV9+IOll+LwBc3b8nx5MeqhddwutzWF4EoUiqtSXC4rJs1yjmzG2hncNmsMO79vB4ebMXn+o7Q47TqeG/ysxgQSE4GAWRAt+vH7w69b+bNpE+XOMDQM2L4dTj7ZLOCWLKgGArH7nQBUrxmgWg0/KdUL8VYvxF3tCK6UfIyUPIyUwwQ8h/F7svF5DlHgOoTfk3XMgmr4Tz6Gp/TfqTrV6kQurlY3Z7LWq9G4zAJsWvU0UjxaXiXeeTxw1lnWTbd2GUZ0Jm/HQiLslquNF8SplPviRGWt3ej69etMWrtRklFY3h8BniCux2bJJp7Hw/O3zKf/3c/Bf6b/2lJyDb6g+c9llzBv0m2c0/acSj22YRjlFnUjXT5WIbii9znm7cd4/MJAYcQlIxKZ1+21rhBsQfG6MvfxuDxRKTrPWDuDix96G2ZPifCBxWjef+hKjQniUNGsz0g/0Sx4Rus5Eq3a4/FE3liqZs0gKTV8eKoX4qlWgKtaPqTkYXhzCXpz8Huz8buzKfQc4ojrIEc4SB6ZHGYfh419HOFAcUHVewTcVfvFlFd0LXu2q/lTt3pdFV0dIisri3r16lkyNlPRNor05l2cSrkvThQIBmg7tW3YBxUluXDRsm5LNo/arBnnUWYYBgZG2L9BI1iq7Vj/6j7Hvk8gGODuz+/mlyO/mL98FW1jLp7Hw4W+ALUa7SWQ1ZTwgm2RIJ56u8nb15RqKc46LwaCgcoViq0oTlv0mL6gz+5fn6VcuKwpBJdYVsPr9vLKP/dxZNo/f32W0h9Y1LpyBBP+0hW3y40LFy6XKzRWLrpcVEyuzO1VuU9lby/ZhuHCCLrx+10E/R4Cfnepn2DAg9/nIhgobvP7XaWP9bnwB9wEfEcfY268FLruc+EP3V7ytl+v+3+93ffrZZ/Z7vMVHW+2+3wu/H7Cjwsk/6xxrxdSUqr+U9b9d+8O8v77x94g69Zxmzmpyz4K3YcocP/CEZdZYM019pMTOMChI4dK/eT78y3pe2q11GMWV8srunrd+rK6HJuVRdu4yLgXX3yRp556ij179tC1a1eef/55evbsaXdYx+VYGy+4cDF69mgu6nCRrW/ei2r2RXGWdb0ixyTK9XiIIZn74A/6ufHjG8vMfYAbP74RFy48bk+lcq6ibeX9nsqLPVqPr+dMrucs67F25ewqs2BbdOz27O0MnjaYJrWbEG/FNyvuEzSClXr8aNxH5HhUdkz63nvvMW7cOLZs2UL79u2ZNGkSgwcPDt1uGAbjx4/ntdde49ChQ5x55pm8/PLLtG/fPnTMwYMHufXWW/noo49wu91cfPHFTJ06lTp16kS1r7Gw+GsPgazm5RzhJnCoBal1zCKA221+ZdblKr5cXls0bovd83hwuz24XDWi8jzV3VCjqn11g8tT9vOAgeEKEDT8BAkQNAIE8RH49XIAX+jfgOEvvhz0E3T58Qd9BDH/DeAnYPjwGyXaDB9+fKHLAcOP3yjEb/jMf4O+sMsBfPiC5mWfUWDeFiw02wzzX1+wED+FFAYKMPCDywBX8Ne+GBS6ghRikOvyg8sHrsOAEXYcrqB5PdReTmoH3fDxllCeH533ECTvo8e5s+WpgAcCKRBMOca/3goccxz/Br1VvK8D1sh0BcDtA4+vcv+6/ZW/T1X/LfO5/Li8fnD7cHn8uDwBXC4XfiDgclFQhUJ9WcX9gmZ+mPsdZLegrA/qqLuD510nw4/BKr0UdavXPcbyAiq6iv38/tJLY1SV7Rn77rvvMnbsWF555RV69erFlClTGDRoEOvXr6dx48Z2h1dlYRsvRFD05r3uE+aJw46in4hdMvMyGfYfzbQVZ5qzaY7dIUgVuXCZM6KOehNT3r9Hz6A61r+J9vi7cnaxcs9Ku18aS1R2TLp48WKuuOIKJk6cyG9/+1umTZvG0KFDWbFiBZ07dwbgySef5LnnnuOtt96iXbt2jBs3jkGDBrFmzRpq1KgBwJVXXsnu3buZO3cuPp+Pa665hhtuuIFp06bFtP/RsHt3xY4rLDR/JFG4MN9G2v5W0nYud/DXArhZyHUBuIMEgxD0lVfMdJtLJkzKjlGksefymEVCPP7Q5dCPO7wdt7/4WLcvdFvo318LkuaxZjHSVVSULNFWVJg0rxdilLhueApLFTcNd2HovkWXDU8hLo8fw1WI4SnEcBXicpf+oBgo9aFxpNvjgVHyXwMC0QzrglG/LokTJOKSOBeMpkHtE2hRt0WVlhfQt9XEaWxfHqFXr1706NGDF154AYBgMEirVq249dZbueeee8q9bzx/Hezf3/+bP8z4g91hJLWwr+VU4npV7mP19WSO4WD+QbYc2sKxnHTCSTSs1bDcxz2etvL6acVzVvbxE+45E+i1sOM5Iz3Wll+28MZ3b3AsN5x+AyfXPznuinNVuY/dBcNYPr5ENn/LfPq/1b+4IYGXR6jsmHT48OHk5uby8ccfh9p+85vfkJGRwSuvvIJhGDRv3pzbb7+dO+64AzB/L02aNOHNN9/k8ssvZ+3atXTq1Ilvv/2W7t27AzB79mwGDx7Mjh07aN68vFmqpngeD8+fD/37H/Mw/v1v6NULDAOCQfPfkpeP/le36bbjvc3ed8CReTxV/1p6NH6O57k8HtCfzmKRiryVKfyWdXtV7hPNOJbuWMp1H11XxuaT2+CC0dDpA+aNmFfpdcxFEsmBAwdo2LBh4i+PUFhYyPLly7n33ntDbW63m4EDB7JkyZJSxxcUFFBQUBC6np1tfjLp8/nw+Xyh+3s8HgKBAMFg8ZT7ona/30/JOrXH48HtdpfZXvS4Rbxe81d29HTno9sb1WxUod/BG//3Br1a9ArF6PV4CQaDBIPB0JtEj9uDx+MhGAgSNIKhYkFReyAQwDCM0PFejxe3203AH8DAKD7e48Hj9uDz+cIKD16vF5fLhd/vDytMeL1eXPzaXuL4lJQUDMMgEAiEjne73Xi9Xozgr+2/Hu92me1FfSp6jKLjA4EARtAIPWfR76Dk79eFy4y9nNevqq9TkZSUFILBIIES21YW/Q7Kai8rx+zOPbv7tHDbwvA38GV4ZfArnHvSuQnRp2R8ndQn6/sUCAaY8/McduXsijizwoW5pu3zFzwfVvSN5z4d3Z4Mr1O5fTJ+jd1t9qnk8QnbJ6L7Ov2m2W9oWbclO7N3xs2Moqqo7JgUYMmSJYwdOzasbdCgQcycOROAzZs3s2fPHgYOHBi6PS0tjV69erFkyRIuv/xylixZQr169UIFW4CBAwfidrv55ptv+P3vf1/qeRNpPPyb30CLFl527XJFLJK5XNCypcHQoX48nqI2h5xv1Cfb+xQIBPH7AyUKuS48Hi8+X4BAIBgq8rpcblwuD35/AL8/WKLo68bt9lBY6CcYNEKP43Z7WLzE4E9/PPaMwJkfFnLBoBS8XggGk+N1Mm/2YhjKvXL7FDTbiz50DvXJ+DV2V+L16eS0k3lw/oPs6jQTI/1D2Ho2HG4GdXZDm0W43AYt6rbk7NZnJ0yfkjL31Keo9+no246HrUXbzMxMAoEATZo0CWtv0qQJ69atK3X8xIkTmTBhQqn2OXPmUKtWLQBat25Nt27dWL16Ndu2bQsd06FDB9LT01m6dCn79+8PtWdkZNCmTRsWLlxITk5OqL137940btyYOXPmhCVM//79qVmzJrNmzQqLYfDgweTn5zNv3jwAAkaAhikNOeA7UOab96a1mlJ3W11+2v4TAI0aNaJPnz6sW7eO9evXh44t6tPKlSsj9mnx4sWl+tSyTUu+/PLLiH365JNPIvfp88h9WjRvUajN6/UyZMgQ9u3bF/YmJjU1lQEDBrB161ZWrVoVaj9mn1ZH7tM3y76JyetUsk+ZmZkR+7R9+/aIfdqwYUPEPtmde3b36ezWZ9OoWiP2Fxa3H61hSkOyf8gmp3FOQvQpGV8n9Sk6ffpjgz8yKWcSLlwRz/1TLpjCjz/8mFB9guR7ndQna/s0se9Ervr4KhJZZcekAHv27Il4/J49e0K3F7WVd8zRSy94vV7q168fOuZoiTQeBhgxoiUTJ56By2VgGCWn3hmAi/vv389nnxXnt/5vqk+x6tOOHZH7tGVL5D6tXBm5T4sXl+7TFZe34Zbbssg+mEpZa3um1s8m4F9AQUF/3G69TupTcvSpaCyM24B2C8Ie1wBGnTIKj9tTZm0jHvtUJJleJ/Upun3Ky8vDKrYuj7Br1y5atGjB4sWL6d27d6j9rrvuYsGCBXzzzTdhx0eaWdCqVSsyMzNDU47j6ROBD9Z9wOUzLgfC17MpmmH1n0v+w0WnXFTcHgefCFSkPdE+5VCfYt+n9354j+HvDwci5/47w97h9+m/T6g+JePrpD5Fp08z189k7NyxYeuat6zbkmfPf5ZLTr0kIfuUjK+T+mRtn2asncGo2aPYuX9nQi6PUNkxKUC1atV46623uOKKK0JtL730EhMmTGDv3r0sXryYM888k127dtGsWbPQMZdddhkul4t3332Xxx9/nLfeeivsTQVA48aNmTBhAjfffHOp50208TDARx+lMGqUwY4dxUXbli0Npk51MXRo/ORxZfqUKP831Scbx8PvBbhsuPvXqadHre3pcvHuOwF+/3sjofqUjK+T+mR9nz5Y9wG3z72dHTnhY+FnBj7DxZ0uTsg+lWxPltdJfYpenw4dOkTjxo0Tf3mEhg0b4vF42Lt3b1j73r17adq0aanjq1evTvXq1Uu1p6SkkJKSEtbm8Zhfpz9a0S+you1HP25l2i877TK8Xi+jZo8q9eZ9ygVTGNYx8kZMZcUeD30q4na7cbtLf2pcVrv65Kw+Xdr5UjweT4VzPxH6lIyvk/oUnT5d2vlShnUaxryf5/HpV59y4VkX0v/E/qGNExKxT0WS6XUqoj5Z06eLO13M0PShfPzdxwx9YmjE+8Wzyo5JAZo2bVru8UX/7t27N6xou3fvXjIyMkLH7Nu3L+wx/H4/Bw8eLPN5E208DDBsGFx0kYt58/x8+ukqLrwwg/79vZihxU8eV7Y9Ef5vVrZdfbJwPHyph/c9MGoU7CgeDtOylYupU1wMGxZ+v0ToUzK+TuqT9X267LTLuPjUi8scC0Pi9amkZHmdSlKfrO1TpHiqyrpHqoJq1apxxhln8MUXX4TagsEgX3zxRdgsh0Q2rOMwtozawtwr5zK2zVjmXjmXzaM2l1mwFUkWyn1xMo/bQ782/eh7Ql/6temnnW7FETxuD2e1PsvuMKqkKmPS3r17hx0PMHfu3NDx7dq1o2nTpmHHZGdn880334SO6d27N4cOHWL58uWhY7788kuCwSC9evWyrH/xwOOBfv0M+vbdSb9+BhHeL4kknWHDYMsWF3Pn+hk7dhlz5/rZstnFMA2HJclpLCxiDVtn2gKMHTuWESNG0L17d3r27MmUKVPIzc3lmmuusTs0yxSdsHJ/zNUJSxxFuS8iIoniWGPSq666ihYtWjBx4kQARo0aRb9+/XjmmWcYMmQI77zzDsuWLePVV18FzK/njR49mkcffZT27dvTrl07xo0bR/PmzRk6dCgAHTt25IILLuD666/nlVdewefzMXLkSC6//HKaN29uy+9BRKxV9IFFbu5O+vXrqg8sRESkwmwv2g4fPpz9+/fz4IMPsmfPHjIyMpg9e3apTRsSncvlIjU1FZfLdeyDRZKIcl+cSrkvTpTI+X6sMem2bdvCvu7Wp08fpk2bxgMPPMB9991H+/btmTlzJp07dw4dc9ddd5Gb+//t3XtsFPUaxvFnl16Fbptyb+iCF6AXraQETIUgBKQ0WGtKVAwmFarEiNZqgmAMAgGkBKNoiEgUpbEloCRLJEZIIRZKVFK1rWi4CCKUFNN4220LJbi75w8jJz3l9Hh0dmbg9/381+1Q3908Th7eDjNdWrhwoX777TdNnjxZe/bsUVJS0pVjamtr9eSTT2r69Onyer2aM2eOXn/9dfveuI04L8JUZB8mIvcwlZWZd/RBZP9UKBRSamrqNfewCwAAgOsR3cx+fOYAAADuYWU3c/SetiaJRCI6c+ZMj6fUASYg+zAV2YeJyDv6wnkRpiL7MBG5h6mszDxLW5uEw2E1NzcrHA47PQpgK7IPU5F9mIi8oy+cF2Eqsg8TkXuYysrMs7QFAAAAAAAAABdhaQsAAAAAAAAALsLS1iYej0eDBw/myYkwDtmHqcg+TETe0RfOizAV2YeJyD1MZWXmPdFoNGrZT7MZT8sFAABwD7qZ/fjMAQAA3MPKbsaVtjYJh8M6duwYN+GGccg+TEX2YSLyjr5wXoSpyD5MRO5hKh5Edg2KRCI6fvy4IpGI06MAtiL7MBXZh4nIO/rCeRGmIvswEbmHqazMPEtbAAAAAAAAAHARlrYAAAAAAAAA4CIsbW3i9Xrl9/vl9fKRwyxkH6Yi+zAReUdfOC/CVGQfJiL3MJWVmfdEo9GoZT/NZjwtFwAAwD3oZvbjMwcAAHAPK7sZv/KwSTgcVlNTE09OhHHIPkxF9mEi8o6+cF6Eqcg+TETuYSorM8/S1iaRSERnz57lyYkwDtmHqcg+TETe0RfOizAV2YeJyD1MZWXmWdoCAAAAAAAAgIvEOT3AP/Hn7XhDoZDDk/xvly9f1oULFxQKhRQfH+/0OIBtyD5MRfZhoo6ODkn/7miIPfow4H5kHyYi9zCVlX34ml7a/vlBZGZmOjwJAAAA/vTzzz8rNTXV6TGMQB8GAABwHyv6sCd6DV8KEYlE1NbWppSUFHk8HqfH6VMoFFJmZqZaW1t5si+MQvZhKrIPEwWDQfn9fv36669KS0tzehwj0IcB9yP7MBG5h6ms7MPX9JW2Xq9XI0aMcHqM/4vP5+OEBSORfZiK7MNEXi+PTbALfRi4dpB9mIjcw1RW9GEaNQAAAAAAAAC4CEtbAAAAAAAAAHARlrY2SUxM1PLly5WYmOj0KICtyD5MRfZhInKPvpAPmIrsw0TkHqayMvvX9IPIAAAAAAAAAOB6w5W2AAAAAAAAAOAiLG0BAAAAAAAAwEVY2gIAAAAAAACAi7C0tVFVVZU8Ho8qKyudHgWIqXA4rGXLlunGG29UcnKybr75Zq1atUrcQhvXm4MHD6q4uFgZGRnyeDzatWtXr2OOHj2qe++9V6mpqerfv78mTJigs2fP2j8sYKFNmzYpLy9PPp9PPp9PBQUF+vjjjyVJv/zyi5566imNHTtWycnJ8vv9qqioUDAYdHhquAF9GKagD8MU9GGYyo4+HBeLwdFbY2OjNm/erLy8PKdHAWJu3bp12rRpk6qrq5Wbm6svvvhC8+fPV2pqqioqKpweD7BMV1eXbr/9di1YsEClpaW9vn/q1ClNnjxZ5eXlWrlypXw+n7799lslJSU5MC1gnREjRqiqqkqjR49WNBpVdXW1SkpK1NTUpGg0qra2Nr388svKycnRmTNn9Pjjj6utrU07d+50enQ4iD4Mk9CHYQr6MExlRx/2RPlVX8x1dnYqPz9fb7zxhlavXq1x48Zpw4YNTo8FxMw999yjoUOHasuWLVdemzNnjpKTk1VTU+PgZEDseDweBQIB3XfffVdemzt3ruLj4/Xee+85Nxhgk/T0dK1fv17l5eW9vvfBBx/o4YcfVldXl+LiuGbARPRhmIY+DBPRh2E6q/swt0ewwaJFizR79mzNmDHD6VEAW9x5553av3+/Tpw4IUlqaWnRoUOHVFRU5PBkgH0ikYg++ugjjRkzRoWFhRoyZIjuuOOOq/6TMeBaFg6HtX37dnV1damgoOCqxwSDQfl8Pha2BqMPwzT0YYA+DHPEqg/TnGNs+/bt+uqrr9TY2Oj0KIBtli5dqlAopKysLPXr10/hcFhr1qzRvHnznB4NsE17e7s6OztVVVWl1atXa926ddqzZ49KS0v1ySef6K677nJ6ROAfOXLkiAoKCtTd3a0BAwYoEAgoJyen13E//fSTVq1apYULFzowJdyAPgwT0YcB+jCuf7HuwyxtY6i1tVVPP/206urquF8LjPL++++rtrZW27ZtU25urpqbm1VZWamMjAyVlZU5PR5gi0gkIkkqKSnRM888I0kaN26cPv30U7355puUVFzzxo4dq+bmZgWDQe3cuVNlZWU6cOBAj6IaCoU0e/Zs5eTkaMWKFc4NC8fQh2Eq+jBAH8b1L9Z9mKVtDH355Zdqb29Xfn7+ldfC4bAOHjyojRs36tKlS+rXr5+DEwKxsXjxYi1dulRz586VJN122206c+aM1q5dS0mFMQYNGqS4uLhev2nNzs7WoUOHHJoKsE5CQoJuueUWSdL48ePV2Nio1157TZs3b5YkdXR0aNasWUpJSVEgEFB8fLyT48Ih9GGYij4M0Idx/Yt1H2ZpG0PTp0/XkSNHerw2f/58ZWVlacmSJRRUXLcuXLggr7fnLbP79et35TetgAkSEhI0YcIEHT9+vMfrJ06c0MiRIx2aCoidSCSiS5cuSfrjioLCwkIlJibqww8/5ApLg9GHYSr6MEAfhnms7sMsbWMoJSVFt956a4/X+vfvr4EDB/Z6HbieFBcXa82aNfL7/crNzVVTU5NeeeUVLViwwOnRAEt1dnbq5MmTV74+ffq0mpublZ6eLr/fr8WLF+vBBx/UlClTNG3aNO3Zs0e7d+9WfX29c0MDFnj++edVVFQkv9+vjo4Obdu2TfX19dq7d69CoZBmzpypCxcuqKamRqFQSKFQSJI0ePBglnSGoQ/DVPRhmII+DFPZ0Yc90Wg0Gss3gZ6mTp2qcePGacOGDU6PAsRMR0eHli1bpkAgoPb2dmVkZOihhx7Siy++qISEBKfHAyxTX1+vadOm9Xq9rKxMW7dulSS98847Wrt2rc6dO6exY8dq5cqVKikpsXlSwFrl5eXav3+/zp8/r9TUVOXl5WnJkiW6++67/+v/F9Iff5EbNWqUvcPCdejDMAF9GKagD8NUdvRhlrYAAAAAAAAA4CLe/30IAAAAAAAAAMAuLG0BAAAAAAAAwEVY2gIAAAAAAACAi7C0BQAAAAAAAAAXYWkLAAAAAAAAAC7C0hYAAAAAAAAAXISlLQAAAAAAAAC4CEtbAAAAAAAAAHARlrYA4KBRo0Zpw4YNfR7j8Xi0a9cuW+YBAAAA7EQfBoCri3N6AAAwWWNjo/r37+/0GAAAAIAj6MMAcHUsbQHAQYMHD3Z6BAAAAMAx9GEAuDpujwAA/9DUqVNVUVGh5557Tunp6Ro2bJhWrFghSYpGo1qxYoX8fr8SExOVkZGhioqKK3/2P/852HfffacpU6YoKSlJOTk5qqur6/Xfa21t1QMPPKC0tDSlp6erpKREP/zwQ4zfJQAAAHB19GEAsB5X2gKABaqrq/Xss8/q8OHD+uyzz/TII49o0qRJCgaDevXVV7V9+3bl5ubqxx9/VEtLy1V/RiQSUWlpqYYOHarDhw8rGAyqsrKyxzGXL19WYWGhCgoK1NDQoLi4OK1evVqzZs3S119/rYSEBBveLQAAANATfRgArMXSFgAskJeXp+XLl0uSRo8erY0bN2r//v0aMmSIhg0bphkzZig+Pl5+v18TJ0686s/Yt2+fjh07pr179yojI0OS9NJLL6moqOjKMTt27FAkEtHbb78tj8cjSXr33XeVlpam+vp6zZw5M8bvFAAAAOiNPgwA1uL2CABggby8vB5fDx8+XO3t7br//vt18eJF3XTTTXrssccUCAT0+++/X/VnHD16VJmZmVcKqiQVFBT0OKalpUUnT55USkqKBgwYoAEDBig9PV3d3d06deqU9W8MAAAA+AvowwBgLa60BQALxMfH9/ja4/EoEokoMzNTx48f1759+1RXV6cnnnhC69ev14EDB3r9mb+is7NT48ePV21tba/v8RAHAAAAOIU+DADWYmkLADGWnJys4uJiFRcXa9GiRcrKytKRI0eUn5/f47js7Gy1trbq/PnzGj58uCTp888/73FMfn6+duzYoSFDhsjn89n2HgAAAIC/iz4MAP8/bo8AADG0detWbdmyRd98842+//571dTUKDk5WSNHjux17IwZMzRmzBiVlZWppaVFDQ0NeuGFF3ocM2/ePA0aNEglJSVqaGjQ6dOnVV9fr4qKCp07d86utwUAAAD8JfRhAPh7WNoCQAylpaXprbfe0qRJk5SXl6d9+/Zp9+7dGjhwYK9jvV6vAoGALl68qIkTJ+rRRx/VmjVrehxzww036ODBg/L7/SotLVV2drbKy8vV3d3NlQYAAABwHfowAPw9nmg0GnV6CAAAAAAAAADAH7jSFgAAAAAAAABchKUtAAAAAAAAALgIS1sAAAAAAAAAcBGWtgAAAAAAAADgIixtAQAAAAAAAMBFWNoCAAAAAAAAgIuwtAUAAAAAAAAAF2FpCwAAAAAAAAAuwtIWAAAAAAAAAFyEpS0AAAAAAAAAuAhLWwAAAAAAAABwEZa2AAAAAAAAAOAi/wKk8yXPW2G5jwAAAABJRU5ErkJggg==", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "IN_COLAB = 'google.colab' in sys.modules\n", + "\n", + "# Install s2fft and data if running on google colab.\n", + "if IN_COLAB:\n", + " !pip install s2fft &> /dev/null" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAKzCAYAAABlBC9iAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeZyN5f/H8dc5s4/ZrEmMQYWyjIhQWcvWIqGvFksUSkihRdaKkq0ktKBVkVRIlJEsWbJEIcm+ZIsZ28ycc+7fH+c3J8fMMDgz9zlzv5+Px3l0z3Xuc5/Pdc7H6TPXXOe6bIZhGIiIiIiIiIiIiIiIX7CbHYCIiIiIiIiIiIiI/EeDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIi50hISMBms2W6RUVFUbVqVZ5//nmOHj3qFzHu3LnT1Djy0s6dO7HZbCQkJFzS47J7P8+9jR07NtPzXOy2fv36HF37/NvF4r/U69lsNurXrw9A/fr1sdlsLF68+JJeIxERERHxP8FmByAiIiLij+rWrcu1114LgMvlYv/+/SxfvpwRI0bw4Ycf8vPPP1O2bFmTo5ScOvf9PN8NN9yQZfv9999PVFRUlvcVKlSI1q1bc+TIEa/2kydP8uWXX2b7+CJFilwwzg4dOmRqO3jwIN9//32291eoUOGC1xQRERGRwGMzDMMwOwgRERERf5GQkMCuXbuYMmUKHTt29Lrv4MGD1KtXjz///JP777+fmTNnmhrjjh07LnnmaaDauXMnZcqUoXTp0pc0w/hC7+eFnge4rNf3Sh+flcWLF9OgQQMALlS67969m9OnTxMfH09kZOQVP6+IiIiImEfLI4iIiIjkUPHixenbty8AP/74o8nRiHiLj4+nQoUKGrAVERERyQc0aCsiIiJyCYoXLw6Aw+HIdN+uXbt47bXXaNiwIfHx8YSFhREXF8ett97KpEmTcLlc2V7333//ZejQodSoUYPY2FgiIiIoW7Ysbdu25bvvvstxfC+//DI2m41SpUqxceNGDMOgSJEi2O32TGvxrlq1yrMu6oQJEzJdq2zZsthsNv7+++8r6uO569E6nU5Gjx5NtWrViIqKwmazeZ07Z84c6tWrR3R0NLGxsdx22218/fXXOe6/lWW3pm3Hjh2x2WxMnTqVrVu38sADD1CsWDEKFCjAzTff7PX6rly5knvuuYeiRYsSERFB7dq1L/gHijNnzjBq1ChuueUW4uLiCA8Pp3z58vTr1y/btZ9nzJhB48aNKVy4MCEhIRQuXJgbbriBxx57jN9++80nr4WIiIhIoNOgrYiIiMglWLVqFQA33nhjpvs++ugjnnvuOXbu3Mn1119Pq1atSExMZPXq1XTr1o02bdpk+fX2DRs2ULlyZQYNGsRff/3Frbfeyr333kvx4sWZM2cOr7322kXjSk9P59FHH+Wll14iMTGRlStXUrlyZWw2Gw0bNsQwjEyDbz/88EOWxwB///03O3bsoEyZMl5r915uH8H91f5WrVrx/PPPU7hwYe655x6qVKniuX/MmDHcfffdLFmyhBtuuIEWLVpw9uxZWrZsyVtvvXXR10AubO3atVSvXp0NGzbQqFEjqlatypo1a7jvvvuYOXMms2fP5rbbbmPv3r00atSI8uXL88svv9C0aVOWLl2a6Xr79++nVq1aPPvss2zbto2bb76Z5s2bk5qaysiRI6lRowa7du3yeszQoUNp27YtP/30E5UqVaJNmzbccsstBAUF8f7777No0aK8ejlERERE/JshIiIiIh6lS5c2AGPKlCmeNqfTaezdu9d46623jLCwMCMoKMj49ttvMz121apVxsaNGzO179u3z6hataoBGF988YXXfSdPnjRKlSplAEb79u2NlJQUr/uPHz9uLFy4MMsYd+zY4TmncePGBmA0a9Ys0zUmTZpkAMZjjz3m1d6gQQMjNDTUqFChghEXF2c4HI6LPuZy+rhjxw4DMACjZMmSxtatWzM9fsOGDUZQUJBht9uNGTNmeN338ccfGzabzQCM0qVLZ3rshWT1fl7IubFmvL6X4kofn5WkpCTPNS+kXr16BmAkJSV5tXfo0MHz+JdfftlwuVye+958803P+1KwYEHjww8/9Hps7969DcBo3LixV7vL5TLq1q1rAEbnzp2N5ORkz33p6enGM888YwBGgwYNPO1nz541IiIijKioKGPLli2Z4t+5c6exefPmi74eIiIiIlagQVsRERGRc2QM8mV3u/nmm42lS5de8nW///57AzDatGnj1T527FgDMBITE70GTXMS444dO4xdu3YZlSpVMgCja9euWV5j+/btBmCUKVPG03b69GkjLCzMqFevntG3b18DMH755RfP/W3atDEA4/PPP7/iPp47kHn+oGCGLl26GIDxwAMPZHn/vffee0WDttnd6tWrl22s2d0GDRqU7fP586BtzZo1vQZsDcM9wFqoUKEs3zfDMIwjR44YgBEaGmqkpaV52r/77jtP3qanp2d6nNPp9ORlxiD/oUOHDMCoUqVKTrsuIiIiYlnBVzJLV0RERCS/qlu3Ltdee63n5yNHjvDbb7+xevVqnn76aT755BOuu+66TI9LTU1lwYIFrF69mkOHDpGamophGKSkpACwdetWr/Pnz58PQOfOnQkKCrqkGNeuXUuPHj04ePAgI0aMoH///lmeV7ZsWcqUKcOOHTvYvn075cqV4+effyY1NZU77riDm2++mZEjR/LDDz9Qq1YtDMNg0aJF2Gw2GjVqdMV9PNf999+fZXvGOqwPP/xwlvd36NDhita2Pf/9zFChQoVsH3P//fcTFRWVqT0xMfGy4zBTs2bNMq0hHBwcTJkyZTh27BjNmzfP9JjChQtTqFAhjh07xtGjRz1rOs+dOxdwv0bBwZl/pbDb7dx+++1s2rSJ5cuXU6lSJYoWLUpCQgK//fYbzzzzDJ07d+aGG27IhZ6KiIiIBD4N2oqIiIhkoUuXLnTs2NGrzeFwMHDgQIYPH069evXYunUr0dHRnvt/+eUXHnjgAXbv3p3tdZOTk71+zljz80KDh9l54IEHcDgcvPzyy9kO2GZo3Lgx7777Lj/88APlypXzrGF7xx13ULlyZcLCwvjhhx948cUXWbduHUePHqVatWoULlzY6zqX08cMxYoVIzIyMsv79u7dC0CZMmWyvD+79pzK6v28mDfeeIOEhIQrel5/Eh8fn2V7xsB0dvdHR0dz7Ngxzp4962nL2JzupZde4qWXXrrg8x4+fNhz/OGHH9K6dWtGjx7N6NGjKVSoELVq1eKOO+7gkUceoUiRIpfUJxEREZH8ShuRiYiIiORQcHAwL7/8MkWKFOHAgQN8+OGHnvtOnz5Ny5Yt2b17N506dWLVqlUcO3YMh8OBYRie2adGNpt0XY4OHToA7g28fv311wue27hxYwAWLlwIuDceK1iwIDVq1CAiIoI6deqwfPlyTp8+7RnQzXiMr/oYERFx+Z2VK2a3X7j0v9j953K5XADceuutdOjQ4YK3czftu+2229i5cyczZsygR48eJCQk8P3339OnTx/Kli2babM8EREREavSTFsRERGRS2C320lISODIkSNs3rzZ075kyRL++ecfbrrpJj744INMj9u2bVuW14uPj2fz5s1s2bIl0yDpxQwYMIAbbriBZ555hoYNGzJ37lxuvfXWLM9t1KgRNpuNpKQkDh06xPr167nvvvs8A3WNGzcmKSmJJUuWZDtoe7l9zIlrrrmG7du3s3PnTq9Bvgw7d+687GuL75UqVQqAe++9l2efffaSHhsREUHr1q1p3bo14J6JO2DAACZPnsyjjz7qmX0uIiIiYmWaaSsiIiJyCVwul2cA8dz1To8dOwZk/xXzjz/+OMv2pk2bAvDBBx/gdDovOZ4+ffowefJkTp48SZMmTTwzac9XuHBhEhMTOXbsGCNHjsQwDO644w7P/RkDtHPmzGHp0qWEhYVx2223eV3jcvuYE/Xq1QPgk08+yfL+c2c1i/maNWsGwIwZM6549njRokV5/fXXAdi9ezf//vvvFccnIiIiEug0aCsiIiKSQw6HgwEDBnDkyBEA7rnnHs99FStWBODHH3/kjz/+8Hrc5MmT+fzzz7O8ZpcuXShZsiTr1q3jscce49SpU173Jycne2a+Zuexxx7j448/Ji0tjbvvvpvZs2dneV7GwOz48eMBvAZta9SoQVxcHO+//z5nzpyhTp06mZYzuNw+5sRTTz1FUFAQX3zxBV999ZXXfdOnT8+2T2KOe++9l5tvvplVq1bRqVMnr3VrM/z7779MnDgRh8MBuNdvfu+997Jc8/jbb78FoGDBgsTExORu8CIiIiIBQMsjiIiIiGThvffeY/HixZ6fjx49yoYNG9izZw8AL774InXq1PHcX61aNe69916+/vprqlWrRv369SlUqBDr169n69atvPDCC7zyyiuZnicqKopvvvmG5s2bM2XKFL766ivq1q1LVFQUe/bsYd26ddSsWfOiSye0a9eOAgUK0LZtW9q0acPUqVN56KGHvM5p3LgxI0eO5OzZs5QpU4Zy5cp57rPb7TRo0MAzYJrV811uH3MiMTGR4cOH069fP1q1akWtWrUoV64c27ZtY/Xq1Tz99NOMGTPmsq4tvme325k9ezYtWrRg2rRpzJw5k6pVqxIfH09aWhp///03GzduxOl00rFjR4KDg/n333957LHHeOKJJ0hMTPRsLrdt2zbWrVuHzWZj5MiRBAUFmdw7EREREfNppq2IiIhIFpYtW8a0adM8twULFmC323nggQdISkri5ZdfzvSYGTNmMHLkSMqXL8/SpUtZsGAB8fHxfP/993Tp0iXb56pWrRobN25kwIABlCpVisWLF/PNN99w8OBB7rnnHp5//vkcxXzPPfcwd+5cwsLCaN++PZMnT/a6/7bbbiMsLAzIelD23LbsBokvt4850bdvX77++mtuvfVWNm3axDfffENISAgzZ86kZ8+eV3Rt8b0SJUrwyy+/MHHiRGrWrMnWrVuZOXMmS5cuBaBbt258//33hIeHA1CuXDnGjh3LXXfdxfHjx5k3bx5z587l1KlTtG/fntWrV9O5c2czuyQiIiLiN2yGL7cwFhEREREREREREZEropm2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiIiIiIiIiIiIiJ+RIO2IiIiIiIiIiIiIn5Eg7YiIiIiIiIiIiIifkSDtiLi12w2G4MHD/b8PHXqVGw2Gzt37vTZc+zcuRObzcbUqVN9dk1fS0hIoGPHjmaHcclcLheVKlXilVde8WpfvXo1derUoUCBAthsNtavX8/gwYOx2WwmRZr3/ve//9G2bVuzwxARERHJM4sXL8Zms7F48WKzQ8nW+b9/iIiYRYO2Ihawfft2unbtStmyZQkPDycmJoa6desybtw4zpw5Y3Z4eebTTz9l7NixZocB/Few5uQWyD777DP27NlDjx49PG3p6em0adOGY8eOMWbMGD766CNKly7t8+f+448/GDx4sE8H+H2pf//+fPnll2zYsMHsUERERMRPZExQyO72yy+/mB1ijkyYMMFvJkRc7DXNuCUkJJgdqoiIF5thGIbZQYhI7pk7dy5t2rQhLCyM9u3bU6lSJdLS0li6dClffvklHTt2ZPLkyWaHma2zZ88SHBxMcHAw4C66OnXqxI4dOy65sLrrrrvYtGlTpkE8wzBITU0lJCSEoKAgH0V+Yf/88w8LFy70anv++eeJiorixRdf9Gp/+OGHSU1NxW63ExISkifx+UpiYiK1atVi0qRJnrYtW7ZQsWJF3n33Xbp06eJpdzgcOBwOwsPDffLcM2fOpE2bNiQlJVG/fn2fXNPXatWqRfny5fnwww/NDkVERET8QEatO3ToUMqUKZPp/qZNm1KkSBETIrs0lSpVokiRIplm1LpcLtLS0ggNDcVuz5s5ZH///TfLly/3auvSpQs1a9bk8ccf97RFRUXRsmXLTL9/iIiYRZ9CIvnYjh07+N///kfp0qVZtGgRV199tee+J598kr/++ou5c+eaGOHF+WoA70JsNluePM+5rrrqKh5++GGvthEjRlCkSJFM7QBhYWF5FZrPrFu3jg0bNjBq1Civ9kOHDgEQFxfn1Z6T4jij0M+r9+vUqVMUKFAg167ftm1bBg0axIQJE4iKisq15xEREZHA0qxZM2rUqGF2GD5nt9vzvO4uW7YsZcuW9Wrr1q0bZcuWzbLuzuv4RESyo+URRPKx119/nZMnT/L+++97DdhmuPbaa+nVq5fnZ4fDwbBhwyhXrhxhYWEkJCTwwgsvkJqa6vW4hIQE7rrrLhYvXkyNGjWIiIigcuXKnr+kz5o1i8qVKxMeHk716tVZt26d1+M7duxIVFQUf//9N02aNKFAgQKUKFGCoUOHcv7k/5ysKfX111/TokULSpQoQVhYGOXKlWPYsGE4nU7POfXr12fu3Lns2rUr01egslvTdtGiRdx2220UKFCAuLg47r33XjZv3ux1TsY6rH/99RcdO3YkLi6O2NhYOnXqxOnTpy8Y96U4f03bjK95LV26lJ49e1K0aFHi4uLo2rUraWlpHD9+nPbt21OwYEEKFixIv379Mr22LpeLsWPHcuONNxIeHs5VV11F165d+ffff73OW7NmDU2aNKFIkSJERERQpkwZHn300YvGPHv2bEJDQ7n99ts9bR07dqRevXoAtGnTBpvN5pkFm9WatjabjR49evDJJ59w4403EhYWxvz58wGYPn061atXJzo6mpiYGCpXrsy4ceM8r0+bNm0AaNCggec9v9D6aRl5uX37dpo3b050dDQPPfRQlq9/hvr163vN4s1Y9uKLL77glVdeoWTJkoSHh9OoUSP++uuvTI+/4447OHXqVKZZ1yIiIiIXMmjQIOx2Oz/++KNX++OPP05oaKjX8ksrV66kadOmxMbGEhkZSb169Vi2bFmma+7bt4/OnTt7auoyZcrQvXt30tLSgKxrNci850RCQgK///47P/30k6cGy6iXslvTdsaMGVSvXp2IiAjPJIZ9+/Z5nZNRq+3bt4+WLVsSFRVF0aJFefbZZ73q/it1/u8fGf3+888/efjhh4mNjaVo0aK89NJLGIbBnj17uPfee4mJiaF48eKZJiwApKamMmjQIK699lrCwsIoVaoU/fr1y/R71sKFC7n11luJi4sjKiqK8uXL88ILL/isbyISWDTTViQf+/bbbylbtix16tTJ0fldunRh2rRptG7dmmeeeYaVK1cyfPhwNm/ezFdffeV17l9//cWDDz5I165defjhh3njjTe4++67mThxIi+88AJPPPEEAMOHD6dt27Zs3brV6ytQTqeTpk2bcsstt/D6668zf/58Bg0ahMPhYOjQoZfUz6lTpxIVFUWfPn2Iiopi0aJFDBw4kOTkZEaOHAnAiy++yIkTJ9i7dy9jxowBuODMxh9++IFmzZpRtmxZBg8ezJkzZ3jrrbeoW7cua9euzbQ0Q9u2bSlTpgzDhw9n7dq1vPfeexQrVozXXnvtkvpyqZ566imKFy/OkCFD+OWXX5g8eTJxcXEsX76c+Ph4Xn31VebNm8fIkSOpVKkS7du39zy2a9eunq/g9ezZkx07djB+/HjWrVvHsmXLCAkJ4dChQ9x5550ULVqU5557jri4OHbu3MmsWbMuGtvy5cupVKmS15IOXbt25ZprruHVV1+lZ8+e3HzzzVx11VUXvM6iRYv44osv6NGjB0WKFCEhIYGFCxfSrl07GjVq5HmNN2/ezLJly+jVqxe33347PXv25M033+SFF16gYsWKAJ7/ZsfhcNCkSRNuvfVW3njjDSIjIy/az6yMGDECu93Os88+y4kTJ3j99dd56KGHWLlypdd5N9xwAxERESxbtoz77rvvsp5LRERE8p8TJ05w5MgRrzabzUbhwoUBGDBgAN9++y2dO3dm48aNREdH8/333/Puu+8ybNgwqlatCrjrqGbNmlG9enXPQO+UKVNo2LAhP//8MzVr1gRg//791KxZk+PHj/P4449ToUIF9u3bx8yZMzl9+jShoaE5jn3s2LE89dRTXst+Xajey6hHb775ZoYPH84///zDuHHjWLZsGevWrfP6dpbT6aRJkybUqlWLN954gx9++IFRo0ZRrlw5unfvnuMYL8cDDzxAxYoVGTFiBHPnzuXll1+mUKFCTJo0iYYNG/Laa6/xySef8Oyzz3LzzTd7Ji64XC7uueceli5dyuOPP07FihXZuHEjY8aM4c8//2T27NkA/P7779x1111UqVKFoUOHEhYWxl9//ZXlALuIWIQhIvnSiRMnDMC49957c3T++vXrDcDo0qWLV/uzzz5rAMaiRYs8baVLlzYAY/ny5Z6277//3gCMiIgIY9euXZ72SZMmGYCRlJTkaevQoYMBGE899ZSnzeVyGS1atDBCQ0ONw4cPe9oBY9CgQZ6fp0yZYgDGjh07PG2nT5/O1J+uXbsakZGRxtmzZz1tLVq0MEqXLp3p3B07dhiAMWXKFE9bYmKiUaxYMePo0aOetg0bNhh2u91o3769p23QoEEGYDz66KNe17zvvvuMwoULZ3quC7nxxhuNevXqZXlf6dKljQ4dOnh+zngdmjRpYrhcLk977dq1DZvNZnTr1s3T5nA4jJIlS3pd++effzYA45NPPvF6nvnz53u1f/XVVwZgrF69+pL6YhiGUbJkSeP+++/P1J6UlGQAxowZM7zaM17LcwGG3W43fv/9d6/2Xr16GTExMYbD4cj2+WfMmJEp9y4kIy+fe+65TPed//pnqFevntfrmtG3ihUrGqmpqZ72cePGGYCxcePGTNe4/vrrjWbNmuUoRhEREcnfMmq8rG5hYWFe527cuNEIDQ01unTpYvz777/GNddcY9SoUcNIT083DMNdX1933XWZ6sXTp08bZcqUMe644w5PW/v27Q273Z5lzZfx2KxqtXNjPrc+z66uzaiVMuqztLQ0o1ixYkalSpWMM2fOeM6bM2eOARgDBw70tGXUakOHDvW6ZrVq1Yzq1atneq4LKVCgQJa1nWFk/v0jo9+PP/64py2jvrbZbMaIESM87f/++68RERHhde2PPvrIsNvtxs8//+z1PBMnTjQAY9myZYZhGMaYMWMMwOt3IRGxNi2PIJJPJScnAxAdHZ2j8+fNmwdAnz59vNqfeeYZgExr395www3Url3b83OtWrUAaNiwIfHx8Zna//7770zP2aNHD89xxtfg09LS+OGHH3IUc4aIiAjPcUpKCkeOHOG2227j9OnTbNmy5ZKuBXDgwAHWr19Px44dKVSokKe9SpUq3HHHHZ7X6lzdunXz+vm2227j6NGjnvcht3Tu3Nnra2q1atXCMAw6d+7saQsKCqJGjRpe78GMGTOIjY3ljjvu4MiRI55b9erViYqKIikpCfhv3dk5c+aQnp5+SbEdPXqUggULXkHv3OrVq8cNN9zg1RYXF5drywr4YpZGp06dvGak3HbbbUDW/w4KFiyYaSaNiIiIWNvbb7/NwoULvW7fffed1zmVKlViyJAhvPfeezRp0oQjR44wbdo0zx4B69evZ9u2bTz44IMcPXrUU++dOnWKRo0asWTJElwuFy6Xi9mzZ3P33XdnuY5uVksi+MqaNWs4dOgQTzzxhNdasi1atKBChQpZ7r+RVd2dVY3la+duoJtRX59fd8fFxVG+fPlMdXfFihWpUKGCV93dsGFDgEx199dff43L5cr1/oiI/9PyCCL5VExMDOAexMyJXbt2Ybfbufbaa73aixcvTlxcHLt27fJqP3dgFiA2NhaAUqVKZdl+/jqpdrs904YA119/PYBnPayc+v333xkwYACLFi3KNEh64sSJS7oW4Olr+fLlM91XsWJFvv/++0wbVJ3/emQMVv7777+e9yI3XMr7cO57sG3bNk6cOEGxYsWyvG7GZmH16tXj/vvvZ8iQIYwZM4b69evTsmVLHnzwwRxtjmact47u5chq5+QnnniCL774gmbNmnHNNddw55130rZtW5o2bXpFzxUcHEzJkiWv6Bpw4Xw4n2EYufrLkIiIiASemjVr5mgjsr59+zJ9+nRWrVrFq6++6vWH7m3btgHQoUOHbB9/4sQJ0tLSSE5OplKlSlce+CW6UN1doUIFli5d6tUWHh5O0aJFvdoKFiyYZY3la1nV3eHh4RQpUiRT+9GjRz0/b9u2jc2bN2eKO0NG3f3AAw/w3nvv0aVLF5577jkaNWpEq1ataN26tdcycyJiHRq0FcmnYmJiKFGiBJs2bbqkx+V08CgoKOiS2n0xeJeV48ePU69ePWJiYhg6dCjlypUjPDyctWvX0r9//zz7K3Ve9/tiz5tV+7mxuFwuihUrxieffJLl4zOKSpvNxsyZM/nll1/49ttv+f7773n00UcZNWoUv/zyywXXBS5cuLBPCuhzZ1JnKFasGOvXr+f777/nu+++47vvvmPKlCm0b9+eadOmXfZzhYWFZVkUZ/fvwul0ZvlaX0o+/Pvvv1x33XWXGKmIiIiI+1s8GYOzGzdu9Lovow4eOXIkiYmJWT4+KiqKY8eO5ei5LlQP5ZXsaiyznjsnNZ/L5aJy5cqMHj06y3MzJltERESwZMkSkpKSmDt3LvPnz+fzzz+nYcOGLFiwwNS+i4g5NGgrko/dddddTJ48mRUrVngtZZCV0qVL43K52LZtm9dmTf/88w/Hjx+ndOnSPo3N5XLx999/e2bXAvz5558AmTb5upDFixdz9OhRZs2a5VnsH2DHjh2Zzs3pgHRGX7du3Zrpvi1btlCkSBGvWbaBqFy5cvzwww/UrVs3y0HR891yyy3ccsstvPLKK3z66ac89NBDTJ8+3etrYuerUKFClu+Dr4SGhnL33Xdz991343K5eOKJJ5g0aRIvvfQS1157rU9nrxYsWJDjx49nat+1a1emGeOXwuFwsGfPHu65554riE5ERESsyOVy0bFjR2JiYujduzevvvoqrVu3plWrVoC73gP3ZI7GjRtne52iRYsSExNz0ckeGd8cOn78uNfmYOd/Iw8ur+7OWC4gw9atW33+O4gZypUrx4YNG2jUqNFFXxe73U6jRo1o1KgRo0eP5tVXX+XFF18kKSnpgu+hiORPmmMvko/169ePAgUK0KVLF/75559M92/fvp1x48YB0Lx5c8C92+u5Mv4i3KJFC5/HN378eM+xYRiMHz+ekJAQGjVqlONrZPzF+dy/ZqelpTFhwoRM5xYoUCBHyyVcffXVJCYmMm3aNK+Buk2bNrFgwQLPaxXI2rZti9PpZNiwYZnuczgcnn7/+++/mWaHZszUSE1NveBz1K5dm02bNl30vMtx7lfOwF3gVqlSxSuujIH1rAZbL1W5cuX45ZdfSEtL87TNmTOHPXv2XNF1//jjD86ePUudOnWuNEQRERGxmNGjR7N8+XImT57MsGHDqFOnDt27d/eslV+9enXKlSvHG2+8wcmTJzM9/vDhw4C7jmrZsiXffvsta9asyXReRi2YMQi8ZMkSz32nTp3K8ltOBQoUyFENVqNGDYoVK8bEiRO9asbvvvuOzZs358rvIHmtbdu27Nu3j3fffTfTfWfOnOHUqVMAWc54zmndLSL5k2baiuRj5cqV49NPP+WBBx6gYsWKtG/fnkqVKpGWlsby5cuZMWMGHTt2BKBq1ap06NCByZMne5YcWLVqFdOmTaNly5Y0aNDAp7GFh4czf/58OnToQK1atfjuu++YO3cuL7zwQrbrPWWlTp06FCxYkA4dOtCzZ09sNhsfffRRll9Dr169Op9//jl9+vTh5ptvJioqirvvvjvL644cOZJmzZpRu3ZtOnfuzJkzZ3jrrbeIjY1l8ODBl9ttv1GvXj26du3K8OHDWb9+PXfeeSchISFs27aNGTNmMG7cOFq3bs20adOYMGEC9913H+XKlSMlJYV3332XmJiYiw5e33vvvQwbNoyffvqJO++806fxd+nShWPHjtGwYUNKlizJrl27eOutt0hMTPTMFE9MTCQoKIjXXnuNEydOEBYWRsOGDbNdx/dizzdz5kyaNm1K27Zt2b59Ox9//LHnl5fLtXDhQiIjI7njjjuu6DoiIiKSv3z33XdZbqhbp04dypYty+bNm3nppZfo2LGjp56dOnUqiYmJnrX/7XY77733Hs2aNePGG2+kU6dOXHPNNezbt4+kpCRiYmL49ttvAXj11VdZsGAB9erV4/HHH6dixYocOHCAGTNmsHTpUuLi4rjzzjuJj4+nc+fO9O3bl6CgID744AOKFi3K7t27veKsXr0677zzDi+//DLXXnstxYoVyzSTFiAkJITXXnuNTp06Ua9ePdq1a8c///zDuHHjSEhI4Omnn86FVzdvPfLII3zxxRd069aNpKQk6tati9PpZMuWLXzxxRd8//331KhRg6FDh7JkyRJatGhB6dKlOXToEBMmTKBkyZLceuutZndDREygQVuRfO6ee+7ht99+Y+TIkXz99de88847hIWFUaVKFUaNGsVjjz3mOfe9996jbNmyTJ06la+++orixYvz/PPPM2jQIJ/HFRQUxPz58+nevTt9+/YlOjqaQYMGMXDgwEu6TuHChZkzZw7PPPMMAwYMoGDBgjz88MM0atSIJk2aeJ37xBNPsH79eqZMmcKYMWMoXbp0toO2jRs3Zv78+Z6YQkJCqFevHq+99lqWG2MFookTJ1K9enUmTZrECy+8QHBwMAkJCTz88MPUrVsXwDN4P336dP755x9iY2OpWbMmn3zyyUVfh+rVq1OlShW++OILnw/aPvzww0yePJkJEyZw/PhxihcvzgMPPMDgwYM9a9IWL16ciRMnMnz4cDp37ozT6SQpKemyBm2bNGnCqFGjGD16NL1796ZGjRqevLsSM2bMoFWrVkRHR1/RdURERCR/ya4mnjJlCqVLl6ZDhw4UKVLE61ty1113HcOHD6dXr1588cUXtG3blvr167NixQqGDRvG+PHjOXnyJMWLF6dWrVp07drV89hrrrmGlStX8tJLL/HJJ5+QnJzMNddcQ7NmzYiMjATcA6xfffUVTzzxBC+99BLFixend+/eFCxYkE6dOmWKf9euXbz++uukpKRQr169LAdtATp27EhkZCQjRoygf//+FChQgPvuu4/XXnvNaxmGQGW325k9ezZjxozhww8/5KuvviIyMpKyZcvSq1cvz3Jx99xzDzt37uSDDz7gyJEjFClShHr16jFkyBDPZsMiYi02I7d3yREROU/Hjh2ZOXNmll/Tkvzlo48+4sknn2T37t35ouj2pfXr13PTTTexdu3abDcHEREREREREWvSmrYiIpJrHnroIeLj43n77bfNDsXvjBgxgtatW2vAVkRERERERDLR8ggiIpJr7Hb7RXcitqrp06ebHYKIiIiIiIj4Kc20FREREREREREREfEjWtNWRERERERERERExI9opq2IiIiIiIiIiIiIH9GgrYiIiIiIiIiIiIgfsdxGZC6Xi/379xMdHY3NZjM7HBERERFLMAyDlJQUSpQogd2ueQN5QXWviIiISN7yZc1ruUHb/fv3U6pUKbPDEBEREbGkPXv2ULJkSbPDsATVvSIiIiLm8EXNa7lB2+joaMD94sXExJgcjYiIiIg1JCcnU6pUKU8tJrlPda+IiIhI3vJlzWu5QduMr4bFxMT4ffHqdDrZtm0b1113HUFBQWaHI5KrlO9iNcp5sRqn0wmgr+nnoUCpe/V5KFajnBcrUb6L1fiy5tWCYn7M5XKxdetWXC6X2aGI5Drlu1iNcl6sRrku2dHnoViNcl6sRPkuVuPLXNegrYiIiIiIiIiIiIgf0aCtiIiIiIiIiIiIiB+x3Jq2OeV0OklPTzc9hlKlSpGWluZZE0OsJSQkxDLr/tjtduLj47Hb9bcksQblvFiNct0/qeaVQBMaGhrwnyeqAcRKlO9iNb7MdZthGIbPrhYAkpOTiY2N5cSJE1luyGAYBgcPHuT48eN5H5xIFuLi4ihevLg2bhERkYB2sRpMfO9Cr7lqXglUdrudMmXKEBoaanYoIiIimfiy5tVM2/NkFK/FihUjMjLS1IEywzA4e/Ys4eHhGrCzIMMwOH36NIcOHQLg6quvNjmi3OV0Ovntt9+oUqWKZWYXi7Up58VqNIPSv6jmlUDkcrnYv38/Bw4cID4+PmDzRTWAWInyXazGlzWvBm3P4XQ6PcVr4cKFzQ4Hl8tFamoqYWFh+iqBRUVERABw6NAhihUrlq//J+dyudi9ezeVKlXK1/0UyaCcF6vRrtH+QzWvBLKiRYuyf/9+HA4HISEhZodzWVQDiJUo38VqfFnzqio6R8Z6XpGRkSZHIvKfjHw0e705ERERyR9U80ogy1gWQbP3RUQkv9OgbRYC9Ws2kj8pH0VERCQ3qMaQQKS8FRERq9CgrR+z2WyEhYWpMBFLsNvtlC9fXl+LFMtQzovVKNclO6p5xWpUA4iVKN/FanyZ6/pX48dsNhsREREqYC9i8eLF2Gw27X4c4IKCgqhQoYLWORLLUM6L1SjXJTuqeX1LtbH/Uw0gVqJ8F6vxZa5r0DaXOF1OFu9czGcbP2PxzsU4XZe+5pJhGJw8eRLDMHIhwtyxc+dObDYb69evz/Tz4MGDsdlsF7ydr2PHjhc8PyEhgTp16nDgwAFiY2PzuLfiSw6Hg+XLl+NwOMwORSRPKOfFapTr+ZPVa96MW6FChahXrx4///xzrj6vauP8QTWAWInyXazGl7muQdtcMGvzLBLGJdBgWgMenPUgDaY1IGFcArM2z7qk6xiGgcPhMK2ATUtL8+n1nn32WQ4cOOC5lSxZkqFDh3q1nW/cuHGZ7p8yZYrn59WrVxMaGkrx4sU1OyPAGYbB4cOHA+oXNpEroZwXq1Gu5z+qeeGHH37gwIEDLFmyhBIlSnDXXXfxzz//+DA6b6qN8wfVAGIlynexGl/mugZtfWzW5lm0/qI1e5P3erXvS95H6y9aX3IRm1Opqan07NmTYsWKER4ezq233srq1as990+dOpW4uDivx8yePdurmBs8eDCJiYm89957lClThvDwcABmzpxJ5cqViYiIoHDhwjRu3JhTp05dcoxRUVEUL17ccwsKCiI6Otqr7XyxsbGZ7o+Li/P8XLRo0UxfAcvo65w5cyhfvjyRkZG0bt2a06dPM23aNBISEihYsCA9e/b02nU2NTWVZ599lmuuuYYCBQpQq1YtFi9efMn9FBEREcnvVPO6FS5cmOLFi1OpUiVeeOEFkpOTWbly5SXH8tFHH5GQkEBsbCz/+9//SElJyfL5VBuLiIhYR7DZAfg7wzA4nX46R+c6XU56ftcTg8yj6gYGNmz0+q4Xjcs0Jsh+8TUuwoPCcxxnv379+PLLL5k2bRqlS5fm9ddfp0mTJvz1118UKlQox9f566+/+PLLL5k1axZBQUEcOHCAdu3a8frrr3PfffeRkpLCzz//7Pd/JTt9+jRvvvkm06dPJyUlhVatWnHfffcRFxfHvHnz+Pvvv7n//vupW7cuDzzwAAA9evTgjz/+YPr06ZQoUYKvvvqKpk2bsnHjRq677jqTeyQiIiKSe1TzXlnNe+bMGT788EMAQkNDcxwHwPbt25k9ezZz5szh33//pW3btowYMYJXXnnlkq5zIaqNRUREAo8GbS/idPppooZH+eRaBgZ7U/YS+1rO1pdKeS4lR5synDp1infeeYepU6fSrFkzAN59910WLlzI+++/T9++fXMcY1paGh9++CFFixYFYO3atTgcDlq1akXp0qUBqFy5co6vZ5b09HTeeecdypUrB0Dr1q356KOP+Oeff4iKiuKGG26gQYMGJCUl8cADD7B7926mTJnC7t27KVGiBOBezmH+/PlMmTKFV1991czuWEJQUBCJiYlaoF4sQzkvVqNc92+qeS+v5q1Tpw52u53Tp09jGAbVq1enUaNGOY4DwOVyMXXqVKKjowF45JFH+PHHH306aKva2L+oBhArUb6L1fgy1zVo68dsNhthoWEXPW/79u2kp6dTt25dT1tISAg1a9Zk8+bNl/ScpUuX9hSvAFWrVqVRo0ZUrlyZJk2acOedd9K6dWsKFix4SdfNa5GRkZ6iFOCqq64iISGBqKgor7ZDhw4BsHHjRpxOJ9dff73XdVJTUylcuHDeBG1xdrvd80uSiBUo58Vq7HatyiVZC+Sa9/PPP6dChQps2rSJfv36MXXqVEJCQi4ploSEBM+ALcDVV1/tqVF9RbWxf1ENIFaifBer8WXNq0Hbi4gMieTk8ydzdO6SXUto/mnzi54378F53F769oueFxEcQXJyMtHR0Ve8kYDdbs/09a709PRM5xUoUMDr56CgIBYuXMjy5ctZsGABb731Fi+++CIrV66kTJkyVxRTbjq/WLbZbFm2uVwuAE6ePElQUBC//vprpr+KnFvMSu5xOBwsWbKE22+/neBgfTRJ/qecF0txOnFqLUy/ppr38mreUqVKcd1113HdddfhcDi477772LRpE2FhYTmO5UI1qq+oNvYvqgHESpTvYik+rnk15eEibDYbBUIL5Oh2Z7k7KRlTEhtZF5s2bJSKKcWd5e7M0fXA/XWpi62lVa5cOUJDQ1m2bJmnLT09ndWrV3PDDTcAULRoUVJSUrw2U1i/fn2OX4O6desyZMgQ1q1bR2hoKF999VWOHhsoqlWrhtPp5NChQ1x77bVet6w2SBPfMwyDlJQUv18vWcRXlPNiGbNmQUICQffea3YkcgGqea+85m3dujXBwcFMmDDhimMxm2rj3KUaQKxE+S6WkQs1rwZtfSjIHsS4puMAMhWxGT+PbTo2RxsyXIoCBQrQvXt3+vbty/z58/njjz947LHHOH36NJ07dwagVq1aREZG8sILL7B9+3Y+/fRTpk6detFrr1y5kldffZU1a9awe/duZs2axeHDh6lYsaJP+2C266+/noceeoj27dsza9YsduzYwapVqxg+fDhz5841OzwREZHANGsWtG4Ne/eaHYn4kGrerNlsNnr27MmIESM4ffr0ZcfiD1Qbi4iIXIJcqnk1aOtjrSq2YmbbmVwTc41Xe8mYksxsO5NWFVvlyvOOGDGC+++/n0ceeYSbbrqJv/76i++//96zDlehQoX4+OOPmTdvHpUrV+azzz5j8ODBF71uTEwMS5YsoXnz5lx//fUMGDCAUaNGeTZ/OF/GV6oC8WsPU6ZMoX379jzzzDOUL1+eli1bsnr1auLj480OTUREJPA4ndCrF2hmTb5k9Zo3Ox06dCA9PZ3x48dfdiz+QrWxiIhIDuRizWszLDZHPTk5mdjYWE6cOEFMTIzXfWfPnmXHjh2UKVOG8PDwK3oep8vJz7t/5kDKAa6Ovprb4m+75NkGhmHgcDgIDg6+4vW98sovv/xC7dq1OXz4MEWKFDE7nHzBl3npz1wuF0eOHKFIkSLarEYsQTkv+d7ixdCggefHZCAWsqzBJHdkV/eq5pVAlh9qY9UAYiXKd8n3crHmDbzpkAEiyB5E/YT6V3SNrDYI8FcOh4OdO3cycuRIqlatqgFbuWR2u51ixYqZHYZInlHOS7534IDZEUgesFrNK+ILqgHESpTvku/lYs2rP3P4MZfLxfHjx32+e2xu2LRpE1WqVOHAgQN8+OGHZocjASg9PZ25c+dmuauySH6knJd87+qrzY5AAkQg1bwivqAaQKxE+S75Xi7WvJppKz6RmJjI6dOnzQ5DApzD4TA7BJE8pZyXfO2226BkSW1CJiKSBdUAYiXKd8nXcrHm1UxbEREREfG9oCBo1MjsKEREREREck9QEIwblyuX1qCtiIiIiPje7t0wc6b7uGBBc2MREREREcktV12VK5fVoK0fs9lsREdHaxddsYTg4GAaNGhAcLBWbRFrUM5LvvfUU3DqFNStC//8g/Htt2ZHJH5KNa9YjWoAsRLlu1jCoEHu/3bu7NOaV/9q/JyKV7GSiIgIs0MQyVPKecm3Zs+Gb76B4GCYNAlCQtzrfYlkQzWvWI1qALES5bvka0uWwI8/uuvdl16CuDifXVozbf2YYRgkJydjGIbZoYjkOofDwbx587RIvViGcl7yrZQU9yxbgL594cYbAW1CItlTzStWoxpArET5LvneObNsKV3ap7muQVsRERER8Z2BA92755YpAwMGmB2NiIiIiEjuSEqCxYshNBReeMHnl9egrQBw+vRp7r//fmJiYrDZbBw/fjzLNl+pX78+vXv3vuA5CQkJjB071mfP6QsdO3akZcuWZochIiLin9auhTffdB9PmACRkebGI3Ie1by+pdpYREQsyzD+m2X72GNQqpTPn0KDtrnE6XQPtn/2mfu/Tqd5sbz77rvcdtttFCxYkIIFC9K4cWNWrVrldc60adP4+eefWb58OQcOHCA2NjZT27///ovNZmP9+vXmdOQ8gwcPJjExMcufExISsNls2d46duyY6XoXOt9mszF48GDGjRvH1KlT86R/IiIiAcXphK5dweWCBx6Apk3NjkjygGre3Dd48GBPPRoUFESpUqV4/PHHOXbsWK4+r2pjERGRC/jxR/j5ZwgLg+efz5Wn0EZkuWDWLOjVy/3NwAwlS8K4cdCqVc6vY7PZPH/xvxKLFy+mXbt21KlTh/DwcF577TXuvPNOfv/9d6655hoAtm/fTsWKFalUqZLncee37dy584riyEurV6/G+f+/NSxfvpz777+frVu3EhMTA2S9EPqBAwc8x59//jkDBw5k69atnraoqCiioqJyOXLrCg4Opnnz5tpVVCxDOS/5zoQJsGYNxMTAmDGZ7lau5z+qefPOjTfeyA8//IDT6WTz5s08+uijnDhxgs8//zzXnlO1ce5RDSBWonyXfOncWbbdusH/1xng25pXM219bNYsaN3au3gF2LfP3T5r1qVdL6cbMsycOZPKlSsTERFB4cKFady4MadOnQLgk08+4YknniAxMZEKFSrw3nvv4XK5+PHHHwH317ZGjRrFkiVLsNls1K9fP8u2MmXKAFCtWjVPW3Z++uknatasSVhYGFdffTXPPffcBRdjPnToEHfffTcRERGUKVOGTz75JIevUNaKFi1K8eLFKV68OIUKFQKgWLFinrbY2NhMj8m4L+N+m83m1RYVFZXpK2D169fnqaeeonfv3hQsWJCrrrqKd999l1OnTtGpUyeio6O59tpr+e6777yea9OmTTRr1oyoqCiuuuoqHnnkEY4cOXJFfc4Pzpw5Y3YIInlKOS/5xr598OKL7uMRI+Dqq82NR3Kdal63vKp5g4ODKV68ONdccw2NGzemTZs2LFy40HN/VsswtGzZ0uvbZQkJCbz66qs8+uijREdHEx8fz+TJk7N9TtXGuUs1gFiJ8l3ynQULYPlyCA+H/v1z7Wk0aHsRhgGnTuXslpwMPXu6H5PVdcA9GyE5OWfXc7kMUlJSLlrEHjhwgHbt2vHoo4+yefNmFi9eTKtWrbJ93OnTp0lPT/cMZs6aNYvHHnuM2rVrc+DAAWbNmpVlW8bXy3744QdPW1b27dtH8+bNufnmm9mwYQPvvPMO77//Pi+//HK2fejYsSN79uwhKSmJmTNnMmHCBA4dOnTBfvuLadOmUaRIEVatWsVTTz1F9+7dadOmDXXq1GHt2rXceeedPPLII5w+fRqA48eP07BhQ6pVq8aaNWuYP38+//zzD23btjW5J+ZyOBwkJSVpV1GxDOW85Cu9e0NKCtSq5V4iIQvKdf+mmjdwat6dO3fy/fffExoaekmPAxg1ahQ1atRg3bp1PPHEE3Tv3t1r9qwvqDa+ONUAYiXKd8l3DMO98S7AE09kmqzgy1zX/PSLOH0afPWtH8Nwz0bIYpJnlpKTc3begQMHcDgctGrVitKlSwNQuXLlbM/v378/JUqUoHHjxgAUKlSIyMhIQkNDKV68uOe889uS/z+gwoULe513vgkTJlCqVCnGjx+PzWajQoUK7N+/n/79+zNw4EDsdu+/Ffz555989913rFq1iptvvhmA999/n4oVK+bsBTBZ1apVGfD/u2M///zzjBgxgiJFivDYY48BMHDgQN555x1+++03brnlFsaPH0+1atV49dVXPdf44IMPKFWqFH/++SfXX3+9Kf0QERG5LHPnwsyZEBQEkyaBXXMCApFqXv+ueTdu3EhUVBROp5OzZ88CMHr06Is+7nzNmzfniSeeANyvz5gxY0hKSqJ8+fKXfK3sqDYWEZF87bvvYNUqiIiAfv1y9ak0aJsPVK1alUaNGlG5cmWaNGnCnXfeSevWrSlYsGCmc0eMGMH06dNZvHgx4eHhuRLP5s2bqV27tte6ZHXr1uXkyZPs3buX+Pj4TOcHBwdTvXp1T1uFChWIi4vLlfh8rUqVKp7joKAgChcu7PULxFVXXQXgmUWxYcMGkpKSslwDbPv27SpMRUQkcJw6BU8+6T5++mmoWtXceCRfs3LNW758eb755hvOnj3Lxx9/zPr163nqqacuOeZz69aM5Q58/e021cYiIpJvnbuWbY8e8P//T8stmgpxEZGRcPJkzm7z5uXsmvPm5ex6kZE5u15QUBALFy7ku+++44YbbuCtt96ifPny7Nixw+u8N954gxEjRrBgwQKvYkquTEhIiNfPNpvNqy2jkHe5XACcPHmSu+++m/Xr13vdtm3bxu233553gfshLU4vVqOcl4A3dCjs2gXx8TB4sNnRyBVQzevfQkNDufbaa6lUqRIjRowgKCiIIUOGeO632+2ZlolIT0/PdJ2s6taMGtVXVBvnjGoAsRLlu+Qbc+a4N94tUAD69s31p9Og7UXYbO73Iie3O+9075ib3ca3NhuUKuU+LyfXCwqyExcXl+mrVVlf20bdunUZMmQI69atIzQ0lK+++spz/+uvv86wYcOYP38+NWrUuKzXImPdLKfTecHzKlasyIoVK7wKx2XLlhEdHU3JkiUznV+hQgUcDge//vqrp23r1q0cP378suL0dzfddBO///47CQkJXHvttV63AgUKmB2eaUJCQmjRokWmQl8kv1LOS8DbuBEyvp49fry7eLkA5bp/U837n0CoeQcMGMAbb7zB/v37AfcmvAcOHPDc73Q62bRp0yVf1wxWrI1VA4iVKN8l3zh3LdunnoKiRbM8zZe5rkFbHwoKgnHj3MfnF7EZP48d6z4vJwzDID09/aKbMqxcuZJXX32VNWvWsHv3bmbNmsXhw4c962O99tprvPTSS3zwwQckJCRw8OBBDh48yMmTJy+hd1CsWDEiIiI8mwOcOHEiy/OeeOIJ9uzZw1NPPcWWLVv4+uuvGTRoEH369MmyGC9fvjxNmzala9eurFy5kl9//ZUuXboQERFxSfEFiieffJJjx47Rrl07Vq9ezfbt2/n+++/p1KnTRX85yM9cLheHDh3y+WwPEX+lnJeA5nLB44+DwwGtWsHdd+fgIcr1/EI1r5uZNW/t2rWpUqWKZx3Yhg0bMnfuXObOncuWLVvo3r17wEyAsGJtrBpArET5LvnG7Nmwfr17E4Bnn832NF/mugZtfaxVK/deHNdc491esqS7vVWrnF/LMAxOnTp10QI2JiaGJUuW0Lx5c66//noGDBjAqFGjaNasGQDvvPMOaWlptG7dmquvvtpze+ONNy6pb8HBwbz55ptMmjSJEiVKcO+992Z53jXXXMO8efNYtWoVVatWpVu3bnTu3NmzIUFWpkyZQokSJahXrx6tWrXi8ccfp1ixYheMx+VyBeTXLEqUKMGyZctwOp3ceeedVK5cmd69e+d4hkl+5XQ6WbFiRb4tzkXOp5yXgPbuu/DLL+6iNWP07iKU6/mLat68q3mz8/TTT/Pee++xZ88eHn30UTp06ED79u2pV68eZcuWpUGDBpd13bxmxdpYNYBYifJd8gWX67+lwHr3hsKFsz3Vl7luMy5WHeUzycnJxMbGcuLECWJiYrzuO3v2LDt27KBMmTJXvGGB0wk//wwHDsDVV8Ntt+V8tkEGl8tFcnIyMTEx+bZguRLdunVj7969zJkzx+xQcpUv89KfpaenM2/ePJo3b66vzoglKOclYB08CBUrwvHj7umUvXrl6GFHjx6lSJEiWdZgkjuyq3tV80ogyw+1sWoAsRLlu+QLM2dCmzYQEwM7d0IWm6Bm8GXNG3jTFANEUBDUr292FPlTSkoK69atY9asWbzwwgtmhyMiImItffq4B2xvusm9a65YmmpeERERydfOnWX79NMXHLD1Nf0p24/ZbDbsdrtnh1VxGzhwIK1bt+a+++6jW7duZocjPmKz2YiOjla+i2Uo5yUgLVgAn30GdjtMnnxJUyqV65Id1bxiNaoBxEqU7xLwZsyA33+H2Fj30ggX4ctc10xbP2az2fT1wSyMGTOGMWPGmB2G+FhwcDANGzY0OwyRPKOcl4Bz5gw88YT7uEcPqF79kh4eiOvQS95QzStWoxpArET5LgHN6fxvlu0zz0Bc3EUf4suaVzNt/ZhhGKSmpl50UwaR/MDlcrFr1y7tKiqWoZyXgPPKK7B9O5QoAcOGXfLDleuSHdW8YjWqAcRKlO8S0D7/HLZscS+JkMN9HHyZ6xq09WOGYXDmzBkVsGIJTqeT9evXa1dRsQzlvASUzZvh9dfdx2+95d6E4RIp1yU7qnnFalQDiJUo3yVgORwwZIj7+Nlnc1z/+jLXNWgrIiIiItlzuaBrV0hPh7vugvvuMzsiEREREZHc9emn8OefULgwPPWUKSFo0FZEREREsjd1Kvz8M0RGwvjxoI1ERERERCQ/czhg6FD3cb9+EB1tShgatPVjNpuN4OBg7bIolmCz2ShatKjyXSxDOS8B4fBh6NvXfTxkCJQufdmXUq5LdlTzitWoBhArUb5LQProI/deDkWLwpNPXtJDfZnr2sbXj9lsNqKioswOQyRPBAcHU6dOHbPDEMkzynkJCH37wrFjUKVKjjdfyI4vd9KV/EU1r1iNagCxEuW7BJz09P823e3fHwoUuKSH+7Lm1UxbP+avmzIsXrwYm83G8ePHzQ7F5xISEhg7dqzZYViS0+lky5YtWqBeLEM5L34vKQmmTXMvhzBpEoSEXNHlAjnX3377bRISEggPD6dWrVqsWrXqgufPmDGDChUqEB4eTuXKlZk3b57nvvT0dPr370/lypUpUKAAJUqUoH379uzfv9/rGgkJCdhsNq/biBEjcqV/ZlPN679UG+cO1QBiJcp3CTjTpsGOHXDVVdC9+yU/XBuRBQKnExYvhs8+c//3Mt40wzBITU3NUQHbsWNHWrZsmak9kIvN+vXr07t370w/79y5M9MvMeffpk6d6nWtjNfhQrfFixezevVqHn/88bztqADgcrnYunUrLpfL7FBE8oRyXvxaaip06+Y+7tYNbrnlii8ZqLn++eef06dPHwYNGsTatWupWrUqTZo04dChQ1mev3z5ctq1a0fnzp1Zt24dLVu2pGXLlmzatAmA06dPs3btWl566SXWrl3LrFmz2Lp1K/fcc0+maw0dOpQDBw54bk+ZtAnGBanmvWL169f31KPh4eFcf/31DB8+PFcHsVUbm0s1gFiJ8l0CSlraf7Nsn3vOvafDJfJlrut7arlh1iz3Vwj37v2vrWRJGDcOWrUyL658olSpUhw4cMDz8xtvvMH8+fP54YcfPG2xsbFej6lTp47XY3r16kVycjJTpkzxtBUqVIjQ0NBcjFxERCRAvPaae7fc4sXh1VfNjsZUo0eP5rHHHqNTp04ATJw4kblz5/LBBx/w3HPPZTp/3LhxNG3alL7/vxbwsGHDWLhwIePHj2fixInExsaycOFCr8eMHz+emjVrsnv3buLj4z3t0dHRFC9ePBd7d4VU8/rMY489xtChQ0lNTWXRokU8/vjjxMXF0f0yZvjkhGpjERGRLHzwAezeDVdfDV27mh2NZtr63KxZ0Lq1d/EKsG+fu33WLHPiOsfSpUu57bbbiIiIoFSpUvTs2ZNTp0557v/oo4+oUaOG5xeFBx98MNvZJABTp04lLi6O2bNnc9111xEeHk6TJk3Ys2cPADt37sRut7NmzRqvx40dO5bSpUtf8l8hgoKCKF68uOcWFRVFcHCwV1tERITXY0JDQzPdHxYW5tUWGhqa6StgNpuNSZMmcddddxEZGUnFihVZsWIFf/31F/Xr16dAgQLUqVOH7du3ez3f119/zU033UR4eDhly5ZlyJAhOByOS+qniIiIKf7887+B2jFjIC7O1HDMlJaWxq+//krjxo09bXa7ncaNG7NixYosH7NixQqv8wGaNGmS7fkAJ06cwGazEXfeaz1ixAgKFy5MtWrVGDlypH/VEqp5fVrzRkZGUrx4cUqXLk2nTp2oUqWK1+C+zWZj9uzZXo+Ji4vzfLss45tos2bNokGDBkRGRlK1atVs8061sYiIyHlSU+GVV9zHzz8P540rmUGDthdjGHDqVM5uycnQs6f7MVldB9yzEZKTc3Q9G+6Cypc7z23fvp2mTZty//3389tvv/H555+zdOlSevTo4TknPT2dYcOGsWHDBmbPns3OnTvp2LHjBa97+vRpXnnlFT788EOWLVvG8ePH+d///ge418Jq3Lix11/uAaZMmULHjh2x2/07DYcNG0b79u1Zv349FSpU4MEHH6Rr1648//zzrFmzBsMwvF6/n3/+mfbt29OrVy/++OMPJk2axNSpU3kl4x+/ZMlutxMfH+/3+SDiK8p58UuG4V67KzUVmjSBBx7w2aUDMdePHDmC0+nkqquu8mq/6qqrOHjwYJaPOXjw4CWdf/bsWfr370+7du2IiYnxtPfs2ZPp06eTlJRE165defXVV+nXr98F401NTSU5OdnrBu7aLuOWsc6aYRi4XC7PzXC54NQpjJMncaWkeG7GyZOZ248fx7hAzWsARs+euI4f97rO+dd2paS4617DIOT/10x2uVyeZQEyxXjO812sfdu2bZ6ad8OGDXz22WcsXbqUJ5980nN+WloaQ4YMYd26dcyaNctT85577Yw4MmLJqHmnTp3K0qVLPTWvy+UiPj6eRo0aMWXKFK/Yp0yZQocOHbDb7dn2KaPd6XTy008/sWXLFkJDQz3tWb025/Y34+cXX3yRPn36sHbtWq677jratWuHw+Hwes6M88+NJePn89vPfa6M/w4bNoyHH36YdevWedXG/fv3Z9WqVZ7aOOM6P/30E+3bt6dnz5788ccfTJw4kalTp/Lyyy/n6P3Orv3cvE5PT/fEntP2jNfv3LaMgeTs2p1OZ5b/nrJrdzgcXu3nvpclS5b0PC6jPZD7lF27+qQ+AV75nh/6lB/fJ/UJXJMnw969GNdcQ3rHjpfdJ1+uaavlES7m9Gnw1W62huGejXDeV/ezYzt5kshL2KVuzpw5mXbePT9Zhg8fzkMPPeRZK/a6667jzTffpF69erzzzjuEh4fz6KOPes4vW7Ysb775JjfffDMnT57Mdmff9PR0xo8fT61atQCYNm0aFStWZNWqVdSsWZMuXbrQrVs3Ro8eTVhYGGvXrmXjxo18/fXXOe6fWTp16kTbtm0B6N+/P7Vr1+all16iSZMmgPvrZBlfmQQYMmQIzz33HB06dADcr+GwYcPo168fgwYNyvsOBIigoCCqVatmdhgieUY5L37pk09g0SIID4cJE9ybkPlIUFCQz66VX6Snp9O2bVsMw+Cdd97xuq9Pnz6e4ypVqhAaGkrXrl0ZPnw4YWFhWV5v+PDhDBkyJFP7ggULiPz/Ndni4+OpWLEi6enpnDx5krS0NADCHA4iihbFBmT1rmfXnhWbYcC+fdgKFszyOudz/vsv6eD55ctutxMTE0NaWhpnzpzxnJexG/OcOXO8Brjhv5o3JSUFu93OsGHDeOCBB+jduzcnT56kWLFivPLKK9x1112MGzeO2NhY2rRp4/kFq0iRIowePZratWuzf/9+oqKiOH36NPDfYOnZs2dJT09n+PDh3HjjjcTExPDBBx9QqVIlkpKSqF69Og8++CDPPPMMr732Gg6Hgw0bNrBx40Y+/vhjgGz79M477/D++++TlpZGeno64eHh9OzZkzNnznjeozNnznD27FkiIiI4deqUJ6bk5GTPOU8++SS33XYbAM8++yy1a9fmr7/+yrTMRnR0NDabzWtg3+FweH4pTklJ8fT97NmzAJ5fdtu1a0fTpk2x2+2e2viZZ56hdu3aAHTt2pXu3btz9uxZUlNTGTRoEL169aJt27aeGcXPPfccgwcPpnfv3oSFhXn6dO4gT8bs35SUFK9ZyhmfJcuWLSM1NdXT3qBBAyIiIrw2/gNo3rw5Z86cISkpyet1b9GiBUeOHPGajRwdHU3Dhg3Zs2cP69ev97QXLVqUOnXqsG3bNrZu3eppj4+Pp1q1avz222/s3r3b016+fHkqVKjAqlWrOHz4sKc9MTGR0qVLs2zZMlJSUtj7/zPVa9euTbFixViwYIHXaxBIfVqyZIknb9Qn9encPu3fv5+9e/d68j0/9Ck/vk+W71PFijiGDSMU+O2uu9i5aNFl9+mnn37CZwyLOXHihAEYJ06cyHTfmTNnjD/++MM4c+bMf40nTxqGe7g1z2+ulBTj1KlThsvlumi/OnToYDRu3NjYtm2b1+3jjz82AOPff/81DMMwatSoYYSGhhoFChTw3CIjIw3A+OOPPwzDMIw1a9YYd911l1GqVCkjKirKc//vv/9uGIZhJCUleV1zypQpRnBwsOF0Or1iiouLM6ZOnWoYhmGkpqYaRYoUMT777DPDMAzjqaeeMho2bHjBPtWrV8/o1atXtj9nGDRokFG1atWLvkbnv1733ntvpvbSpUsbY8aM8fwMGF988YXn57///tsAjFWrVnnaFi1a5JVTRYoUMcLDw71e4/DwcAMwTp06dUlxGkY2eZkPORwOY+3atYbD4TA7FJE8oZwXv3P0qGEULequQ155xeeXP3bsWLY1mL9KTU01goKCjK+++sqrvX379sY999yT5WNKlSrlVUsYhmEMHDjQqFKlildbWlqa0bJlS6NKlSrGkSNHLhrLpk2bDMDYsmVLtuecPXvWOHHihOe2Z88eAzCOHDlipKWlGWlpaYbD4TDOnDlj/P7778apU6cMp9NpOJ1Ow5WSYlrN60xONk6ePPlfLP9f+7pcLk9bRntGzbt161bP7c8///TUvEePHjWcTucFa96Mmnb16tVGixYtMtW8GzduNJxOp/Hjjz8agHHs2DHD5XIZ77//vhEcHGykp6d74nG5XEZcXJzxwQcfGE6n0zhz5oxRpEgR49NPPzWcTqfRo0cPo2HDhhfsU7169YyOHTsaW7duNX755RejadOmxtChQ73OB4wvv/zS6zqxsbHG+++/bzidTk+NunLlSs+1jxw5YgDGTz/95PWc58ae8XNGTp/fXrp0aWP06NGe5wSM6dOne66R8by//PKL5zEZr9vx48cNp9N5wdo4JSXlgq9NVu2nT582/vjjDyM5OdmT12lpaZ7Yz227ULthGIbT6fRqS09Pv2C7w+Hwas/4f3h27enp6V7tGb8vnT171lizZo1x5swZr/acxu6PfcquXX1Sn9LS0rzyPT/0KT++T5bv07hx7nG4UqWMtJSUK+rTP//847OaVzNtLyYyEk6ezNm5S5ZA8+YXP2/ePLj99oueZoSHk5aSQnh4eI6WSChQoADXXnutV9ve89YZO3nyJF27dqVnz56ZHh8fH8+pU6do0qQJTZo04ZNPPqFo0aLs3r2bJk2aeP6CfzlCQ0Np3749U6ZMoVWrVnz66aeMGzfusq+XlzK+rgd43oes2jJmAJw8eZIhQ4bQKosNOMLDw3Mz1IDmcrnYvXs3lSpV0mwssQTlvPid/v3h8GG44QZ49lmfX/7cmXKBIjQ0lOrVq/Pjjz/SsmVLwN2PH3/80WtppHPVrl2bH3/80fOtJoCFCxd6ZiHCfzNst23bRlJSEoULF75oLOvXr8dut1OsWLFszwkLC8tyFm5ISIhX7ZKeno7NZsNut/+3bEWBAqbVvISHk56SQkREhNcyGjabLcsauECBAlx//fVebfv27QPw9CknNW/Tpk2zrHkdDofXa5MRR8bPXq/b/8toCw8Pp3379kydOpX777+fzz77jHHjxnn6kV2fYmNjPX2aMWMG1157LbVr16Zx48aex5z7WJvNRnp6uud5M9pDQ0M9sWX8v8XlcmW7PMn5cZ37c1bngDvPzn1tzm/L+K9hGJ73IrvaODIy8qKvzfnt59bj5+Z1hqzasmvP6r28UHtQUFCW/8/Orj1jJnVW19+7dy9VqlTxiutSYs+u3aw+ZdeuPqlPQJb5Hsh9yo/vk6X7dOYMDB8OgG3AAELO+4b5pfbJl7/badD2Ymw2dxGbE3fe6d4xd9++rNf4stnc9995J+TkTcyFX25uuukm/vjjj0yDuxk2btzI0aNHGTFiBKVKlQLItJlCVhwOB2vWrKFmzZoAbN26lePHj1OxYkXPOV26dKFSpUpMmDABh8ORZeGWH9x0001s3bo129dYRETE7yxdCu+95z6eOBG0Y7xHnz596NChAzVq1KBmzZqMHTuWU6dOeZZGat++Pddccw3D/7/Y79WrF/Xq1WPUqFG0aNGC6dOns2bNGiZPngy4B0xbt27N2rVrmTNnDk6n07PebaFChQgNDWXFihWsXLmSBg0aEB0dzYoVK3j66ad5+OGHKZjFkgM+oZrXb2reqKgoevXqxbPPPsu6deuw2WwULVqUAwcOeM7Ztm2bZ/kGf6faWERE/NrEiXDwICQkwEX2c8prgbcjhD8LCoKM2aPn/5U44+exY3NWvOaS/v37s3z5cnr06MH69evZtm0bX3/9tWe2SHx8PKGhobz11lv8/ffffPPNNwwbNuyi1w0JCeGpp55i5cqV/Prrr3Ts2JFbbrnFU9ACVKxYkVtuucWz2UaEH+zElxsGDhzIhx9+yJAhQ/j999/ZvHkz06dPZ8CAAWaHJiIikllaGnTr5j7u3Bn+fz1McXvggQd44403GDhwIImJiaxfv5758+d7NhvbvXu312BanTp1+PTTT5k8eTJVq1Zl5syZzJ49m0qVKgHuGaHffPMNe/fuJTExkauvvtpzW758OeCetTh9+nTq1avHjTfeyCuvvMLTTz/tGfg1nWreXK95u3btyp9//smXX34JQMOGDRk/fjzr1q1jzZo1dOvWLdtZSP5GtbGIiPitU6dgxAj38YABfjdxQYO2vtaqFcycCddc491esqS7/RL+0m6z2QgLC8vR0gg5VaVKFX766Sf+/PNPbrvtNqpVq8bAgQMpUaIE4F6YeerUqcyYMYMbbriBESNG8MYbb1z0upGRkfTv358HH3yQunXrEhUVxeeff57pvM6dO5OWlua12Vl2XC5XttPN/VmTJk2YM2cOCxYs4Oabb+aWW25hzJgxlC5d2uzQ/Jrdbqd8+fIBubu4yOVQzovfGDUKfv8dihSB117LtacJ5Fzv0aMHu3btIjU1lZUrV3o2XgVYvHgxU6dO9Tq/TZs2bN26ldTUVDZt2kTzc5YSSEhI8Gz2dP6tfv36gHtm4i+//MLx48c5c+YMf/zxB88//3y2G5CZQjWvz2rerBQqVIj27dszePBgXC4Xo0aNolSpUtx22208+OCDPPvss57N5fydauPMVAOIlSjfxa9NmACHDkHZstC+vU8u6ctctxlGVt9pyr+Sk5OJjY3lxIkTmXadPXv2LDt27KBMmTJXvvao0wk//wwHDsDVV7tnreTTNQunTp1K7969OX78+EXPHTZsGDNmzOC333676LkVKlSgS5cuPJsL6+oFEp/mpYiIyLn+/htuvBHOnoUPP4RHHsm1p7pQDSa5I7vXXDXv5cmtmlcujWpjERHxiZMnoUwZOHIEpkzx2dIIvqx5A28aY6AICoL/ny1xuQzD4NSpUxQoUMCnMw/McPLkSXbu3Mn48eN5+eWXL3juoUOH+O6779i6dSuNGjXKowjFbA6Hg1WrVlGzZs2AnGEtcqmU82I6w4AnnnAP2DZsCA8/nKtP53A4cvX6YhLVvF4upeYV61INIFaifBe/NX68e8D22mt9Wgf7subVvxg/ZhgGDocDwzACvoDt0aMHn332GS1btrzo18SaNm3Kv//+y5tvvkm1atXyKEIxm2EYHD58GItN/hcLU86L6b74Ar7/3r1214QJmdcm9THlumTHqjWvWJdqALES5bv4peRkGDnSfTxoEPjwDwq+zHUtKiJXrGPHjhf9mtjUqVNJTU3l888/J+giX5lbu3YtO3bs4KmnnvJhlCIiIuJx/Dj07u0+fuEFKF/ezGhEAoKva14RERExyVtvwbFj7hq4XTuzo8mWBm1FRERErObFF+HgQbj+enjuObOjERERERHJGydOQMbmo4MG+fVa/Bq09WM2m42IiIiA/5qYSE4EBQWRmJioWSliGcp5Mc3KlfDOO+7jiRMhLCxPnla5LtlRzStWoxpArET5Ln5n3Dj3t85uuAHatvX55X2Z61rTNgsul8vsEAB3ARuWR79Iif/yl3zMbXa7ndKlS5sdhkieUc6LKRwO6NrVvQlZ+/bQoEGePbXdrrkC/sZfagzVvHIp8sO6mKoBxEqU7+JX/v0XRo92H+fSLFtf1rwatD1HaGgodrud/fv3U7RoUUJDQ039i39+2klXLp1hGKSlpXH48GHsdjuhoaFmh5SrHA4HS5Ys4fbbb9euomIJynkxxbhxsGEDFCr039fC8ogvd9KVK6OaVwJVxoZGNpuNkJAQs8O5bKoBxEqU7+JXxoxxL49QqRK0bp0rT+HLmlf/Ys5ht9spU6YMBw4cYP/+/WaHg2EYnDlzRl8Xs7jIyEji4+Pz/QwlwzBISUnJF7MnRHJCOS95btcuGDjQffz661C0aJ4+vXLdf6jmlUBms9koWbJkQH/VWjWAWInyXfzGsWMwdqz7ePBgyKUxFl/mugZtzxMaGkp8fDwOhwOn02lqLOnp6Z6/SAXyX5Ll8gUFBREcHKxfYERE5MoYBjz1FJw+DbfdBp06mR2RmEw1rwSqkJCQgB6wFRERk4waBSkpULUq3Hef2dHkiAZts5DxdRuzi8agoCAcDgfh4eGmxyIiIiIBbPZs+PZbCAlxbz6Wz7+9ITmjmldEREQs4cgRePNN9/GQIQFTCwdGlBYVFBRE7dq19ZdksQTlu1iNcl7yTEqKe5YtQN++7p1yTaBcl+zo81CsRjkvVqJ8F7/wxhtw8iTcdBPcc0+uPpUvc10zbf2Y3W6nWLFiZochkieU72I1ynnJMy+9BPv2QdmyMGCAaWHk97XZ5fLp81CsRjkvVqJ8F9MdOgRvveU+HjIEcnn5SV/WvKqe/Vh6ejpz584lPT3d7FBEcp3yXaxGOS95Yu3a/4rUCRMgIsK0UJTrkh19HorVKOfFSpTvYrqRI937Otx8M7RoketP58tc16Ctn3M4HGaHIJJnlO9iNcp5yVVOJzz+OLhc8L//QZMmZkckki19HorVKOfFSpTvYpqDB+Htt93HeTDL1tc0aCsiIiKSH739Nvz6K8TGwpgxZkcjIiIiIpK3XnsNzpyBW26Bpk3NjuaSadBWREREJL/Zt++/9WtHjIDixc2NR0REREQkL+3fDxMnuo8DcJYtaNDWrwUHB9OgQQOCg7VfnOR/ynexGuW85KpevSAlxT2r4PHHzY4GQLku2dLnoViNcl6sRPkuphkxAs6ehbp14Y478uxpfZnrGrT1cxEmbhgikteU72I1ynnJFXPmwJdfQlAQTJoEPtzBViS36PNQrEY5L1aifJc8t3evuw4GGDo0IGfZggZt/ZrD4WDevHlatFssQfkuVqOcl1xx6hQ8+aT7uE8fqFLF3HjOoVyX7OjzUKxGOS9WonwXUwwfDmlpcPvt0KBBnj61L3Ndg7YiIiIi+cWQIbB7N5QuDYMGmR2NiIiIiEje2r0b3n3XfRzAs2xBg7YiIiIi+cNvv8Ho0e7j8eOhQAFz4xERERERyWuvvgrp6dCwIdSrZ3Y0V0SDtiIiIiKBzuWCrl3B6YT774e77jI7IhERERGRvLVzJ7z/vvt4yBBTQ/EFm2EYhtlB5KXk5GRiY2M5ceIEMTExZodzQYZh4HA4CA4OxhbA07lFckL5LlajnBefmjgRuneH6GjYvBmuucbsiDI5ceIEcXFxAVGD5ReBUvfq81CsRjkvVqJ8lzzVpYt70PaOO2DBAlNC8GXNq5m2fu7MmTNmhyCSZ5TvYjXKefGJgwfhuefcxy+/7JcDtiIXo89DsRrlvFiJ8l3yxPbtMHWq+zgfzLIFDdr6NYfDQVJSknZZFEtQvovVKOfFZ55+Gk6cgOrV4cknzY4mW8p1yY4+D8VqlPNiJcp3yTMvv+xeKqxpU6hd27QwfJnrGrQVERERCVQLFsD06WC3w+TJEBRkdkQiIiIiInlr2zb48EP3cT6ZZQsatBUREREJTGfOuNexBXjqKbjpJnPjERERERExw7Bh7o15W7SAmjXNjsZnNGjr54KDg80OQSTPKN/FapTzckVefhn+/tu9hu2wYWZHI3JF9HkoVqOcFytRvkuu2roVPvnEfZyPZtmCyYO2w4cP5+abbyY6OppixYrRsmVLtm7detHHzZgxgwoVKhAeHk7lypWZN29eHkSb90JCQmjRogUhISFmhyKS65TvYjXKebkif/wBI0e6j996C6KjzY0nB5Trkh19HorVKOfFSpTvkuuGDnXPsr33XvceDybzZa6bOmj7008/8eSTT/LLL7+wcOFC0tPTufPOOzl16lS2j1m+fDnt2rWjc+fOrFu3jpYtW9KyZUs2bdqUh5HnDZfLxaFDh3C5XGaHIpLrlO9iNcp5uWwuF3TtCunpcPfd0LKl2RHliHJdsqPPQ7Ea5bxYifJdctUff8Bnn7mPBw82NZQMvsx1Uwdt58+fT8eOHbnxxhupWrUqU6dOZffu3fz666/ZPmbcuHE0bdqUvn37UrFiRYYNG8ZNN93E+PHj8zDyvOF0OlmxYgVOp9PsUERynfJdrEY5L5dtyhRYuhQiI92zbG02syPKEeW6ZEefh2I1ynmxEuW75KohQ8AwoFUrSEw0OxrAtzWvXy0scuLECQAKFSqU7TkrVqygT58+Xm1NmjRh9uzZWZ6fmppKamqq5+fk5GQA0tPTSU9PB8ButxMUFITT6fQaEc9odzgcGIbhaQ8KCsJut2fbnnHdDBnrtzgcjhy1h4SE4HK5PNdJT0/HZrMRHByMy+XySoCM9uxi97c+ZRW7+qQ+ZfQJ8PQhv/QpP75P6pNv+pTVZ3yg9yk/vk9+16fDhwnu1w8bYAwZAvHxOM4731/7dH6/REREREQu28aNMGOG+3jQIHNjySV+M2jrcrno3bs3devWpVKlStmed/DgQa666iqvtquuuoqDBw9mef7w4cMZksVCxAsWLCAyMhKA+Ph4qlWrxm+//cbu3bs955QvX54KFSqwatUqDh8+7GlPTEykdOnSLFmyhJSUFE977dq1KVasGAsWLPD6pbBBgwZERERkWnu3efPmnDlzhqSkJE9bcHAwLVq04MiRI6xYsQKAhQsXEh0dTcOGDdmzZw/r16/3nF+0aFHq1KnDtm3bvNYD9vc+AeqT+uTVp+XLlwPufM8vfcqP75P65Ps+LVy4MN/1CfLf++Qvfao2bhzxx45xIiGByCef5ExKSsD0KeMP5yIiIiIiVyxjlm2bNlClitnR5Aqbce50CBN1796d7777jqVLl1KyZMlszwsNDWXatGm0a9fO0zZhwgSGDBnCP//8k+n8rGbalipViiNHjhATEwP476yf1NRUli9fTp06dQgJCTF9howv+uRvs37UJ//p09mzZ1m2bBl16tQhODg4X/QpP75P6pPv+pTVZ3yg9yk/vk/+1CfnDz8QfOedGDYbzp9/JqhOnYDq07FjxyhevDgnTpzw1GCSu5KTk4mNjfX719zhcLBkyRJuv/127TAulqCcFytRvkuuWL8eqlVzLxO2cSPceKPZEXkcO3aMwoUL+6T+8otB2x49evD111+zZMkSypQpc8Fz4+Pj6dOnD7179/a0DRo0iNmzZ7Nhw4aLPlegFK8iIiIiHqmp7hkEf/4JTzwBb79tdkSXTDVY3tNrLiIiIvnSfffB7Nnwv//9txGZn/Bl/WXqRmSGYdCjRw+++uorFi1adNEBW3B/ze7HH3/0alu4cCG1a9fOrTBN43K52LVrl3ZZFEtQvovVKOflkowY4R6wLV4cXn3V7Ggui3JdsqPPQ7Ea5bxYifJdfO7XX90Dtna7X65l68tcN3XQ9sknn+Tjjz/m008/JTo6moMHD3Lw4EHOnDnjOad9+/Y8//zznp979erF/PnzGTVqFFu2bGHw4MGsWbOGHj16mNGFXOV0Olm/fr12WRRLUL6L1SjnJcf+/PO/gdqxYyE21tRwLpdyXbKjz0OxGuW8WInyXXxu8GD3fx98ECpUMDWUrPgy100dtH3nnXc4ceIE9evX5+qrr/bcPv/8c885u3fv5sCBA56f69Spw6effsrkyZOpWrUqM2fOZPbs2RfcvExEREQkIBkGdO8OaWnQtCm0bWt2RCIiIiIi5li9GubMcc+yfekls6PJdaauAp2T5XQXL16cqa1Nmza0adMmFyISERER8SMffwyLFkF4uHsdW5vN7IhERERERMyRsRzCI4/A9debG0seMHWmrVyYzWajaNGi2PQLmliA8l2sRjkvF3X0KPTp4z4eOBDKljU3niukXJfs6PNQrEY5L1aifBef+eUX+O47CAry61m2vsx1m5GT6a75iHbRFRERkYDQpQu8/z7ceCOsXQuhoWZHdEVUg+U9veYiIiKSbzRpAgsWwKOPumtkP+XL+kszbf2Y0+lky5YtWrBbLEH5LlajnJcL+vnn/4rRSZMCfsAWtBGZZE+fh2I1ynmxEuW7+MSyZe4B2+BgGDDA7GguKN9sRCYX5nK52Lp1Ky6Xy+xQRHKd8l2sRjkv2UpLg27d3MddukDduubG4yPKdcmOPg/FapTzYiXKd/GJjLVsO3WCMmXMjeUifJnrGrQVERER8SdvvAF//AFFi8Jrr5kdjYiIiIiIeX76CX78EUJC4MUXzY4mT2nQVkRERMRfbN8Ow4a5j0ePhkKFzI1HRERERMRMGbNsu3SB0qXNjSWPadDWj9ntduLj47Hb9TZJ/qd8F6tRzksmhgFPPglnz0KjRvDQQ2ZH5FPKdcmOPg/FapTzYiXKd7kiSUnumbahofDCC2ZHkyO+zHWbYRiGz64WALSLroiIiPilzz+H//3PXZRu3AjXX292RD6lGizv6TUXERGRgGUYcPvtsHQp9OgBb71ldkQ54sv6S3/q8GNOp5N169Zpl0WxBOW7WI1yXrwcPw69ermPX3wx3w3Ygm930pX8RZ+HYjXKebES5btcth9/dA/YhoXB88+bHU2O+TLXNWjrx1wuF7t379Yui2IJynexGuW8eHnhBfjnHyhfHvr3NzuaXKFcl+zo81CsRjkvVqJ8l8tiGDBwoPu4WzcoUcLceC6BL3Ndg7YiIiIiZvrlF5g40X38zjvu2QQiIiIiIlb1/fewYgVERMBzz5kdjWk0aCsiIiJilvR06NrVPZugQwdo0MDsiEREREREzGMYMGiQ+7h7dyhe3Nx4TKRBWz9mt9spX768dlkUS1C+i9Uo5wWAcePgt9+gUCF44w2zo8lVynXJjj4PxWqU82Ilyne5ZPPmwapVEBkJ/fqZHc0l82WuB/vsSuJzQUFBVKhQwewwRPKE8l2sRjkv7Nr13yyCkSOhSBFz48llQUFBZocgfkqfh2I1ynmxEuW7XJJzZ9n26AFXXWVuPJfBlzWv/tThxxwOB8uXL8fhcJgdikiuU76L1SjnLc4w3IXo6dNw++3QqZPZEeU65bpkR5+HYjXKebES5btckm+/hV9/hQIFoG9fs6O5LL7MdQ3a+jHDMDh8+DCGYZgdikiuU76L1SjnLe6rr2DOHAgJcW9CZrOZHVGuU65LdvR5KFajnBcrUb5Ljp07y7Znz4D9Fpovc12DtiIiIiJ5KTnZXYiCe52uihXNjUdERERExGyzZ8P69RAdDc88Y3Y0fkGDtiIiIiJ56aWXYN8+KFcOXnzR7GhERERERMzlcv03y7ZXLyhc2Nx4/IQGbf1YUFAQiYmJ2rhDLEH5LlajnLeoX3+F8ePdx++8AxER5saTh5Trkh19HorVKOfFSpTvkiNffgkbN0JMDPTpY3Y0V8SXuR7ssyuJz9ntdkqXLm12GCJ5QvkuVqOctyCnE7p2dc8kaNcO7rjD7IjylN2uuQKSNX0eitUo58VKlO9yUU4nDB7sPn76aShY0NRwrpQva15Vz37M4XCwaNEi7bIolqB8F6tRzlvQ22+7Z9rGxsLo0WZHk+eU65IdfR6K1SjnxUqU73JRM2bAH39AXBz07m12NFfMl7muQVs/ZhgGKSkp2mVRLEH5LlajnLeYvXv/W7/2tdegeHFz4zGBcl2yo89DsRrlvFiJ8l0uyOmEIUPcx8884x64DXC+zHUN2oqIiIjktl694ORJqF0bHnvM7GhERERERMw3fTps2QKFCkHPnmZH43c0aCsiIiKSm779FmbNgqAgmDgRtLariIiIiFidw/HfLNtnn3VvQiZe9FuDHwsKCqJ27draZVEsQfkuVqOct4hTp6BHD/fxM89AlSrmxmMi5bpkR5+HYjXKebES5btk65NPYNs2KFLkv3o5H/Blrgf77Eric3a7nWLFipkdhkieUL6L1SjnLWLwYNi9G0qXhoEDzY7GVL7cSVfyF30eitUo58VKlO+SpfR0GDbMfdy3L0RHmxuPD/my5lX17MfS09OZO3cu6enpZocikuuU72I1ynkL2LABxoxxH7/9NhQoYG48JlOuS3b0eShWo5wXK1G+S5Y++gi2b4eiReHJJ82Oxqd8mesatPVzDofD7BBE8ozyXaxGOZ+POZ3Qtav7v61bQ4sWZkck4tf0eShWo5wXK1G+i5dzZ9n272/5iQ0XokFbEREREV+bPBlWrnR/1WvcOLOjERERERHxD1Onws6dcNVV0L272dH4NQ3aioiIiPjSgQPw/PPu41degRIlzI1HRERERMQfpKXByy+7j59/HiIjzY3Hz9kMwzDMDiIvJScnExsby4kTJ4iJiTE7nAsyDIOUlBSio6Ox2WxmhyOSq5TvYjXK+Xzsf/+Dzz+HGjXgl19AuyUDcOLECeLi4gKiBssvAqXu1eehWI1yXqxE+S5eJk50z669+mr3mrYREWZH5HO+rHk109bPReTDBBbJjvJdrEY5nw99/717wNZuh0mTNGArkkP6PBSrUc6LlSjfBYCzZ93fQgN44YV8OWDraxq09WMOh4N58+Zp0W6xBOW7WI1yPh86ffq/dbl69oSbbjI3Hj+jXJfs6PNQrEY5L1aifBeP996DvXuhZEno0sXsaHKNL3Ndg7YiIiIivvDyy7Bjh7sQHTrU7GhERERERPzDmTMwfLj7+IUXIDzc3HgChAZtRURERK7U77/DyJHu47feguhoc+MREREREfEXkyfD/v0QHw+PPmp2NAFDg7YiIiIiV8Llgm7dwOGAe+6Bli3NjkhERERExD+cPg0jRriPX3wRwsLMjSeA2AzDMMwOIi8Fyi664N5l0eFwEBwcrF0WJd9TvovVKOfzkfffd6/LVaAA/PGHewaBZOLLnXQlZwKl7tXnoViNcl6sRPkujB4NzzwDCQmwdSuEhpodUa7yZc2rmbZ+7syZM2aHIJJnlO9iNcr5fODQIejb1308dKgGbEUukz4PxWqU82IlyncLO3Xqv1m2L72U7wdsfU2Dtn7M4XCQlJSkXRbFEpTvYjXK+Xzi2Wfh338hMRF69jQ7Gr+mXJfs6PNQrEY5L1aifLe4CRPg8GEoWxYeecTsaPKEL3Ndg7YiIiIil2PRIvjoI7DZYNIkCA42OyIREREREf+QkgKvveY+HjgQQkLMjScAadBWRERE5FKdPevefAzgiSegZk1z4xERERER8Sfjx8PRo3DddfDQQ2ZHE5A0aOvngjVrRyxE+S5Wo5wPYCNGwLZtcPXV8MorZkcjEvD0eShWo5wXK1G+W1ByMrzxhvt44EB9I+0yadDWj4WEhNCiRQtCNIVcLED5LlajnA9gW7fC8OHu47FjITbW1HACRSDn+ttvv01CQgLh4eHUqlWLVatWXfD8GTNmUKFCBcLDw6lcuTLz5s3z3Jeenk7//v2pXLkyBQoUoESJErRv3579+/d7XePYsWM89NBDxMTEEBcXR+fOnTl58mSu9M9s+jwUq1HOi5Uo3y3qzTfh2DEoXx7atTM7mjzly1zXoK0fc7lcHDp0CJfLZXYoIrlO+S5Wo5wPUIbhXhYhLQ2aNYM2bcyOKGAEaq5//vnn9OnTh0GDBrF27VqqVq1KkyZNOHToUJbnL1++nHbt2tG5c2fWrVtHy5YtadmyJZs2bQLg9OnTrF27lpdeeom1a9cya9Ystm7dyj333ON1nYceeojff/+dhQsXMmfOHJYsWcLjjz+e6/01gz4PxWqU82IlyncLOnECRo1yHw8eDEFBpoaT13yZ6xq09WNOp5MVK1bgdDrNDkUk1ynfxWqU8wHqo49g8WKIiIC333ZvQiY5Eqi5Pnr0aB577DE6derEDTfcwMSJE4mMjOSDDz7I8vxx48bRtGlT+vbtS8WKFRk2bBg33XQT48ePByA2NpaFCxfStm1bypcvzy233ML48eP59ddf2b17NwCbN29m/vz5vPfee9SqVYtbb72Vt956i+nTp2eakZsf6PNQrEY5L1aifLegsWPh+HG44QZLTnDwZa5rUQkRERGRnDh6FJ55xn08cCCUKWNuPJLr0tLS+PXXX3n++ec9bXa7ncaNG7NixYosH7NixQr69Onj1dakSRNmz56d7fOcOHECm81GXFyc5xpxcXHUqFHDc07jxo2x2+2sXLmS++67L8vrpKamkpqa6vk5OTkZcC/JkJ6e7ok/KCgIp9PpNRMko93hcGAYhqc9KCgIu92ebXvGdTNkrFvocDhy1B4SEuKJI+NaNpuN4OBgXC6X1y8+Ge3Zxe5vfcoqdvVJfcroE/yX8/mlT/nxfVKffNMnyPwZH+h9yo/vk0/6dPgwQaNHYwMcAwZgt9mw/38sAdunK3yfroQGbUVERERyol8/OHIEKlX6b/BW8rUjR47gdDq56qqrvNqvuuoqtmzZkuVjDh48mOX5Bw8ezPL8s2fP0r9/f9q1a0dMTIznGsWKFfM6Lzg4mEKFCmV7HYDhw4czZMiQTO0LFiwgMjISgPj4eKpVq8Zvv/3mmdkLUL58eSpUqMCqVas4fPiwpz0xMZHSpUuzZMkSUlJSPO21a9emWLFiLFiwwOsXqAYNGhAREeG1ji9A8+bNOXPmDElJSV59atGiBUePHgVg4cKFAERHR9OwYUP27NnD+vXrPecXLVqUOnXqsG3bNrZu3epp98c+HTlyxGtgX31Sn87t0/Lly4H/cj4/9Ck/vk/qk2/6tG/fPuC/fM8PfcqP75Ov+nSgXz9KJSdzonRpFoeHk7hnT8D36VLfp59//hlfsRnnDgtbQHJyMrGxsZw4ccJTGPsrh8PBkiVLuP3227XbouR7ynexGuV8gPn5Z7j9dvfxsmVQp4658QSgY8eOUbhw4YCowTLs37+fa665huXLl1O7dm1Pe79+/fjpp59YuXJlpseEhoYybdo02p2z6caECRMYMmQI//zzj9e56enp3H///ezdu5fFixd7XpdXX32VadOmef1CAVCsWDGGDBlC9+7ds4w3q5m2pUqV4siRI55r++Osn7S0NH7++Wfq1KlDcHCw382QyY+zftQnc/t09uxZli1b5sn5/NCn/Pg+qU++6VN2n/GB3Kf8+D75pE/Hj2OUKYMtJQXH559j3Hdf4PfpMt6nf/75h+LFi/uk5tVviX4sODiYhg0bmh2GSJ5QvovVKOcDSFoadO3qPn7sMQ3YXqZA/ONEkSJFCAoKyjTYmlGMZ6V48eI5Oj89PZ22bduya9cuFi1a5FXUFy9ePNNGZw6Hg2PHjmX7vABhYWGEhYVlag8JCcm0k3FQUBBBWWwMkt37lF17djskX0p7aGgojRo1ytRut9ux2zNvwZFd7P7Up+xiV5/UJ4Dw8PAscz6Q+5Qf3yf1yTd9yu4zPpD7lB/fJ5/0adQobCkpkJhIcOvWcM71ArZPXPr7FBERkWX75dBGZH7M5XKxa9cu7bIolqB8F6tRzgeQN96AzZuhWDEYMcLsaAJWIOZ6aGgo1atX58cff/S0uVwufvzxR6+Zt+eqXbu21/ng/krouednDNhu27aNH374gcKFC2e6xvHjx/n11189bYsWLcLlclGrVi1fdM2v6PNQrEY5L1aifLeII0fgzTfdx4MHew3YWo0vc926r2IAcDqdrF+/XrssiiUo38VqlPMBYvt2GDbMfTx6NBQqZG48ASxQc71Pnz68++67TJs2jc2bN9O9e3dOnTpFp06dAGjfvr3XRmW9evVi/vz5jBo1ii1btjB48GDWrFlDjx49APeAbevWrVmzZg2ffPIJTqeTgwcPcvDgQdLS0gCoWLEiTZs25bHHHmPVqlUsW7aMHj168L///Y8SJUrk/YuQy/R5KFajnBcrUb5bxMiRcOoU3HQT3HOP2dGYype5HnjfUxMRERHJC4YBTzwBZ89C48bw4INmRyQmeOCBBzh8+DADBw7k4MGDJCYmMn/+fM9mY7t37/b66l2dOnX49NNPGTBgAC+88ALXXXcds2fPplKlSgDs27ePb775BnBvZHGupKQk6tevD8Ann3xCjx49aNSoEXa7nfvvv583M2awiIiIiPiLQ4dg/Hj38dChYLOZG08+okFbERERkax8/jksWABhYTBhggpQC+vRo4dnpuz5Fi9enKmtTZs2tGnTJsvzExISyMk+wIUKFeLTTz+9pDhFRERE8tzrr8Pp01CzJjRvbnY0+YqWR/BjNpuNokWLYtMviWIBynexGuW8nzt+HHr3dh+/+CJcd52Z0eQLynXJjj4PxWqU82Ilyvd87uBB9+QGgCFDNMkB39a8NiMnf+rPR5KTk4mNjeXEiRNeu/SKiIiIeHTvDhMnQvnysGGDe7atXBHVYHlPr7mIiIjkqqefhrFj4ZZbYPlyDdri2/pLM239mNPpZMuWLVqwWyxB+S5Wo5z3YytWwKRJ7uOJEzVg6yPKdcmOPg/FapTzYiXK93xs/3545x33sday9fBlrmvQ1o+5XC62bt2Ky+UyOxSRXKd8F6tRzvup9HTo2tW9CVnHjvD/m0LJlVOuS3b0eShWo5wXK1G+52PDh0NqKtx6q3vTXgF8W/Nq0FZEREQkw9ixsHEjFC4MI0eaHY2IiIiIiP/ZuxcmT3Yfay3bXKNBWxERERGAnTth8GD38ciRUKSImdGIiIiIiPinV1+FtDSoVw8aNDA7mnxLg7Z+zG63Ex8fj92ut0nyP+W7WI1y3s8YBvToAadPw+23u5dGEJ9Srkt29HkoVqOcFytRvudDu3bBe++5jzXLNhNf5rrNMAzDZ1cLANpFV0RERDL58kto3RpCQuC336BCBbMjyndUg+U9veYiIiLic127updGaNgQfvzR7Gj8ji/rL/2pw485nU7WrVunXRbFEpTvYjXKeT+SnAw9e7qP+/fXgG0uUa5LdvR5KFajnBcrUb7nMzt2wAcfuI+HDDE3Fj/ly1zXoK0fc7lc7N69W7ssiiUo38VqlPN+5KWXYP9+uPZaeOEFs6PJt5Trkh19HorVKOfFSpTv+cwrr4DDAXfcAbfeanY0fsmXua5BWxEREbGuNWvgrbfcx++8AxER5sYjIiIiIuKPtm+HqVPdx5plmyc0aCsiIiLW5HC41+QyDHjwQWjc2OyIRERERET807Bh4HRCs2ZQu7bZ0ViCBm39mN1up3z58tplUSxB+S5Wo5z3A2+/DWvXQlwcjB5tdjT5nnJdsqPPQ7Ea5bxYifI9n9i2DT76yH08eLCpofg7X+Z6sM+uJD4XFBREBW2GIhahfBerUc6bbO9eGDDAffzaa3DVVebGYwFBQUFmhyB+Sp+HYjXKebES5Xs+MXQouFxw111Qs6bZ0fg1X9a8+lOHH3M4HCxfvhyHw2F2KCK5TvkuVqOcN1nPnnDyJNSpA126mB2NJSjXJTv6PBSrUc6LlSjf84EtW+DTT93HWsv2onyZ6xq09WOGYXD48GEMwzA7FJFcp3wXq1HOm+ibb+CrryA4GCZOBH1dL08o1yU7+jwUq1HOi5Uo3/OBjFm2994LN91kdjR+z5e5rt9SRERExDpOnoQePdzHzzwDlSubG4+IiIiIiL/6/XeYPt19rLVs85wGbUVERMQ6Bg+GPXsgIQEGDjQ7GhERERER/zV0KBgGtGoFiYlmR2M5GrT1Y0FBQSQmJmrjDrEE5btYjXLeBOvXw9ix7uO334bISDOjsRzlumRHn4diNcp5sRLlewDbuBG++MJ9rFm2OebLXA/22ZXE5+x2O6VLlzY7DJE8oXwXq1HO5zGnE7p2df+3TRto3tzsiCzHrrWDJRv6PBSrUc6LlSjfA1jGQG3btlpS7BL4suZV9ezHHA4HixYt0i6LYgnKd7Ea5XwemzQJVq2CmJj/ZttKnlKuS3b0eShWo5wXK1G+B6j162HWLLDZYNAgs6MJKL7MdQ3a+jHDMEhJSdEui2IJynexGuV8HjpwAJ5/3n38yitQooS58ViUcl2yo89DsRrlvFiJ8j1AZcyy/d//4IYbTA0l0Pgy1zVoKyIiIvlb796QnAw33wzdu5sdjYiIiIiI//r1V/j6a7DbtXGvyTRoKyIiIvnX/PnuDRTsdvcSCdoEQ0REREQkexmzbB98ECpUMDUUq9OgrR8LCgqidu3a2mVRLEH5LlajnM8Dp0/DE0+4j3v1gmrVzI3H4pTrkh19HorVKOfFSpTvAWbVKpgzxz3RQbNsL4svcz3YZ1cSn7Pb7RQrVszsMETyhPJdrEY5nwdefhl27IBSpWDoULOjsTxf7qQr+Ys+D8VqlPNiJcr3AJOx6djDD8N115kbS4DyZc2r6tmPpaenM3fuXNLT080ORSTXKd/FapTzuWzTJhg50n381lsQFWVuPKJcl2zp81CsRjkvVqJ8DyArVriXFgsKgpdeMjuagOXLXNegrZ9zOBxmhyCSZ5TvYjXK+VzickG3buBwwL33um8i4tf0eShWo5wXK1G+B4iMWbYdO0K5cqaGIm4atBUREZH85YMPYNkyKFDAPctWRERERESyt3QpLFwIwcEwYIDZ0cj/06CtiIiI5B+HDkG/fu7jYcPc69mKiIiIiEj2MmbZPvooJCSYGor8x2YYhmF2EHkpOTmZ2NhYTpw4QUxMjNnhXJBhGKSkpBAdHY3NZjM7HJFcpXwXq1HO55JHHoGPP4bERFi92j1bQPzCiRMniIuLC4gaLL8IlLpXn4diNcp5sRLlewD46SeoXx9CQuCvvyA+3uyIApova17NtPVzERERZocgkmeU72I1ynkf+/FH94CtzQaTJmnAViSA6PNQrEY5L1aifPdjhgEDB7qPu3TRgK2f0aCtH3M4HMybN0+LdoslKN/FapTzPnb2LHTv7j5+8kmoWdPceCQT5bpkR5+HYjXKebES5bufS0qCJUsgNBReeMHsaPIFX+a6Bm1FREQk8A0fDtu2wdVXw8svmx2NiIiIiIh/O3eWbdeuULKkufFIJhq0FRERkcC2ZQuMGOE+HjcOYmPNjUdERERExN/98AMsWwbh4fDcc2ZHI1nQoK2IiIgELsOAbt0gLQ2aN4fWrc2OSERERETEv507y7ZbNyhRwtx4JEs2wzAMs4PIS4Gyiy64d1l0OBwEBwdrl0XJ95TvYjXKeR+ZNg06doSICPjjD0hIMDsiyYYvd9KVnAmUulefh2I1ynmxEuW7n5o/H5o1c9fQf/8NxYubHVG+4cuaVzNt/dyZM2fMDkEkzyjfxWqU81fo6FF49ln38aBBGrAVCWD6PBSrUc6LlSjf/cy5s2yfeEIDtn5Mg7Z+zOFwkJSUpF0WxRKU72I1ynkf6NcPjhyBSpWgTx+zo5GLUK5LdvR5KFajnBcrUb77oblzYfVqiIx019PiU77MdQ3aioiISOBZsgQ++MB9PGkShISYG4+IiIiIiL8zDPc31AB69IBixcyNRy5Ig7YiIiISWNLS3BsmADz+ONSpY248IiIiIiKB4JtvYO1aiIqCvn3NjkYuwtRB2yVLlnD33XdTokQJbDYbs2fPvuD5ixcvxmazZbodPHgwbwI2QXBwsNkhiOQZ5btYjXL+Mo0cCZs3u2cGjBhhdjQi4gP6PBSrUc6LlSjf/YTL9d8s2549oUgRc+ORizL1X86pU6eoWrUqjz76KK1atcrx47Zu3eq1A1uxfDqdOyQkhBYtWpgdhkieUL6L1SjnL9Nff8GwYe7jMWOgYEFz45EcC9ESFpINfR6K1SjnxUqU735k9mzYsAGio+GZZ8yOJt/yZc1r6qBts2bNaNas2SU/rlixYsTFxfk+ID/jcrk4cuQIRYoUwW7XShaSvynfxWqU85fBMNw73KamQuPG0K6d2RHJJXC5XGaHIH5Kn4diNcp5sRLlu584d5Zt795QqJCp4eRnvqx5A/JfTGJiIldffTV33HEHy5YtMzucXON0OlmxYgVOp9PsUERynfJdrEY5fxmmT4eFCyEsDN55B2w2syOSS6Bcl+zo81CsRjkvVqJ89xNffgmbNkFsLDz9tNnR5Gu+zPWAWljk6quvZuLEidSoUYPU1FTee+896tevz8qVK7npppuyfExqaiqpqamen5OTkwFIT08nPT0dALvdTlBQEE6n02tEPKPd4XBgGIanPSgoCLvdnm17xnUzZKzf4nA4ctQeEhKCy+XyXCc9PR2bzUZwcDAul8srATLas4vd3/qUVezqk/qU0SfA04f80qf8+D6pT77pU1af8YHep1x9nw4fJqh3b2yA8/nnsZUtix0Cu0/58X26QJ/O75eIiIiI5AGnEwYPdh8//bSWFwsgATVoW758ecqXL+/5uU6dOmzfvp0xY8bw0UcfZfmY4cOHM2TIkEztCxYsIDIyEoD4+HiqVavGb7/9xu7du72er0KFCqxatYrDhw972hMTEyldujRLliwhJSXF0167dm2KFSvGggULvH6BatCgAREREcybN88rhubNm3PmzBmSkpI8bcHBwbRo0YIjR46wYsUKABYuXEh0dDQNGzZkz549rF+/3nN+0aJFqVOnDtu2bWPr1q2edn/vE6A+qU9efVq+fDngzvf80qf8+D6pT77v08KFC/Ndn8D379M/jz7KNYcOkVKyJItvvJFaR44EfJ/y4/t0oT5l/OFcRERERPLQF1/AH39AXJx7aQQJGDbj3OkQJrLZbHz11Ve0bNnykh7Xt29fli5d6vWLybmymmlbqlQpjhw54tnMzF9nyKSmprJ8+XLq1KlDSEiI6TNkfNEnf5v1oz75T5/Onj3LsmXLqFOnDsHBwfmiT/nxfVKffNenrD7jA71PufY+rVwJdeq4Y/rhB4zbbw/8PuXH9+kifTp27BjFixfnxIkTXhvKSu5JTk4mNjbW719zh8PBkiVLuP3227XDuFiCcl6sRPluMqcTbrwRtm51b+Y7YIDZEeV7x44do3Dhwj6pvwJ+0PaOO+4gOjqaWbNm5ej8QCleRUREBEhPh5tucq/B1akTfPCB2RHJZVINlvf0mouIiFjcxx/DI4+4Nx7bsQNUD+Q6X9Zfpm5EdvLkSdavX+/5mt6OHTtYv36952t4zz//PO3bt/ecP3bsWL7++mv++usvNm3aRO/evVm0aBFPPvmkGeHnOpfLxa5du7TbsliC8l2sRjmfQ2PGuAdsCxeG1183Oxq5Asp1yY4+D8VqlPNiJcp3EzkckLFcaN++GrDNI77MdVMHbdesWUO1atWoVq0aAH369KFatWoMHDgQgAMHDnito5aWlsYzzzxD5cqVqVevHhs2bOCHH36gUaNGpsSf25xOJ+vXr9cui2IJynexGuV8Duzc+d+mCW+8AUWKmBmNXCHlumRHn4diNcp5sRLlu4k++QT++stdQ/foYXY0luHLXDd1QZH69etzodUZpk6d6vVzv3796NevXy5HJSIiIqYzDHjySThzBurVgw4dzI5IRERERCQwpKfD0KHu4379ICrK3Hjkspg601ZEREQkS19+CfPm/R979x0eVbX1cfw7M+mU0HtTRLBhEBVBacoVBQsXUbABFq6Nbm8oNuyComKvcC1cLK8ggnQE4UoR9QqC0puEkhAISaa8f2zSIAkpZ+bMzPl9nicPMzuTmXUyy+POmn3WhthYmDABXC67IxIRERERiQwffgh//QV16sDtt9sdjZSTirZhzOVyUbt2bVz6Q1UcQPkuTqOcL0F6OgwbZm7fdx+0amVvPGIJ5boUR+dDcRrlvDiJ8t0G2dnwxBPm9r33QqVK9sbjMFbmuitQUn+CKKRddEVERMLc0KHwyitwwgnwyy+QkGB3RGIBzcFCT79zERERB3rzTbjlFqhXD/78E5KS7I7IUaycf2mlbRjz+XysXr1aDbvFEZTv4jTK+WL8978wfry5/frrKthGEeW6FEfnQ3Ea5bw4ifI9xLKy4Mknze377lPB1gZW5rqKtmHM7/ezZs0a/H6/3aGIBJ3yXZxGOV8Er9esCggE4NproVs3uyMSCynXpTg6H4rTKOfFSZTvIfbuu7BpEzRoAP/6l93ROJKVua6irYiIiISH8eNhxQqoXh1efNHuaEREREREIsehQ/mrbB94ABIT7Y1HKkxFWxEREbHf5s3w0EPm9jPPmJ1uRURERESkdN5+G7ZuhUaN4Oab7Y5GLKCibRhzu900adIEt1tvk0Q/5bs4jXL+CEOHwoEDcO65cNNNdkcjQaBcl+LofChOo5wXJ1G+h0hmJjz1lLn94IMQH29vPA5mZa67AoFAwLJniwDaRVdERCTMfPUV9OoFMTGmPcKpp9odkQSB5mChp9+5iIiIQ4wdCyNGQJMmsHYtxMXZHZFjWTn/iinPD61fv54FCxawceNGDh48SO3atWnTpg3t27cnQbs8W8bn87Fq1Spat26Nx+OxOxyRoFK+i9Mo5w/LyIAhQ8ztu+5SwTaKhXLXaM1VI4vOh+I0ynlxEuV7CBw8CE8/bW4/9JAKtjazcs5bpqLtxIkTGTduHD/99BN169alQYMGJCYmsmfPHv78808SEhK49tpruffee2natKllQTqV3+9n06ZNnHrqqTq5SdRTvovTKOcPe+QR08/2uOPg4YftjkaCKBS7RmuuGpl0PhSnUc6LkyjfQ+D112HnTjOfHjjQ7mgcz8o5b6mLtm3atCEuLo6BAwfyn//8h8aNGxf6flZWFosXL+aTTz7hzDPP5LXXXuPKK6+0LFARERGJMitWwLhx5varr0JSkr3xSETTXFVEREQc58ABs4kvmFW2sbH2xiOWKnXR9umnn6Z79+7Ffj8+Pp4uXbrQpUsXnnzySTZs2GBFfCIiIhKNfD645Rbz71VXwcUX2x2RRDjNVUVERMRxXn0Vdu2C5s3h+uvtjkYsVuqibUmT4CPVrFmTmjVrlisgyed2u2nZsqV2WRRHUL6L0zg+5ydMgP/+F6pWNRsnSNQLdq5rrhq5HH8+FMdRzouTKN+DaP9+ePZZc3vUKK2yDRNW5nq5nmn58uX88ssvefe/+uorevXqxQMPPEB2drZlwTmdx+OhVatW6vsijqB8F6dxdM5v2wYPPGBuP/UU1K9vbzwSEqHMdc1VI4ujz4fiSMp5cRLlexCNHw+7d0OLFnDNNXZHI4dZmevlKtrecsst/PHHHwD89ddf9OvXj6SkJD7//HPuuecey4JzOq/Xy6JFi/B6vXaHIhJ0yndxGkfn/IgRkJ4OZ58Nt95qdzQSIqHMdc1VI4ujz4fiSMp5cRLle5Ckp8Nzz5nbjzwCMaW+kF6CzMpcL1fR9o8//iAlJQWAzz//nE6dOjFp0iTef/99/vOf/1gWnNMFAgF27dpFIBCwOxSRoFO+i9M4Nue//RY++ww8HnjjDfOvOEIoc11z1cji2POhOJZyXpxE+R4k48bB3r3QqhX062d3NFKAlblerqJtIBDA7/cD8P3339OjRw8AGjduTGpqqmXBiYiISBQ5eBBuv93cHjYMDhfVRKymuaqIiIhErX374MUXze1HHtEiiChWrqLtmWeeyRNPPMFHH33EvHnz6NmzJwDr16+nbt26lgYoIiIiUeLxx2HDBmjcGEaPtjsaiWKaq4qIiEjUGjvWFG5POQWuvNLuaCSIylW0HTt2LMuXL2fw4ME8+OCDnHDCCQBMnjyZDh06WBqgk3k8HlJSUtSwWxxB+S5O47ic//VXeP55c3v8eKhc2d54JORCmeuaq0YWx50PxfGU8+IkyneL7d0LL71kbmuVbViyMtddAQubLRw6dAiPx0NsbKxVT2m59PR0kpOTSUtLo2rVqnaHIyIiEv38fujYERYtgl694Isv7I5IbBAOc7BImKtaKRx+5yIiImKhhx+GJ56A006DlSvBXa61mBJEVs6/LH13ExISHDMJDgWv18vs2bO1y6I4gvJdnMZROf/OO6ZgW7kyvPyy3dGITcIh1zVXDU+OOh+KoJwXZ1G+W2j3btMaAUyrMRVsw5KVuR5T2gdWr14dl8tVqsfu2bOn3AFJvkAgwP79+7XLojiC8l2cxjE5v3Mn3HOPuf3446afrThSsHNdc9XI5ZjzochhynlxEuW7hV54ATIyzGa+vXrZHY0Uw8pcL3XRdmxuNR/YvXs3TzzxBN27d6d9+/YALF68mO+++46HH37YsuBEREQkwt15p9kooU0bGDzY7mjEJj4fLFxYuoJqeWmuKiIiIlFr1678K9ZGj4ZSflAtka3URdsBAwbk3b7iiit47LHHGFzgj6+hQ4cyfvx4vv/+e0aMGGFtlCIiIhJ5vv8eJk40l269+SbElHraIVFkyhQYNgy2bAnu+6+5qoiIiESt556DAwegbVu49FK7o5EQKddGZJUrV2blypV5O/HmWrduHSkpKWRkZFgWoNUiaUMGv99PamoqtWrVwq1eJRLllO/iNFGf84cOmQ0S1q2DIUPUy9ahpkyBPn3AzDbTgdDMwSJ5rmqlSJn3Rv35UOQIynlxEuW7BXbuhOOPh4MH4ZtvoGdPuyOSEuzbt4/q1avbtxFZzZo1+eqrr44a/+qrr6hZs2aFApJ8brebOnXq6MQmjqB8F6eJ+px/6ilTsG3QwOxwK47j85kVtna0sNNcNbJE/flQ5AjKeXES5bsFnn3WFGzPPht69LA7GjkGK3O9XNepjR49mptvvpm5c+fSrl07AJYsWcL06dN56623LAvO6XJycpgxYwYXXnihdjqWqKd8F6eJ6pxfvRqeftrcHjcOwniFnwTPggWwZYs9r625amSJ6vOhSBGU8+IkyvcK2r4dXnvN3FYv24iQk5Nj2XOVq2g7cOBATjrpJF5++WWmTJkCwEknncTChQvzJsZiDa/Xa3cIIiGjfBenicqcDwTg1lshJ8dcunXFFXZHJDbZvt2+19ZcNfJE5flQpATKeXES5XsFPPOMaTvWvj107253NBJi5d4Rol27dkycONHKWERERCTSffABzJsHiYkwfrxWAziY3V0INFcVERGRiLZ1K0yYYG4/9pjm1Q5U7kYLfr+fP/74g4ULFzJ//vxCXyIiIuJAqalw113m9qOPQrNmdkYjNlq7Fu69194YrJyrvvrqqzRr1oyEhATatWvH0qVLS3z8559/TqtWrUhISOC0005j2rRphb4/ZcoULrzwQmrWrInL5WLlypVHPUeXLl1wuVyFvm699dYyxy4iIiIR6umnISsLzjsPLrjA7mjEBuVaafvjjz9yzTXXsHHjRgJH7C7hcrnw+XyWBOd0MTExdO3alZiYci+IFokYyndxmqjM+Xvugd274bTTYMQIu6MRm3z8Mdx2G2RkQOXK5l+XK7Qbklk5V/30008ZOXIkEyZMoF27dowdO5bu3buzZs0a6tSpc9TjFy1axNVXX82YMWO45JJLmDRpEr169WL58uWceuqpABw4cIDzzjuPq666ikGDBhX72oMGDeKxxx7Lu5+UlFTquCNJVJ4PRUqgnBcnUb6X0+bN8Oab5rZW2UYUK3PdFThyJlsKKSkpnHjiiYwePZr69evjOiJ5kpOTLQvQaunp6SQnJ5OWlkbVMN8YJRAI4PV6iYmJOep3LBJtlO/iNFGX8/PmQZcuZkL5ww+m75Y4SkYG3HEHfPihud+5M0ycCEuWwLBhsGVLOhCaOZiVc9V27dpx1llnMX78eMCs4G3cuDFDhgzhvvvuO+rxffv25cCBA3zzzTd5Y+eccw4pKSlMyL3E8bANGzZw3HHHsWLFClJSUgp9r0uXLqSkpDB27NhSx3qkSJn3Rt35UOQYlPPiJMr3crrtNtMaoUsXmDPH7mikDNLS0qhWrZol869ylX/Xrl3L5MmTOeGEEyr04lIyr9fLtGnT6NGjh3ZZlKinfBeniaqcz8oym48B/OtfKtg60IoV0LevaYvgdsMjj8CDD4LHA717w+WXwzffeOnVKzTxWDVXzc7OZtmyZdx///15Y263m27durF48eIif2bx4sWMHDmy0Fj37t358ssvy/z6EydO5OOPP6ZevXpceumlPPzwwyWuts3KyiIrKyvvfnp6OmB2Mc7dydjtduPxePD5fPj9/kLH5fF48Hq9hVYnezwe3G53seNH7pCcu7rkyE1nihuPjY0lOzub6dOn849//IPY2FhcLhcxMTH4/f5Cq6Jzx4uLPZyOqbjYdUw6JrfbzaFDh5gxY0ZezkfDMUXj+6RjsuaYijvHR/IxBf19+vNPPO+8gwvwPvwwbr8/8o8pGt+nYsYzMzOxSrmKtu3atWPdunUq2oqIiAg89xysXg1168KYMXZHIyEUCMArr8Ddd0N2NjRqZFbXdupU+HEeD5x3Xuj6I1g1V01NTcXn81G3bt1C43Xr1mX16tVF/syOHTuKfPyOHTvK9NrXXHMNTZs2pUGDBqxatYp7772XNWvWMGXKlGJ/ZsyYMYwePfqo8RkzZuQVe5s0aUKbNm1YtWoVmzZtyntMy5YtadWqFUuXLmXXrl154ykpKTRt2pT58+ezf//+vPH27dtTp04dZsyYUegPqK5du5KYmHhUH98ePXqQmZnJnAKrhWJiYujZsye7d+8GYObMmQBUqVKF888/n82bNxfq91u7dm06dOjA2rVrWbNmTd54OB5TampqocK+jknHVPCYFi1aBOTnfDQcUzS+Tzoma45p69atQH6+R8MxBft92jlsGA1zctjVujWL9u+nfWpqxB9TNL5PxR3TggULsEq52iN88cUXPPTQQ9x9992cdtppR60Qat26tWUBWi1SLhMDsyoialZhiRyD8l2cJmpyfu1a08M2KwsmTYKrr7Y7IgmR3bvhhhvg//7P3L/sMnj3XahZs7jH76ZWrVohmYNZNVfdtm0bDRs2ZNGiRbQvsIL8nnvuYd68eSxZsuSon4mLi+ODDz7g6gL/Lbz22muMHj2anTt3FnpsSe0RjjR79mwuuOAC1q1bR/PmzYt8TFErbRs3bkxqamre7zwcV8hkZWVppa2OyVHHlJmZqZW2OibHHFNx5/hIPqagvk8bNxI48URcXi/euXMJdOgQ+ccUje9TCce0Y8cO6tevb197hCuuuAKAG2+8MW/M5XIRCAS0EZmIiIhTBAJw++2mYPuPf0C/fnZHJCEyfz5ccw1s3QpxcfD88zB4cPjskWHVXLVWrVp4PJ6jiq07d+6kXr16Rf5MvXr1yvT40mrXrh1AiUXb+Ph44uPjjxqPjY09qnDt8XjweDxHPba4zTOKGy/uQ6eyjLvd7iLjdLvded8rqLjYw+2Yiopdx6RjKjh+ZM5HwzGVJsayjuuYIv+Ycn+m4M9F+jEF7X164glcXi9ceCExnTuXO/bixp2We+F0TOVRrqLt+vXrLQtAihcTE0OPHj2KTQSRaKJ8F6eJipz/97/h++8hIQFefz18KnYSND4fPPEEPPYY+P1w4onwySfQps2xfzaUuW7VXDUuLo62bdsya9Yseh1uyOv3+5k1axaDBw8u8mfat2/PrFmzGD58eN7YzJkzC63ULY/cy/vq169foecJR1FxPhQpA+W8OInyvQzWrYMPPjC3i2h3JJHBylwv1zM1bdrUsgCkZJmZmVSpUsXuMERCQvkuThPROb93L4wYYW4/9BAUs/JPoseWLXDddTBvnrk/YACMHw+VK9sbV1GsnKuOHDmSAQMGcOaZZ3L22WczduxYDhw4wA033ABA//79adiwIWMO93MeNmwYnTt35oUXXqBnz5588skn/PTTT7z55pt5z7lnzx42bdrEtm3bAPL6qtWrV4969erx559/MmnSJHr06EHNmjVZtWoVI0aMoFOnTmHdhqwiIvp8KFIOynlxEuV7KT3xhPmE/OKL4Zxz7I5GwsDR64RL6c8//2TIkCF069aNbt26MXToUP78808rY3M8r9fLnDlzjurZIRKNlO/iNBGf8/fdB3//DSedZHahkqj2zTeQkmIKtpUqwYcfwvvvl61gG+pct2qu2rdvX55//nlGjRpFSkoKK1euZPr06XmbjW3atInt27fnPb5Dhw5MmjSJN998k9NPP53Jkyfz5Zdfcuqpp+Y95uuvv6ZNmzb07NkTgH79+tGmTRsmTJgAmBW+33//PRdeeCGtWrXizjvv5IorruD/chsIR5mIPx+KlJFyXpxE+V5Kf/wBH31kbmuVbUSzMtfLtdL2u+++47LLLiMlJYVzzz0XgB9++IFTTjmF//u//+Mf//iHZQGKiIhImFm0CHJXDU6YYJqaSlTKyoJ774Vx48z9Nm1MO4QTT7Q3rmOxeq46ePDgYtshzJ0796ixK6+8kiuvvLLY5xs4cCADBw4s9vuNGzdmXu6SZhEREYl+ub2nLr0UzjrL7mgkTJSraHvfffcxYsQInn766aPG7733XhVtRUREolVODtxyi7l9443QqZO98UjQrF0LffvCihXm/vDh8PTTUMQ+V2FHc1URERGJGL//bvaKAHj0UVtDkfBSrvYIv//+OzfddNNR4zfeeCP/+9//KhyU5FOzbnES5bs4TUTm/Isvwq+/Qq1a8OyzdkcjQfLxx3DGGaZgW7Mm/N//wUsvRUbBFjRXjUQReT4UqQDlvDiJ8v0YclfZ9uplJmAih5WraFu7du28HWwLWrlyJXXq1KloTHJYbGwsPXv2JDY21u5QRIJO+S5OE5E5v359fo+t55831TyJKhkZZoOx6683tzt3hp9/hksuqfhzhzLXNVeNLBF5PhSpAOW8OIny/Rh++w0+/dTc1irbqGBlrpfr445Bgwbxr3/9i7/++osOHToApk/YM888w8iRIy0Lzun8fj+pqanUqlULt7vce8aJRATluzhNxOV8IACDB0NmJnTpAv372x2RWGzFCujXz+yD4XbDI4/Agw+Cx2PN8/v9fmueqBQ0V40sEXc+FKkg5bw4ifL9GEaPNvPsK66A00+3OxqxgJVz3nL9F/Pwww8zatQoXnnlFTp37kznzp0ZP348jz76KA899JBlwTmdz+dj8eLF+Hw+u0MRCTrluzhNxOX85MkwbZrZdGzCBHC57I5ILBIIwCuvwDnnmIJto0YwZw6MGmVdwRYIaa5rrhpZIu58KFJBynlxEuV7CVatgs8/N/NqrbKNGlbmerlW2rpcLkaMGMGIESPYv38/AFWqVLEsKBEREQkjaWkwbJi5fd990LKlvfGIZXbvNvvJff21uX/ZZfDuu5Hf+UJzVREREQl7uW3HrrwSTj3V3lgkLJWraLt+/Xq8Xi8tWrQoNAFeu3YtsbGxNGvWzKr4RERExG4PPQTbt0OLFnD//XZHIxaZPx+uvRa2bDELqJ9/3nTAiIZF1JqrioiISFhbsQKmTDETr0cesTsaCVPlao8wcOBAFi1adNT4kiVLGDhwYEVjksNcLhdVqlTBFQ1/PYkcg/JdnCZicn7pUnj1VXN7wgRISLA3Hqkwn88s7Oja1RRsTzwRfvwRhgwJbsE2lLmuuWpkiZjzoYhFlPPiJMr3YuS2Q7j6ajj5ZFtDEWtZmeuuQCAQKOsPVa1aleXLl3PCCScUGl+3bh1nnnkm+/btsyo+y6Wnp5OcnExaWhpVq1a1OxwREZHw5fXCWWfBypVw3XXw0Ud2RyQVtGWLeSvnzTP3BwyA8eOhcuXgv3Yo52CRPFe1kua9IiIiYWjZMjjzTLPz6//+p9ZjUcbK+Ve5Vtq6XK68/mAFpaWlqbm0hfx+Pxs3bgzpbssidlG+i9NERM6/8oop2FavDi+8YHc0UkHffAMpKaZgW6kSfPghvP9+aAq2YO1OuseiuWpkiYjzoYiFlPPiJMr3IuS2Q7j2WhVso5CVuV6uom2nTp0YM2ZMoUmvz+djzJgxnHfeeZYF53Q+n4+VK1fqjwtxBOW7OE3Y5/ymTfDww+b2s89CnTr2xiPllpUFw4fDpZeajcfatIHly+H660MbRyhzXXPVyBL250MRiynnxUmU70dYsgSmTgWPJ3+uLVHFylwv10ZkzzzzDJ06daJly5Z07NgRgAULFpCens7s2bMtC05ERERsMnQoHDgA554LN95odzRSTmvXQt++Zq8LMMXbp5+G+Hhbwwo6zVVFREQkLOX2sr3+erPJr0gJyrXS9uSTT2bVqlVcddVV/P333+zfv5/+/fuzevVqTj31VKtjFBERkVD68kv46iuIiYE33jD9tiTifPwxnHGGKdjWrAlffw0vvRT9BVvQXFVERETC0KJFMH26VtlKqZVrpS1AgwYNeOqpp6yMRY7gcrmoXbu2dlkUR1C+i9OEbc7v3w9Dhpjbd98Np5xibzxSZhkZcMcdpmctQKdOMHEiNGpkb1yhznXNVSNH2J4PRYJEOS9OonwvILeX7cCBcPzxtoYiwWNlrpd76cyCBQu47rrr6NChA1u3bgXgo48+YuHChZYF53QxMTF06NCBmJhy19ZFIobyXZwmbHP+kUdgyxY47jh46CG7o5EyWrEC2rY1BVu321yBN3u2/QVbIOS5rrlq5Ajb86FIkCjnxUmU74ctWADff2+uZNMcO6pZmevlKtr+5z//oXv37iQmJrJ8+XKysrIAsyOvVjRYx+fzsXr1ajXsFkdQvovThGXOr1gB48aZ26+9BklJ9sYjpRYIwCuvwDnnwB9/QMOGMGeOqcF7PHZHZ4Qy1zVXjSxheT4UCSLlvDiJ8v2w3FW2N90EzZrZGooEl5W5Xq6i7RNPPMGECRN46623iI2NzRs/99xzWb58uWXBOZ3f72fNmjX4/X67QxEJOuW7OE3Y5bzPB7fcAn6/2bnqoovsjkhKafdu6NXL7B2XnQ2XXgo//2zaIoSTUOa65qqRJezOhyJBppwXJ1G+A3Pnmk/TY2PhgQfsjkaCzMpcL1fRds2aNXQq4i+B5ORk9u3bV9GYREREJNRefx3++1+oWtXsViURYf58SEkxm4zFxcHLL5s95GrWtDsye2muKiIiImEhEMhfZTtoEDRpYm88ElHKVbStV68e69atO2p84cKFHK9myiIiIpFl27b8T/3HjIH69e2NR47J54PHHoOuXU0L4hNPhB9/NHvIaZ8PzVVFREQkTMyebT5lj4uD+++3OxqJMOUq2g4aNIhhw4axZMkSXC4X27ZtY+LEidx1113cdtttVsfoWG63myZNmuB2l3u/OJGIoXwXpwmrnB8+HPbvh3btTIsECWtbt8IFF5hFG34/DBgAy5ZBmzZ2R1ayUOa65qqRJazOhyIhoJwXJ3F0vhdcZXvLLeGxM6wEnZW57goEAoGy/lAgEOCpp55izJgxHDx4EID4+HjuuusuHn/8ccuCC4b09HSSk5NJS0ujatWqdocjIiJir2nToGdPs1vVsmVw+ul2RyQl+OYbGDjQ9LGtVMl0tbj+erujKp1QzsEiea5qJc17RUREbDRjBnTvDgkJ8Oef0KCB3RFJCFg5/ypX+dflcvHggw+yZ88efv31V3788Ud27drlqElwKPh8PlasWKFdFsURlO/iNGGR8wcPwh13mNvDh6tgG8aysmDECLPJ2O7dZlXt8uWRU7AFa3fSPRbNVSNLWJwPRUJIOS9O4th8DwRg1Chz+9ZbVbB1ECtzvUJrduPi4jj55JNp1aoV33//Pb///rtVcQlmx7lNmzY5e5dFcQzluzhNWOT8Y4/Bhg1mQ4RHH7UvDinR2rXQoQOMHWvuDx8OixebPraRxI5c11w1MoTF+VAkhJTz4iSOzffp02HJEkhMhHvvtTsaCSErc71cRdurrrqK8ePHA5CZmclZZ53FVVddRevWrfnPf/5jWXAiIiISJL/8Ai+8YG6PHw+VK9sbjxTp44/hjDPMqtqaNeHrr+GllyA+3u7IwpvmqiIiImKbgqts77gD6tWzNx6JWOUq2s6fP5+OHTsC8MUXX+D3+9m3bx8vv/wyTzzxhKUBioiIiMX8frMZgtcL//ynueZewkpGhtlg7Prrze1OnWDlSr1VpaW5qoiIiNhm6lT46SdISoK777Y7Golg5SrapqWlUaNGDQCmT5/OFVdcQVJSEj179mTt2rWWBuhkbrebli1bOnOXRXEc5bs4ja05//bb5vr6ypXh5ZdD//pSohUroG1b+PBDcLtN54rZsyN/w+FQ5rrmqpFFcwBxGuW8OInj8r3gKtshQ6BOHXvjkZCzMtfL9UyNGzdm8eLFHDhwgOnTp3PhhRcCsHfvXhISEiwLzuk8Hg+tWrXC4/HYHYpI0CnfxWlsy/mdO/P7aj3xRORXAqNIIACvvALnnAN//AENG8KcOfDIIxANp8ZQ5rrmqpFFcwBxGuW8OInj8v2rr8wn8JUrw1132R2N2MDKXC9X0Xb48OFce+21NGrUiAYNGtClSxfAXIp22mmnWRac03m9XhYtWoTX67U7FJGgU76L09iW8yNHwr59plHq4MGhfW0p1u7d0KsXDB0K2dmmDcLPP5u2CNEilLmuuWpk0RxAnEY5L07iqHz3+/M39x06FGrVsjUcsYeVuR5Tnh+6/fbbadeuHZs2beIf//hH3tLf448/Xn3CLBQIBNi1axeBQMDuUESCTvkuTmNLzs+cCZMmmWvu33wzOpZvRoH58+Haa2HLFoiLg+eeM1fTuVx2R2atUOa65qqRRXMAcRrlvDiJo/L9iy/Mp+5VqsCdd9odjdjEylwvV9EWoG3btrRt27bQWM+ePSsckIiIiARBZibcfru5PXiwaZoqtvL54MknYfRoszCjRQv49FNo08buyKKD5qoiIiISMgVX2Q4fDod764tURKnbIzz99NNkZmaW6rFLlixh6tSp5Q5KRERELPbUU7BuHTRoAI8/bnc0jrd1K1xwgelX6/dD//6wbJkKthWhuaqIiIjYZvJk+PVXSE6GESPsjkaiRKmLtv/73/9o0qQJt99+O99++y27du3K+57X62XVqlW89tprdOjQgb59+1KlSpWgBOwkHo+HlJQU5zTsFkdTvovThDTnf/8dnnnG3H7lFahaNfivKcX65hs4/XSYNw8qVYIPP4QPPjBX0kWzYOe65qqRS3MAcRrlvDiJI/Ld58tfZTtyJFSvbms4Yi8rc90VKEOzhZ9//pnx48czefJk0tPT8Xg8xMfHc/DgQQDatGnDzTffzMCBA8N2Z9709HSSk5NJS0ujqv5oFRGRaBcIQJcupnHqJZfA119HX7PUCJGVBffdB2PHmvtt2sAnn8CJJ9oaVsiEYg4WDXNVK2neKyIiEgKTJpkNCqpVgw0bzGpbcSwr519lKtrm8vv9rFq1io0bN5KZmUmtWrVISUmhVgTsjBdJk1ev18v8+fPp1KkTMTHlbj8sEhGU7+I0Icv5996DG2+EpCT43/+gadPgvZYUa+1a6NcPli8394cPh6efhvh4W8MKqT179lCzZs2QzMEiea5qpUiZ92oOIE6jnBcnifp893rhlFPgjz/giSfgwQftjkhsZuWct1z/xbjdblJSUkhJSanQi0vJAoEA+/fvd8Yui+J4yndxmpDkfGoq3H23uf3ooyrY2uTjj+G22yAjA2rWNHX0Sy+1O6rQC+X5XXPVyKI5gDiNcl6cJOrz/d//NgXbGjVg6FC7o5EwYGWuR+HHHCIiIgKYgu3u3dC6tVnaKSGVkQGDB5t+tQCdOsHEidCokb1xiYiIiIgFvF547DFz++67o3+DAgm5Um9EJiIiIhFk7lx4/33Tv/aNNyA21u6IHGXlSmjb1hRs3W6z0Hn2bBVsRURERKLGxx/DunVQq5b5pF7EYlppG8Y8Hg/t27eP7l0WRQ5TvovTBDXns7Lg1lvN7VtugXPOsf41pEiBAIwfD3fdBdnZ0LCh2ZuiUye7I7Ofzu9SHM0BxGmU8+IkUZvvOTn5q2zvuQcqV7Y3HgkbVua6irZhzO12U6dOHbvDEAkJ5bs4TVBz/tlnYc0aqFsXxowJzmvIUXbvNnu+ff21uX/ppaZ/bc2a9sYVLtxuXeAlRdMcQJxGOS9OErX5/uGHsH491KkDt99udzQSRqyc81bomdatW8d3331HZmYmENoNJpwgJyeHqVOnkpOTY3coIkGnfBenCVrOr10LTz5pbo8dC9WqWfv8UqT58yElxRRs4+Jg3Dj46isVbAuy4/yuuWpk0BxAnEY5L04SlfmenQ2PP25u33cfVKpkbzwSVqzM9XIVbXfv3k23bt048cQT6dGjB9u3bwfgpptu4s4777QsOAGv12t3CCIho3wXp7E85wMBuO020x7hwguhb19rn1+O4vOZK+O6doUtW6BFC/jxR7N5sMtld3TOpblq5NEcQJxGOS9OEnX5/t57sHEj1KuX35JMJAjKVbQdMWIEMTExbNq0iaSkpLzxvn37Mn36dMuCExERkTKYNAlmzYKEBHjtNVUNg2zrVrjgAnjkEfD7oX9/WLYM2rSxOzLRXFVERESCIisr/6q2+++HxER745GoVq6etjNmzOC7776j0RFbILdo0YKNGzdaEpiIiIiUwZ49MHKkuf3ww9C8ub3xRLlvvoGBA00f20qVTI28f3+7o5JcmquKiIhIULzzDmzeDA0awL/+ZXc0EuXKtdL2wIEDhVYt5NqzZw/x8fEVDkqMmJgYunbtSkyM9ouT6Kd8F6exPOfvuw/+/htOPhnuusua55SjZGXBiBFmk7Hdu82q2uXLVbAtjVCe3zVXjSyaA4jTKOfFSaIq3w8dgqeeMrcfeMBc3SZyBCtzvVxF244dO/Lhhx/m3Xe5XPj9fp599lm6du1qWXACiVpqLw6ifBensSznf/gB3nrL3J4wweyEJZZbuxY6dDD7uwEMGwaLF8OJJ9oalhRBc9XIozmAOI1yXpwkavL9rbdMf6xGjeDmm+2ORhygXEXbZ599ljfffJOLL76Y7Oxs7rnnHk499VTmz5/PM888Y3WMjuX1epk2bVr0Ne0WKYLyXZzGspzPycnfAOGmm6Bjx4oHJ0f5+GM44wyzqrZGDfjqK1O81aLN0gvl+V1z1ciiOYA4jXJenCRq8j0zM3+V7YMPahIoxbIy18tVtD311FP5448/OO+887j88ss5cOAAvXv3ZsWKFTRXDz0REZHQeeEF+PVXqFULVIyyXEaG6V17/fXmdqdO8PPPcNlldkcmJdFcVURERCz1xhuwYwc0aQI33mh3NOIQ5W60kJyczIMPPmhlLCIiIlIW69fDY4+Z2y++CDVr2htPlFm5Evr2hT/+ALcbRo2Chx4Cj8fuyKQ0NFcVERERSxw4AGPGmNsPP6xWZBIy5S7aHjp0iFWrVvH333/j9/sLfe8yLT8REREJrkAAbr/dXKp1/vlw3XV2RxQ1AgF49VW4807IzoaGDWHSJLPKViKH5qoiIiJiiddfNxv+HnccDBhgdzTiIK5AIBAo6w9Nnz6d/v37k5qaevQTulz4fD5LgguG9PR0kpOTSUtLo2rVqnaHU6JAIIDX6yUmJgaXy2V3OCJBpXwXp6lwzn/2mVkGGhcHq1ZBy5bWB+lAu3eb1sBffWXuX3opvPeeFjFbIS0tjWrVqoVkDhbJc1UrRcq8V3MAcRrlvDhJxOd7RgYcfzzs2gXvvgs33GB3RBLmrJzzlqun7ZAhQ7jyyivZvn07fr+/0JdTJsGhkpmZaXcIIiGjfBenKXfOp6XBsGHm9v33q2BrkQULICXFFGzj4mDcOHNbBdvIo7lq5NEcQJxGOS9OEtH5/uqrpmDbvLnZ5EAkhMpVtN25cycjR46kbt26VscjBXi9XubMmRP5uyyKlILyXZymQjn/4INmI4QTT4T77rM+OIfx+eDxx6FLF9iyBVq0gB9/hKFDIRIXhISrUJ7fNVeNLJoDiNMo58VJIjrf9++H554zt0eNgphydxgVB7Ey18tVtO3Tpw9z5861LAgREREppaVL4bXXzO0JEyAhwd54ItzWrdCtm5mH+/3Qvz8sWwZt2tgdmVSE5qoiIiJSYa+8YnpnnXgiXHON3dGIA5XrY4Lx48dz5ZVXsmDBAk477TRiY2MLfX/o0KGWBCciIiIFeL1wyy1mp6zrr4euXe2OKKJ98w0MHGjm4pUqmVp4//52RyVW0FxVREREKiQtDZ5/3tzWKluxSbmy7t///jczZswgISGBuXPnFmom7XK5Sj0Rnj9/Ps899xzLli1j+/btfPHFF/Tq1avEn5k7dy4jR47kt99+o3Hjxjz00EMMHDiwPIcREWJ0YhAHUb6L05Q5519+GVauhBo14IUXghKTE2Rlma4SY8ea+23awCefmEUUEh2smqtK6GgOIE6jnBcnich8f/ll2LsXWrWCfv3sjkYcyhUIBAJl/aF69eoxdOhQ7rvvPtzucnVYAODbb7/lhx9+oG3btvTu3fuYRdv169dz6qmncuutt3LzzTcza9Yshg8fztSpU+nevXupXjNSdtEVEREpZNMmOOkkOHgQ3n4bbrrJ7ogi0tq1Zt69fLm5P2wYPPMMxMfbG5cThHIOZtVcNdJp3isiIlIO+/ZBs2Zmte0nn0DfvnZHJBHEyvlXuT7uyM7Opm/fvhWeBF988cVcfPHFpX78hAkTOO6443jh8Oqik046iYULF/LSSy+VumgbSfx+P6mpqdSqVcvRf3CIMyjfxWnKnPNDhpiC7XnnwQ03BD/AKPTxx3DbbZCRYRYrv/ceXHaZ3VE5h9/vD9lrWTVXldDQHECcRjkvThKR+f7SS6Zge8opcOWVdkcjEcbKOW+5irYDBgzg008/5YEHHrAskNJYvHgx3bp1KzTWvXt3hg8fXuzPZGVlkZWVlXc/PT0dgJycHHJycgBwu914PB58Pl+hX27uuNfrpeCCZI/Hg9vtLnY893lz5V4KcOQOcsWNx8bG4vf7OXToEIsXL+Yf//gHcXFxxMTE4Pf78fl8eY91uVzExMQUG3u4HVNRseuYdExut5usrKy8fI+NjY2KY4rG90nHZN0xFXWOL/aYpkzB8/XXBGJj8Y4fjzsQwHP4NcPpmML1fcrIgBEjYvjgA3OJfMeOfj74wEejRgCReUwljYfr+3To0CFCxa65qpSPz+dj8eLF9OjRI3L+oBepAOW8OEnE5fuePfk9tB59FCIhZgkrBefaFVWuoq3P5+PZZ5/lu+++o3Xr1kdt7vDiiy9aEtyRduzYQd26dQuN1a1bl/T0dDIzM0lMTDzqZ8aMGcPo0aOPGp8xYwZJSUkANGnShDZt2rBq1So2bdqU95iWLVvSqlUrli5dyq5du/LGU1JSaNq0KfPnz2f//v154+3bt6dOnTrMmDGj0B9QXbt2JTExkWnTphWKoUePHmRmZjJnzpy8sZiYGHr27ElqaiqLFy8GYObMmVSpUoXzzz+fzZs3s3LlyrzH165dmw4dOrB27VrWrFmTNx7uxwTomHRMhY5p0aJFgMn3aDmmaHyfdEzWH9PMmTNLPqYTTsB3++14gLWXX87vGzbQMj4+rI8pnN6nv/6qyvPPn8W2bbG43QGuumoNV165hlWr4H//i8xjitT3KfeD81Cwa64qIiIiEe7FFyE9HVq3ht697Y5GHK5cPW27lrBbtcvlYvbs2WUPxOU6Zk/bE088kRtuuIH7778/b2zatGn07NmTgwcPFlm0LWqlbePGjUlNTc3rLRGuK2QOHTrEzJkztdJWx+SIY8rMzGTGjBlaaatjcswxFXWOLzL2u+6CsWMJHH883hUrIDExbI8pnN6nnBwvr7/u5p573GRnu2jYMMDHH8O550buMUX6+7R7927q168fkv6qwZirRqJI6Wmbk5PDtGnT6NGjx1EFdpFopJwXJ4mofE9NheOOM5dpTZkC//yn3RFJBNq9eze1atWyr6dtwRUdoVSvXj127txZaGznzp1UrVq1yIItQHx8PPFF7C4SGxt71AnD4/Hg8XiOemxxOx0WN17ciags4263m7i4OKpUqZL3x3zueFGXFBQXe7gdU1Gx65h0TLmxHJnvZY29uHG9TzomCL9jKuocf1Tsy5ebnWsB12uvEXvE//TD7ZjC5X3avRtuuimWr74y9y+5BN57z0WtWgCReUxlHQ/HY4qLiyvye8Fg11xVysflclGlShVcLpfdoYiEhHJenCSi8v2FF0zBtk0bKGFBoUhJrMz1chVt7dK+ffujLg2cOXMm7du3tymi4IqJieH888+3OwyRkFC+i9McM+d9PrjlFvD7oV8/iMINN4NhwQK45hrYsgXi4uC558webpHwd0K0K66gK6I5gDiNcl6cJGLyfdcueOUVc/vRRzV5lHKzcs5b6mfq3bs377//PlWrVqX3Mfp6TJkypVTPmZGRwbp16/Lur1+/npUrV1KjRg2aNGnC/fffz9atW/nwww8BuPXWWxk/fjz33HMPN954I7Nnz+azzz5j6tSppT2MiOL3+9m8eTONGzeOjIbdIhWgfBenOWbOv/Ya/PQTJCebHWylRD4fPPWUmWP7/dCiBXz6qVkoIeHByp10ixKMuaqEhuYA4jTKeXGSiMn3556DAwegbVu49FK7o5EIZuWct9RF2+Tk5LwlvsnJyZa8+E8//VSo59jIkSMBs+Pv+++/z/bt2wttfnHccccxdepURowYwbhx42jUqBFvv/023aN09ZHP52PlypU0aNAgvE9uIhZQvovTlJjzW7fCgw+a208/DfXqhT7ACLJ1K1x3Hcyda+737w/jx0OVKraGJUewcifdogRjriqhoTmAOI1yXpwkIvJ9504zeQR47DGtspUKsXLOW+qi7Xvvvcdjjz3GXXfdxXvvvWfJi3fp0oWS9kF7//33i/yZFStWWPL6IiIiYWn4cNi/H845B/71L7ujCWtTp8KAAaaPbaVKZoFy//52RyV2CMZcVURERBzgmWcgMxPatYOLL7Y7GpE8ZfqYY/To0WRkZAQrFhEREZk6FSZPBo8H3ngDwnVFgs2ysmDkSLPJ2O7dpg3C8uUq2Dqd5qoiIiJSJtu3w+uvm9ujR2uVrYSVMnXHLWlVrFjP5XJRu3btyNhlUaSClO/iNEXm/IEDcMcd5vaIEdC6tT3Bhbl168zebMuWmfvDhpkFEvHx9sYlJQvF+V1z1cikOYA4jXJenCTs8/3pp+HQIejQAS680O5oJApYmeuuQBlmt263m507d1K7dm3LAgi19PR0kpOTSUtLo2rVqnaHIyIiku/ee+HZZ6FJE/jf/8z1/lLIxIlw662QkQE1asB778Fll9kdlZRGKOZg0TBXtZLmvSIiIiXYuhWaNzeXcM2cCd262R2RRAEr519lWmkLcOKJJx6zarxnz55yByT5fD4fa9eupUWLFng8HrvDEQkq5bs4zVE5v2oVvPCC+earr6pge4SMDBg8GD74wNzv1MkUcBs1sjcuKb1gb0SWS3PVyKM5gDiNcl6cJKzzfcwYU7Dt2BEuuMDuaCRK2LIRWa7Ro0drR94Q8fv9rFmzhubNm4ffyU3EYsp3cZpCOe9ywS23gM8HvXubRq2SZ+VK6NsX/vjDtPgdNQoeesi0/ZXI4ff7Q/I6mqtGHs0BxGmU8+IkYZvvmzfDW2+Z2+plKxaycs5b5qJtv379qFOnjmUBiIiION5bb8GPP0KVKvDyy3ZHEzYCAbPo+M47ITsbGjY0q2s7d7Y7MglnmquKiIjIMT31lJlgdukCXbvaHY1IkcpUtA3bxtEiIiKRxOfDNW8eDefPx5WVZXrZAjzxhKlMCrt3w003wVdfmfuXXGL619aqZW9cEt40VxUREZFj2rAB3nnH3B492tZQREpSpqKtduQNLbfbTZMmTXC73XaHIhJ0yndxjClTYNgwYrZs4cyC48cfD3fcYVdUYWXBArjmGtiyBeLi4LnnYMgQXbUW6UJxftdcNTJpDiBOo5wXJwnLfH/yScjJMX1sO3WyOxqJMlbmuivgsNmtdtEVERHbTJkCffqY6/6L8p//mJ62DuXzmSvVHn0U/H5o0QI++QTOOMPuyMQKmoOFnn7nIiIiR/jrL2jZErxeWLgQzj3X7ogkylg5/wqjjzrkSD6fjxUrVoRst2UROynfJer5fDBsWPEFW5cLhg83j3OgrVuhWzezyZjfD9dfD8uWqWAbTXR+l+JoDiBOo5wXJwm7fH/iCVOw7d5dBVsJCitzXUXbMOb3+9m0aVPIdlsWsZPyXaLeggXmev/iBAJmF9sFC0IXU5iYOhVOPx3mzoVKleCDD+DDD82+bBI9dH6X4mgOIE6jnBcnCat8X7fOTDJBvWwlaKzMdRVtRUREQmH7dmsfFwWysmDkSLPJ2O7d0KYNLF8O/fvbHZmIiIiIRJ3HHzdXtfXoAe3a2R2NyDGVaSMyERERKYe0NPjuu9I9tn794MYSJtatg379TAsEMJ0jnnkG4uPtjUtEREREotCaNfDxx+a2VtlKhFDRNoy53W5atmwZXrssigSJ8l2iUkYGvPIKPPcc7N1b8mNdLmjUCDp2DE1sNpo4EW691fx6atSA996Dyy6zOyoJNp3fpTiaA4jTKOfFScIm3x9/3GyccOmlcOaZ9sYiUc3KXNf/JcKYx+OhVatWeDweu0MRCTrlu0SVgwfhhRfg+OPhgQdMwbZVK9MLwOUyXwXl3h87FqL4v4GMDLjhBrjuOnO7Uyf4+WcVbJ1C53cpjuYA4jTKeXGSsMj333+HSZPMba2ylSCzMtdVtA1jXq+XRYsW4fV67Q5FJOiU7xIVDh2Cl1+G5s3hrrtg1y444QT46CP49VdTyJ08GRo2LPxzjRqZ8d697Yk7BFauNIsa3n8f3G549FGYPdscujiDzu9SHM0BxGmU8+IkYZHvo0ebTX979TKbKIgEkZW5rqJtGAsEAuzatYtAIGB3KCJBp3yXiJadDRMmQIsWpjnrjh3QtCm88475ZP+66/JX0PbuDRs24J05k59GjsQ7cyasXx+1BdtAAMaPh3POMa3EGjY0xdpHHonqRcVShEg+v7/66qs0a9aMhIQE2rVrx9KlS0t8/Oeff06rVq1ISEjgtNNOY9q0aYW+P2XKFC688EJq1qyJy+Vi5cqVRz3HoUOHuOOOO6hZsyaVK1fmiiuuYOfOnVYeVtjQHECcRjkvTmJ7vv/6K3z2mbn96KP2xCCOYmWuq2grIiJSXjk5pjB74olw222wZYupSr7+OvzxB9x4I8QU0T7e4yHQuTNbO3Ui0Llz1FYv9+yBf/4ThgyBrCy45BKz4rZzZ7sjEym9Tz/9lJEjR/LII4+wfPlyTj/9dLp3787ff/9d5OMXLVrE1VdfzU033cSKFSvo1asXvXr14tdff817zIEDBzjvvPN45plnin3dESNG8H//9398/vnnzJs3j23bttE7Sj/cERERCZrcVbZ9+sDpp9sdjUiZqGgrIiJSVj6faXlw0klw882wcSPUq2daI6xbZ3bZiouzO0pbLVgAKSnw1VfmVzFuHHz9NdSqZXdkImXz4osvMmjQIG644QZOPvlkJkyYQFJSEu+++26Rjx83bhwXXXQRd999NyeddBKPP/44Z5xxBuPHj897zPXXX8+oUaPo1q1bkc+RlpbGO++8w4svvsj5559P27Ztee+991i0aBE//vhjUI5TREQk6vz8s2lB5nKZy7xEIoyKtmHM4/GQkpKiBvXiCMp3iQh+P3z6KZx6KvTvD3/+aaqQzz9vbg8ZAgkJpXqqaM15n89sztulC2zebDpGLF4MQ4cevf+aOEsk5np2djbLli0rVFx1u91069aNxYsXF/kzixcvPqoY271792IfX5Rly5aRk5NT6HlatWpFkyZNyvQ8kSJaz4cixVHOi5PYmu+5m45ddZWZv4uEgJW5XsQ1mxIu3G43TZs2tTsMkZBQvktYCwTgiy/MJ/S5lzhXrw733AODB0PlymV+ymjM+a1bTfveuXPN/euvh1dfhSpVbA1LwoTbHXlrBVJTU/H5fNStW7fQeN26dVm9enWRP7Njx44iH79jx45Sv+6OHTuIi4ujWrVqZXqerKwssrKy8u6np6cDkJOTQ05ODmDeB4/Hg8/nw+/35z02d9zr9RbqxebxeHC73cWO5z5vrpjDLWGO3ISjuPHY2FgAGjRogM/nw+fz4XK5iImJwe/34/P58h6bO15c7OF0TMXFrmPSMbndbvx+f6Gcj4Zjisb3ScdkzTFB0ef4oB/T0qXEfvEFAZcL7wMPEHP4MXqfdEzBPqaCr1lRKtqGMa/Xy/z58+nUqVNegopEK+W7hKVAAKZOhVGjYMUKM1a1Ktx5Jwwfbm6XU7Tl/NSpMGAA7N4NlSrBa6+ZxcgiubRLevCNGTOG0bmrigqYMWMGSUlJADRp0oQ2bdqwatUqNm3alPeYli1b0qpVK5YuXcquXbvyxlNSUmjatCnz589n//79eePt27enTp06zJgxo9B727VrVxITE4/afK1Hjx5kZmYyZ86cvLGYmBh69uzJzp07C23uVqVKFc4//3w2b95caJO22rVr06FDB9auXcuaNWvyxsPxmFJTUwutitYx6ZgKHtO8efPIyMiIqmOKxvdJx2TNMW3YsIFffvkl5Me0Z8gQ6gJbOnZk+fr1dG3WTO+TjikkxzR79mys4go4bMvK9PR0kpOTSUtLo2oF/tgOhZycHKZNm0aPHj3yViGIRCvlu4SVQABmzjTF2iVLzFjlyjBsmCnYVq9e4ZeIlpzPyoL774eXXjL3U1JMB4kTT7Q1LAlDu3fvplatWhExB8uVnZ1NUlISkydPplevXnnjAwYMYN++fXz11VdH/UyTJk0YOXIkw4cPzxt75JFH+PLLL/n5558LPXbDhg0cd9xxrFixgpSUlLzx2bNnc8EFF7B3795Cq22bNm3K8OHDGTFiRJHxFrXStnHjxqSmpub9zsNxhUxWVhbTp0/nH//4B7GxsbavkInGVT86pvA6pszMTGbMmJGX89FwTNH4PumYrDmm4s7xQT2m5cvhrLMIuN14f/4ZWrbU+6RjCtkx7dixg/r161sy5438pT0iIiJWmjsXHn4YFi409xMTTa/au+/WLlpHWLcO+vWDZcvM/aFD4dlnIT7e3rhErBIXF0fbtm2ZNWtWXtHW7/cza9YsBg8eXOTPtG/fnlmzZhUq2s6cOZP27duX+nXbtm1LbGwss2bN4oorrgBgzZo1bNq0qcTniY+PJ76I/wBjY2OP+nDI4/EU2XOtuJX/xY0X96FTWcZzW2ccGafb7S6yrUZxsYfbMRUVu45Jx1Rw/Micj4ZjKk2MZR3XMUX+MeX+TMGfC+oxHd50zHXttcQe0ctW75OOya5jKg8VbUVERAB++MGsrM29nCU+Hm67De67D47oTykwcSLceitkZECNGvDee3DZZXZHJWK9kSNHMmDAAM4880zOPvtsxo4dy4EDB7jhhhsA6N+/Pw0bNmTMmDEADBs2jM6dO/PCCy/Qs2dPPvnkE3766SfefPPNvOfcs2cPmzZtYtu2bQB5l+jVq1ePevXqkZyczE033cTIkSOpUaMGVatWZciQIbRv355zzjknxL8BERGRCLJkCUybBh6PWYghEsFUtA1jHo+H9u3bF1npF4k2ynexzdKlplj73XfmfmwsDBoEDzwADRsG7WUjNeczMszC4/ffN/c7doRJk6BRI1vDkggQabmeq2/fvuzatYtRo0axY8cOUlJSmD59et5mY5s2bSq0iqNDhw5MmjSJhx56iAceeIAWLVrw5ZdfcmqBlT5ff/11XtEXoF+/foBpo/Doo48C8NJLL+F2u7niiivIysqie/fuvPbaayE44tCL1POhSHkp58VJQp7vh1fZ0r8/tGgRmtcUKcDKXFdPWxERcaaVK02x9v/+z9z3eOCGG+Chh6BpU1tDC1crV5p2CGvWgNttFi889BBEwT5qEgKag4WefuciIuIoixbBueeayemaNXD88XZHJA5k5fzr6OYOEjZycnKYOnXqUY2ZRaKR8l1C5tdf4YoroE0bU7B1u2HAADOxe+utkBVsIynnAwEYPx7OOcf8mho2NF0kHn1UBVspvUjIdbFHJJ0PRaygnBcnCWm+566yHThQBVuxjZW5rj+1wtyRO+OJRDPluwTVmjWmyvjpp6YK6XKZZaOPPAItW9oSUiTk/J49cOON8NVX5v4ll5j+tdqTTUSsFAnnQxErKefFSUKS7/Pnw/ffm1ZnDz4Y/NcTCQGttBURkej2559mJe3JJ8Mnn5iC7RVXwKpVphmrTQXbSLBwIaSkmIJtXByMGwdff62CrYiIiIiEmdxVtjfeCM2a2RqKiFW00lZERKLTxo3w+ONmxyyfz4xddhmMHm0qkVIsnw+eesosTPb7zR4On3wCZ5xhd2QiIiIiIkeYMwfmzjWrDB54wO5oRCyjom0Yi4mJoWvXrsSoYaA4gPJdLLN1Kzz5JLz9NuT2E7roInjsMTjrLHtjKyBcc37bNrj2WjPvBbj+enj1VahSxdawJAqEW65L+AjX86FIsCjnxUmCnu+BQP4q25tvhiZNgvM6IqVkZa7r/xJhLjEx0e4QREJG+S4VsmMHPP00TJgAWVlm7IILTLG2Qwd7YytGuOX81Klm34bUVKhUCV57Dfr3tzsqEXGCcDsfigSbcl6cJKj5Pns2LFgA8fFw//3Bex0RG6inbRjzer1MmzZNTerFEZTvUm67dsE995gdYseNMwXbjh3NUtHvvw/bgm045XxWFowcaTYZS0013SOWL1fBVqwVDrku4SmczocioaCcFycJar4HAjBqlLl9yy3QqJH1ryFSRlbmulbaiohIZNqzB154AV5+GTIyzFi7dqaPbbdu4HLZG1+EWLcO+vWDZcvM/aFD4dlnzWIFEREREZGwNWMGLFoECQlw3312RyNiORVtRUQksqSlwdix8OKLkJ5uxs44w7RB6NFDxdoymDgRbr3V1Lxr1ID33jN7tYmIiIiIhLWCvWxvuw3q17c3HpEgUNFWREQiQ0YGvPIKPPcc7N1rxk47zRRrL79cxdoyyMiAIUPg/ffN/Y4dYdIkXVEmIiIiIhHi229hyRJITIR777U7GpGgcAUCgYDdQYRSeno6ycnJpKWlUbVqVbvDKVEgEMDr9RITE4NLxQiJcsp3KdbBg2ZHrGeeMQ1XAVq1gtGjoU8fcEdme3a7cn7lStMOYc0a86t7+GF46CHQBtYSbGlpaVSrVi0i5mDRIlLmvZoDiNMo58VJgpLvgQCcfTb89BPcdZdZ1CESJqyc80bmX7oOkpmZaXcIIiGjfJdCDh0y/WqbN4e77zYF2xNOgI8+gl9/hauuitiCba5Q5nwgAOPHwznnmIJtgwYwaxY8+qgKtiJiP80BxGmU8+Ikluf7N9+Ygm2lSmZDYpEoFdl/7UY5r9fLnDlztKuoOILyXfJkZ8Prr5sC7bBhsGMHNGsG77wDv/8O110HHo/dUVZYKHN+zx745z9NS4SsLLjkEvj5Z+jSJegvLZJH53cpjuYA4jTKeXESy/O9YC/bwYOhdm1rnlfEIlae27W2RkREwkNODnz4ITz+OGzcaMYaNTLX7t9wA8TF2RtfhFq4EK65BjZvhthYc/XY0KFqASwiIiIiEeirr2DFCqhc2bRGEIliKtqKiIi9fD6zC9bo0fDnn2asXj144AEYNAgSEuyNL0L5fPDUU6b9gd8PLVrAJ5/AGWfYHZmIiIiISDn4/fmrbIcNg1q17I1HJMhUtA1zMWo0KA6ifHcYvx8++8xUFdesMWO1a8N998Gtt0JSkq3hhUKwcn7bNrj2Wpg719y//np49VWoUiUoLyciUmGaA4jTKOfFSSzL9ylTYNUqqFoVRo605jlFwpgrEAgE7A4ilCJlF10RkagVCMAXX5hPyX/91YzVqGE2Gxs82FzqJOU2dSoMHGj2batUCV57Dfr3tzsqEc3B7KDfuYiIRA2/H1q3ht9+g1GjzFV6ImHIyvmXNiILY36/n7///hu/3293KCJBp3x3gEDA7PTati1ccYUp2CYnmwnX+vVmha2DCrZW53x2Ntx5p9lkLDUVUlJg+XIVbCV86PwuxdEcQJxGOS9OYlm+f/65KdgmJ8OIEdYEJxIEVp7bVbQNYz6fj8WLF+Pz+ewORSTolO9RLBCAGTOgfXu49NL8jQMeesgUa0eNMpc4OYyVOb9uHXToAC++aO4PHQo//ggnnljhpxaxjM7vUhzNAcRplPPiJJbku89nWqqBaYtQrZoVoYkEhZXndjXSERGR4JkzxxRlFy4095OSTAuEu+/WxgEWmTQJbrkFMjJMl4n33oPLLrM7KhERERERi3z6KaxeDdWrmw3IRBxCRVsREbHeDz/Aww+boi1AfDzcfjvcey/UrWtvbFHiwAEYMsQUaQE6djQF3EaN7I1LRERERMQyXm9+/9o77zTtEUQcQkXbMOZyuahSpQoul8vuUESCTvkeJZYuNStrv/vO3I+NhX/9Cx54ABo0sDe2MFORnF+5Evr1gzVrwO029fGHHgJtRC3hTOd3KY7mAOI0ynlxkgrn+6RJ8Mcf5pKyoUOtDU4kCKw8t7sCgUDAsmeLANpFV0QkCFasMMXab74x92Ni4IYbTCWxSRN7Y4sigQC8+ircdRdkZZk6+MSJ0KWL3ZGJHJvmYKGn37mIiEQ0rxdatYI//4SnnzZX7YmEOSvnX9qILIz5/X42btyoXUXFEZTvEerXX+GKK+CMM0zB1u2GAQPMEtA331TBtgRlzfk9e+Cf/zQtEbKy4JJL4OefVbCVyKHzuxRHcwBxGuW8OEmF8v2jj0zBtlYtuOMO64MTCQIrz+0q2oYxn8/HypUrtauoOILyPcKsXg1XXw2tW8OUKeBywTXXwP/+B++/D8cfb3eEYa8sOb9wIaSkwFdfmY4TY8fC119rLzeJLDq/S3E0BxCnUc6Lk5Q733Ny4PHHze1774XKla0PTiQIrDy3q/udiIiU3rp18Nhj5pr83E8Q+/SBRx+FU06xNbRo5PPBU0+ZX6/fDyecAJ98Am3b2h2ZiIiIiEgQffABrF9vNjG+/Xa7oxGxhYq2IiJybBs3mk+633/fVBIBLrvM7OSakmJnZFFr2za49lqYO9fcv+46eO01qFLF1rBERERERIIrO7vwKtukJHvjEbGJirZhzOVyUbt2be0qKo6gfA9TW7fCk0/C22+bS5QALr7YFGvPOsve2CJcSTk/dSoMHAipqVCpkinW9u8f+hhFrKTzuxRHcwBxGuW8OEm58v2992DTJqhXD269NXjBiQSBled2VyAQCFj2bBFAu+iKiJTCjh1mh9YJE8yuVwAXXGBaI3ToYG9sUSw7G+6/H1580dxPSTHtEFq2tDUsEUtoDhZ6+p2LiEjEycoyPcG2bIGXXza78IpEECvnX9qILIz5fD5Wr16tBvXiCMr3MLFrF9x9t9lIbNw4M2nq2NFco//99yrYWsTng1mzfLzwwlZmzfLh85l2wR065Bdshw6FxYtVsJXoofO7FEdzAHEa5bw4SZnz/e23TcG2YUMYNCi4wYkEgTYicwi/38+aNWto3rw5Ho/H7nBEgkr5brM9e+CFF0yh9sABM3bOOaaX1AUXgC7fs8yUKTBsGGzZ4gEaAlCjBhw8CIcOmdvvvWdaBotEE3/u5oUiR9AcQJxGOS9OUqZ8P3TI7MIL8MADkJAQ/ABFLGblnFdFWxERJ0tLg5deMl/p6WasbVvTBuHii1WstdiUKdCnDxzZmGjPHvPvSSfBjBnQqFHoYxMRERERsdWbb5rdeBs3hptusjsaEdupaCsi4kQZGaZH1PPPw969Zqx1a7PB2OWXq1gbBD6fWWFbUif5/fuhfv3QxSQiIiIiEhYyM2HMGHP7wQchPt7eeETCgHrahjG3202TJk1wu/U2SfRTvofIwYOmUHvccWYytHevWd752WewYgX06qWCbZAsWGDac5VkyxbzOJFopPO7FEdzAHEa5bw4SanzfcIEsxly06Zwww2hCU4kCKw8t2ulbRjzeDy0adPG7jBEQkL5HmSHDsEbb5hPr3fuNGMnnACPPgr9+oH6qQVVZiZ8/HHpHrt9e3BjEbGL+jZKcTQHEKdRzouTlCrfDxyAp582tx96COLigh+YSJBYOefVR3thzOfzsWLFCu0qKo6gfA+S7Gx4/XVToB0+3BRsmzWDd9+F33+Ha69VwTaIUlNNe+CmTeGdd0r3M2qPINFK53cpjuYA4jTKeXGSUuX766/D33+bqwEHDAhdcCJBYOW5XUXbMOb3+9m0aZN2WxZHUL5bLCcH3n4bWrSA22+HrVvN7lYTJsCaNeaSoxhdbBEs69bBHXdAkybwyCOwa5fZTyE5ufjuEy6XeUzHjqGNVSRUdH6X4mgOIE6jnBcnOWa+Z2TAM8+Y26NGQWxs6IITCQIrz+0q2oqIRBOfDz780PSpHTQINm0ySzdfecVUEm+5RZcbBdHixXDFFXDiifDaa6YtwhlnwKRJ8NdfZoEzHF24zb0/dqwWPouIiIiIg4wfby5PO+EEuO46u6MRCSsq2oqIRAO/Hz75BE45xVxS9OefULs2vPCCuT14sHZgDRKfD774As49Fzp0gClTIBCAHj1g9mz46Se4+mqzsLl3b5g8GRo2LPwcjRqZ8d697TkGEREREZGQS0+H554zt0eN0pWAIkfQfxFhzO1207JlS+0qKo6gfC+nQMBUDB95BH791YzVqAH33GOuz69c2d74otjBg/DBB/Dii2YRM5hFzNddByNHmvp5UXr3hssvh7lzfaxcuYOUlHp06eLRCluJejq/S3E0BxCnUc6Lk5SY76+8Anv2mMvUrr469MGJBIGV53ZXIBAIWPZsESA9PZ3k5GTS0tKoWrWq3eGIiJRPIADffGOKtStWmLHkZLjzThg2DHR+C5q//zatD1591VzJBVC9Otx2m1nQrI3ERIqmOVjo6XcuIiJhKy3NbDy2dy9MnAjXXGN3RCKWsHL+pY/2wpjX62XRokV4vV67QxEJOuV7KQUC8N13cM45cNllpmBbuTI89BCsXw8PP6yCbZD88Qfceis0bQqjR5uCbbNm8PLLpnXwk0+WrWCrnBenUa5LcXQ+FKdRzouTFJvv48aZgu1JJ0HfvvYEJxIEVp7b1R4hjAUCAXbt2oXDFkOLQynfS2HOHFOU/eEHcz8pCYYMgbvuglq17I0tSgUC5tf9/PPw9dfmPsCZZ8Ldd5tWB+VtvaWcF6dRrktxdD4Up1HOi5MUme/79pkeY2CuHFSfMIkiVp7bVbQVEQl3Cxeaxvxz5pj7CQnmWvx774W6de2NLUr5fPDll6ZY++OP+eOXXmpq5B07gstlW3giIiIiIpHrpZdMe4RTToErr7Q7GpGwpaKtiEi4WrrUrKydMcPcj4uDQYPggQegQQN7Y4tSBw7A+++bD/7/+suMxcdD//5mc7FWrWwNT0REREQksu3ZY4q2YHqOaUM+kWKpaBvGPB4PKSkpeHSpgDiA8r2AFSvMytpvvjH3Y2LgxhvhwQehSRN7Y4tSO3fC+PFmg7E9e8xYjRpw++1mc7FgLGhWzovTKNelODofitMo58VJjsr3F16A/fuhdWv45z/tDU4kCKw8t6toG8bcbjdNmza1OwyRkFC+A7/8Yno6ffGFue92myWeDz8Mxx9vb2xR6vffzarajz6CrCwzdvzxZlXtwIFQqVLwXls5L07j1koaKYbOh+I0ynlxkkL5nppqdvEFrbKVqGXlnFf/hYQxr9fL7NmztauoOIKj8331aujXD04/3RRsXS645hpTUXzvPRVsLRYIwPz5pj/tySfD22+bgm27djB5MvzxB9xxR3ALtuDwnBdHUq5LcXQ+FKdRzouTFMr355+HjAxo0wYuv9zu0ESCwspzu1bahrFAIMD+/fu1q6g4giPzfd06eOwxmDgR/H4z1qcPPPqoacovlvJ6YcoUM1f873/NmMtl5ot33QUdOoR2czFH5rw4mnJdiqPzoTiNcl6cJC/fd+6EV14xg6NHa1dfiVpWnttVtBURCbWNG+Hxx82OVz6fGbv8cjN5Of10W0OLRhkZ8O67Zr+DDRvMWEICDBhg2iCceKKt4YmIiIiIRCefD9e8eTScPx/3xIlw8CCceSZccondkYlEBBVtRURCZcsWePJJeOcdyMkxYxdfbFbbnnmmvbFFoe3bzYf5r78O+/aZsZo1zcZit98OderYGp6IiIiISPSaMgWGDSNmyxYK/aVz4YVaZStSSirahjGPx0P79u21q6g4QlTn+44dMGYMvPFG/m5X3bqZYm379vbGFoV++81sLvbxx5CdbcZOOAHuvNPs65aUZG98uaI650WKoFyX4uh8KE6jnJeoN2WKaftW1GXiY8ZA27bQu3fo4xIJASvP7a6AwxrppKenk5ycTFpaGlWrVrU7HBGJZrt2wbPPwquvQmamGevUyRRrO3e2N7YoEwjA3LmmX+20afnjHTrA3XebTcf0d5GIvTQHCz39zkVEJOR8PmjWzFxlWBSXCxo1gvXrNUGXqGTl/MttUUwSBDk5OUydOpWc3MuoRaJYVOX7nj3wwANw3HGmipiZCeecAzNnmsqiCraWycmBf//bdJc4/3xTsHW5zAf3P/xgvnr1Cs/5YFTlvEgpKNelODofitMo5yWqLVhQfMEWzGqLzZvN40SikJXndrVHCHNer9fuEERCJuLzPS3N7Hb10kuQnm7G2rY1m45ddJF6N1lo/354+20YOxY2bTJjiYlwww0wYoRphxAJIj7nRUQsovOhOI1yXqLSgQMwaVLpHrt9e3BjEYkCKtqKiFTU/v3w8stmVW3ujletW5s2CJddpmKthbZuNZuLTZhgauQAtWvDkCFw221Qq5a98YmIiIiIOM6vv5r9Oz78MH/xyrHUrx/cmESigIq2IiLldfCg6Vf77LOQmmrGTjoJRo+GK64AtzrQWOWXX+CFF8wH97lXm7RsaTYXu+46s8pWRERERERC5NAh+M9/zGqKhQvzx48/HnbvNsXborZQyu1p27Fj6GIViVAq2oaxmJgYunbtSkyM3iaJfhGV74cOmU+Sx4yBnTvNWIsW8Oij0LdveDZQjUCBAMyaZRYwf/dd/njHjnDXXXDJJZFdF4+onBexgHJdiqPzoTiNcl4i2rp15m+h994zxVkwf//06gW33mo2mvjyS+jTxxRoCxZuc69AHDtWfzNJ1LLy3K7/S4S5RC0fEwcJ+3zPyoJ33oEnn4Rt28zYccfBqFFmuacm3pbIyYFPPzXF2p9/NmNut1m8fOed0K6dvfFZKexzXkQkRHQ+FKdRzktEycmB//s/s6p25sz88UaN4F//gptuggYN8sd794bJk2HYsMKbkjVqZAq2vXuHLHSRSBbBa5Sin9frZdq0aWpSL44Q1vmek2N2vTrxRLjjDlOwbdzYfMK8ejUMHKiCrQXS000LhOOPh+uvNwXbpCTTr3btWvjss+gq2IZ1zosEgXJdiqPzoTiNcl4ixqZNZoFK06ZmBcXMmWa1bI8e8PXXsH49PPxw4YJtrt69YcMGvDNn8tPIkXhnzjSPV8FWopyV53ZVGUREiuPzwcSJZkOxP/80Y/XrwwMPwKBBEB9vb3xRYvNms4/bm2/m71tQt27+5mI1atgbn4iIiIiIY/h8pjfZhAkwdSr4/Wa8Th24+Wbzd1CzZqV7Lo+HQOfObD1wgNM7d1ZLBJEyUtFWRORIfr9Z1vnoo7BmjRmrUwfuu8/0adLlbJZYudKsrP3kE8j9MPKkk0wLhGuvhYQEW8MTEREREXGOHTvg3XfNSoqNG/PHzz/f/A10+eUQF2dffCIOpKKtiEguvx+++AIeeQR++82M1agB99wDgwdDpUr2xhcFAgFzVdVzz8H33+ePd+liNhe7+OLI3lxMRERERCRiBAIwd65ZVTtlSv5KiurV4YYbTL/ali1tDVHEyVyBQMGt/KJfeno6ycnJpKWlUbVqVbvDKVEgEMDr9RITE4Mrd5dFkShla74HAvDNN6Zf08qVZiw52Sz5HDYMwvxcEQmys82K2uefh19+MWNuN1x5pSnWnnmmvfHZQed4cZq0tDSqVasWEXOwaBEp816dD8VplPNiuz174IMPTLH2jz/yx9u3N6tqr7zSsqsLle/iNFbOebXSNsxlZmZSpUoVu8MQCYmQ53sgADNmmGLt0qVmrEoVGD4cRo6EatVCF0uU2rfPXGE1bpzZvw3MguWbbza/5tK2w4pWOseLiBg6H4rTKOcl5AIB+PFHU6j99FPIyjLjlSubXYBvuQVOPz0oL618FykfXYQaxrxeL3PmzNGuouIIIc/32bOhY0e46CJTsE1KgnvvNTuaPvaYCrYVtHGjqXs3bmx+rdu2mT3cxowxG4+NHauCrc7x4jTKdSmOzofiNMp5Can0dHj9dUhJgQ4d4MMPTcE2JQXeeMNM1F97LWgFW+W7OI2Vua6VtiLiLAsXwsMPm95NYHa7uv1207e2bl1bQ4sGy5ebzcU+/dRsPAtwyimmBcLVV0N8vL3xiYiIiIg4wooVpig7cSJkZJixhAQzKb/1VjjrLFC7ApGwFhYrbV999VWaNWtGQkIC7dq1Y2nuZcpFeP/993G5XIW+ErTFuIgcy5Il0L27WV07d67Z+XTwYPjzT1NlVMG23AIB+PZbuOACaNsWJk0yBdsLLjDjv/wCAweqYCsiIiIiElQHD8L778M558AZZ5iibUYGtGplLnXbtg3efRfOPlsFW5EIYPtK208//ZSRI0cyYcIE2rVrx9ixY+nevTtr1qyhTp06Rf5M1apVWbNmTd79aG5mHRNj+1skEjJByffly03P2qlTc18EbrwRHnwQmjSx/vUcJCvLFGhfeAF++82MeTzQt6/Zw+2MM+yNLxLoHC8iYuh8KE6jnBdL/f67KdB+8IHZVAIgNhauuMKsqu3UydYirfJdpHxcgUAgYGcA7dq146yzzmL8+PEA+P1+GjduzJAhQ7jvvvuOevz777/P8OHD2Zd7IiqjSNlFV0Qq6Jdf4JFH4IsvzH2PB/r3N60RjjvO3tgi3N69Zv+Cl1+GHTvMWOXK8K9/wbBhqoWLSNE0Bws9/c5FRKJYVpb5W2fCBJg3L3/8uOPMpmI33ADFLIQTkeCxcv5la3uE7Oxsli1bRrdu3fLG3G433bp1Y/HixcX+XEZGBk2bNqVx48Zcfvnl/Ja7xCvK+P1+/v77b/x+v92hiASdZfm+ejX062ca6X/xhflE+dprzafP776rgm0FbNgAw4ebzcUeeMAUbBs2hGefNZuLvfCCCrZloXO8OI1yXYqj86E4jXJeKuSvv+D++82k/OqrTcHW7YZevWD6dFi3zuwEHCYFW+W7OI2VuW7rGvXU1FR8Ph91j+glWbduXVavXl3kz7Rs2ZJ3332X1q1bk5aWxvPPP0+HDh347bffaNSo0VGPz8rKIisrK+9+eno6ADk5OeTk5ACmUOzxePD5fIV+ubnjXq+XgguSPR4Pbre72PHc582VeynAkTvIFTceGxuL3+/n0KFDLF68mH/84x/ExcURExOD3+/Hl7u7D6Y1RExMTLGxh9sxFRW7jknH5Ha7ycrKysv32NjYsh/Thg3w+OMwcSKuw6/rv+IK3I89hr9VKxP74efS+1S2Y1q2zMVLL3mYPNlF7kNPPTXAiBE+rrnGTUKC+/A5NXKOCex/n4o6x0f6MUXj+6Rjsu6YDh06hEhRfD4fixcvpkePHrjdYbHlhkhQKeelzLxe0+5twgT47juzqQRAgwYwaBDcfDMUUQ8JB8p3cZqCc+2KirjGIu3bt6d9+/Z59zt06MBJJ53EG2+8weOPP37U48eMGcPo0aOPGp8xYwZJSUkANGnShDZt2rBq1So2bdqU95iWLVvSqlUrli5dyq5du/LGU1JSaNq0KfPnz2f//v2FYqtTpw4zZswo9AdU165dSUxMZNq0aYVi6NGjB5mZmcyZMydvLCYmhp49e5Kampq32njmzJlUqVKF888/n82bN7Ny5cq8x9euXZsOHTqwdu3aQn1+w/2YAB2TjqnQMS1atAgw+V6WY0rcuZOWn31Gk7lzcR0+OW4/+2xWX301B1u0oOfJJ5P69996n8p4TD/+uJTp0918+eUJ/PZbrbzvt227h0suWU1Kyi5cLkhPb09CQmQcU7i+TzNnzoy6Y4Loe590TBU/ptwPzkVERKSUtm6Ft9+Gt94yt3N172561V5yidmzQ0Sikq09bbOzs0lKSmLy5Mn06tUrb3zAgAHs27ePr776qlTPc+WVVxITE8O///3vo75X1Erbxo0bk5qamtdbIlxXyBw6dIiZM2dqpa2OyRHHlJmZyYwZM0q/0nbDBtxPP437vfdwHX5MoEcPvKNGFdoBS+9T2Y4pO9vNv//t4YUXAqxe7TocR4B+/eCuu1ycckrkHVO4vk9FneMj/Zii8X3SMVl3TLt376Z+/frqrxpCkdLTNicnh2nTptGjRw9iY2PtDkck6JTzUiK/H2bONKtq/+//IPf/27Vrmw2VBw2C5s3tjbEMlO/iNLt376ZWrVqWzL9s/UgmLi6Otm3bMmvWrLyird/vZ9asWQwePLhUz+Hz+fjll1/o0aNHkd+Pj48nPj7+qPHY2NijThgejwePx3PUY4vb6bC48eJORGUZd7vdxMXFUaVKlbw/5nPHi7qkoLjYw+2Yiopdx6Rjyo3lyHwvMsbt22HMGGLfeAOys81Yt27w2GO42renqCPV+3TsY9q928wLX3kFdu4EcFG1qtnDYOhQV4GrrSLnmI4Ubu9TUef4SD+maHyfdEzWHVNcXFyR3xNxuVxUqVIFl407m4uEknJeirRrl9mD4803Td/aXJ06wW23wT//CUXUNsKd8l2cxspct3WlLcCnn37KgAEDeOONNzj77LMZO3Ysn332GatXr6Zu3br079+fhg0bMmbMGAAee+wxzjnnHE444QT27dvHc889x5dffsmyZcs4+eSTj/l6kbLiQESOsGsXPPMMvPYaZGaasU6dTB/bTp3sjS2C/fUXvPSSmR8ePGjGGjc2G47dfDPoNCkiVtEcLPT0OxcRCXOBACxYYFZPTJ6ctwcHyckwYIBZQVGKOoeIhA8r51+2Nz/p27cvu3btYtSoUezYsYOUlBSmT5+etznZpk2bCq0I2bt3L4MGDWLHjh1Ur16dtm3bsmjRolIVbCON3+9n8+bNNG7cWA27Jbr5fPjnzWP3r79S89RTcXfuDLkru/bsgeefh5dfhgMHzFj79qZYe/75oE9sy2XJEvNrnTKFvM3FUlLgrrvgqqtAVy4Fn87x4jTaNVqKo/OhOI1yXti7Fz76yBRrf/89f/zss02v2r594fAePJFO+S5OY+Wc1/aiLcDgwYOLbYcwd+7cQvdfeuklXnrppRBEZT+fz8fKlStp0KCBTm4SvaZMgWHDcG/ZQu3csUaN4KmnYN06sww0d1ObM8+Exx6Diy5SsbYc/H745htTrF2wIH/8ootMsVY18NDSOV6cxsqddCW66HwoTqOcd6hAAP77X1Oo/eST/KsHK1WCa681q2oL7M0RLZTv4jRWznnDomgrIg41ZQr06WMmMAVt2QL9++ffb93aFGsvu0xVxXLIzDQf5L/wAvzxhxmLjTVzw5Ej4bTT7I1PRERERCRqZWTApEmmWLtiRf74aaeZXrXXXqueZCJSJBVtRcQePh8MG3Z0wbagmBj4+GO48krQp7JllppqWgCPH29aAoNpj3XbbTBkCDRoYG98IiIiIiJRa9UqU6j9+OP8Kwfj400vsltvNS3ftCBFREqgom0Yc7lc1K5dW7ssSnT6/HOzorYkXi/UrauCbRmtXWu6Srz/fv5VV02awIgRcNNNUKWKreHJYTrHi9Mo16U4Oh+K0yjno1hmptlQbMIEWLQof7xFC1OoHTAAata0Lz4bKN/FaazMdRVtw1hMTAwdOnSwOwyRigsEYP16mDfPfM2dCxs3lu5nt28PamjRZPFieO45+PLL/AXMZ5wBd99tulDE6IwfVnSOF6eJ0UlIiqHzoTiNcj4K/fEHvPGGWTWxZ48Zi4mBf/7TFGu7dnXsqlrluziNlXNeLV8LYz6fj9WrV2vjDok8gYDZROztt+H666FpU2jeHG68ET74wBRsS7t6tn794MYa4Xw++OILOPdc6NDB3A4EoGdPmDMHfvoJ+vVTwTYc6RwvThPJuf7qq6/SrFkzEhISaNeuHUuXLi3x8Z9//jmtWrUiISGB0047jWnTphX6fiAQYNSoUdSvX5/ExES6devG2rVrCz2mWbNmuFyuQl9PP/205ccWDnQ+FKdRzkeJ7Gxz9eAFF0DLlvDii6Zg27QpPPkkbN4Mn33m+N1+le/iNFbmuoq2Yczv97NmzRr8fr/doYiULBCANWvgzTfhmmugUSNzCdCgQaaH0+bNpmrYvj3cfz9Mnw67d5vHFTeBcbmgcWPo2DG0xxIhDh6E11+HVq2gd29z9VVcnKmL//YbfPMNdOni6Plh2NM5XpwmUnP9008/ZeTIkTzyyCMsX76c008/ne7du/P3338X+fhFixZx9dVXc9NNN7FixQp69epFr169+PXXX/Me8+yzz/Lyyy8zYcIElixZQqVKlejevTuHDh0q9FyPPfYY27dvz/saMmRIUI/VLjofitMo5yPchg3w4IOm/9hVV8Hs2WbSfcklMHUq/PknPPAA1Ktnd6RhQfkuTmNlrmvtlYiUXSAAv/+e3+5g3jzYsaPwY2JjoV076NzZVA/bt4dKlQo/Ztw4c92+y1V4Q7LcSuPYseDxBPNIIs7ff8Orr5qv3bvNWPXqZnOxwYO1MFlExGovvvgigwYN4oYbbgBgwoQJTJ06lXfffZf77rvvqMePGzeOiy66iLvvvhuAxx9/nJkzZzJ+/HgmTJhAIBBg7NixPPTQQ1x++eUAfPjhh9StW5cvv/ySfv365T1XlSpVqKc/+kVE7Ofzwbffml6106bl/+1Srx7cfLP5atrU3hhFJOqoaCsix+b3w//+Z3rRzpsH8+eb6mFB8fFwzjmmSNu5s7mdlFTy8/bubRr1DxtWeFOyRo1MwbZ3b6uPJGKtWWM2F/vgA8hdiNWsGYwcCTfcAJUr2xqeiEhUys7OZtmyZdx///15Y263m27durF48eIif2bx4sWMHDmy0Fj37t358ssvAVi/fj07duygW7dued9PTk6mXbt2LF68uFDR9umnn+bxxx+nSZMmXHPNNYwYMaLEPmlZWVlkZWXl3U9PTwcgJyeHnJycvPg9Hg8+n6/QSpDcca/XS6DAB6kejwe3213seO7z5sqNz+v1lmo8NjY2L47c53K5XMTExOD3+wtdYpg7Xlzs4XZMRcWuY9Ix5R4T5Od8tBxTNL5PgW3bcL/3Hu533sG1eXPeY/wXXIB/0CACl15KTGKiib2Ux2r7MdnwPsHR5/hIP6ZofJ90TME5popQ0TaMud1umjRpgru0vT9FrOL3wy+/5G8aNn9+/rLOXAkJZvVsly6mSNuunRkrq9694fLL8c2dy+alS2l89tl4unTRClvMB/g//ADPPw9ff53/gf5ZZ5nNxf75T/WqjWQ6x4vTRGKup6am4vP5qFu3bqHxunXrsnr16iJ/ZseOHUU+fsfhK1Jy/y3pMQBDhw7ljDPOoEaNGixatIj777+f7du38+KLLxYb75gxYxg9evRR4zNmzCDp8AepTZo0oU2bNqxatYpNmzblPaZly5a0atWKpUuXsmvXrrzxlJQUmjZtyvz589m/f3/eePv27alTpw4zZswo9AdU165dSUxMPKqPb48ePcjMzGTOnDl5YzExMfTs2ZM9hzftmTlzJmBWGJ9//vls3ryZlStX5j2+du3adOjQgbVr17JmzZq88XA8ptTU1EKFfR2TjqngMS1atAjIz/loOKaoep9at6bpn3+y57HHqLVoEe7DRR9/jRq4b7yROS1akJ57Dp85MzKOycb3adu2bYd/VTOj5pii8X3SMVl3TAsWLMAqrkDBsrADpKenk5ycTFpaGlWrVrU7HJHw4PPBzz/ntzqYPx/27i38mKQks9NVbruDs84yq2vFcrmbiz3/PCxZkj9+6aVw112mza961YpIpInEOdi2bdto2LAhixYton379nnj99xzD/PmzWNJwZP0YXFxcXzwwQdcffXVeWOvvfYao0ePZufOnSxatIhzzz2Xbdu2Ub9AT5urrroKl8vFp59+WmQs7777LrfccgsZGRnEF/P/36JW2jZu3JjU1NS837lWyOiYdEw6Jh1TMceUmor7ww9xv/02rnXr8h7n79AB/6BBuK+6CndSUmQd02FR9T7pmHRMYX5Mu3fvplatWpbMebVGK4z5fD5WrVpF69at8WjVoVjJ64WVK/PbHSxYAGlphR9TqRKcd15+u4MzzzQ7XQWJ8h0OHID33zcbz/71lxmLj4f+/U0bhFatbA1PLKacF6eJxF2ja9WqhcfjYefOnYXGd+7cWWyv2Xr16pX4+Nx/d+7cWahou3PnTlJSUoqNpV27dni9XjZs2EDLli2LfEx8fHyRBd3Y2FhiY2MLjXk8niLPPcW1Xyhu/MjnLc94IBDg119/Pep86Ha7i1yhXVzs4XRMxcWuY9IxgSkc/PLLL0flfCQfU8S+T4EAMUuWmF61n38OuR98ValiJuG33IL7tNMK7eAe9sdEeL1PxZ3jI/mYovF90jFZd0xWXl0WedepOYjf72fTpk3aZVEqLifHLNl89lno0QNq1Mi/xv6bb0zBtkoVuPhieOYZ+PFHs9J2+nS4/36zwjaIBVtwdr7v3AkPP2w2oB082BRsa9QwYxs3wptvqmAbjZyc8+JMkZjrcXFxtG3bllmzZuWN+f1+Zs2aVWjlbUHt27cv9Hgwl4TmPv64446jXr16hR6Tnp7OkiVLin1OgJUrV+J2u6lTp05FDiks6XwoTqOcDwNpaWZn39atzUKVjz82BdszzoC33oJt22D8eDjtNLsjjXjKd3EaK3NdK21FolF2Nvz0U367g4ULzTLOgpKTzXX2ue0OUlLUIDXEfv/drKr96KP8D/WbNzeragcMMIudRUTEXiNHjmTAgAGceeaZnH322YwdO5YDBw5www03ANC/f38aNmzImDFjABg2bBidO3fmhRdeoGfPnnzyySf89NNPvPnmm4BZYTd8+HCeeOIJWrRowXHHHcfDDz9MgwYN6NWrF2A2M1uyZAldu3alSpUqLF68mBEjRnDddddRvXp1W34PIiJRYdkys6p20iQ4eNCMJSbCNdfArbeaqwtFRMKEKjQi0SArC/773/x2B4sW5U9CclWvDp065bc7OP10bfZlg0DAtAx+/nmzyDlXu3Zm4XOvXnpbRETCSd++fdm1axejRo1ix44dpKSkMH369LyNxDZt2lToMrgOHTowadIkHnroIR544AFatGjBl19+yamnnpr3mHvuuYcDBw7wr3/9i3379nHeeecxffp0Eg5v6BkfH88nn3zCo48+SlZWFscddxwjRoxg5MiRoT14EZFocOAAfPKJKdb+9FP++Mknw223wXXXQbVqtoUnIlIcbUQWxnw+H2vXrqVFixbqdyiFHTpk2h3Mm2cKtYsXm7GCatY0RdouXUyR9rTTIIx37o72fPd6YcoUU6z973/NmMsFl19uNhfr0EGbizlNtOe8yJH27t1LjRo1ImIOFi0iZd6r86E4jXI+RH77zRRqP/wQ0tPNWFwc9OljVtWed54m4CGgfBensXLOq6KtSCTIzDSF2dx2Bz/+mH89fa7atfNX0XbpYj45DuMirVNkZMC778JLL8GGDWYsIQEGDoQRI+DEE+2MTkQkdDQHCz39zkXEcbKy4D//McXaBQvyx5s3h1tuMZPw2rVtC09Eop+V8y+1RwhjXq+XpUuXcvbZZxe7K51EqQMHTJE2t93B0qWmT21Bdevmr6Lt3BlOOimiPymOtnzfvh1eeQVefx327TNjtWrBHXeYL80VJdpyXuRYvF6v3SFImNL5UJxGOR8E69aZ3Xvfew9SU82Yx2Mua7v1VrjgAi1osYnyXZzGyjmv/osJY4FAgF27duGwxdDOlJEBP/yQ3+7gv/8119MX1KBB/irazp3NEs0ILtIeKVry/bff4IUXYOLE/Dp7ixZmc7H+/SEpyd74JHxES86LlJZyXYqj86E4jXLeIjk58H//Z1bVzpyZP96oEfzrX3DTTeZvKLGV8l2cxspcV9FWxA7p6bBwYX67g59+Ap+v8GMaNy7c7qB586gq0kaTQADmzDH9ar/9Nn/83HNNv9pLL9XmYiIiIiIilti8Gd56C95+21zeBubvpIsuMqtqe/QAregUkSigM5lIKOzbZ4q0ue0Oli8Hv7/wY5o1K1ykbdZMRdowl5MDkyebYu3y5WbM5YJ//tMUa9u3tzc+EREREZGo4PPBjBmm99jUqfl/S9WpY1bUDhoExx1nb4wiIhZT0TaMeTweUlJStMNiJNqzxzS+z213sHKlWY5Z0PHHF2530LSpDYGGj0jK9/37zQf7Y8fCpk1mLDERbrjBbC52wgm2hicRIpJyXsQKynUpjs6H4jTK+TLYudPs6vvmm/m7+gJ07WpW1fbqBXFxdkUnpaB8F6exMtdVtA1jbrebpg4v5EWM1FSYPz+/3cGqVUcXaVu0yF9J27mzaX8geSIh37duhZdfhjfegLQ0M1a7NgwZArfdZjYaEymtSMh5ESu5tQGMFEPnQ3Ea5fwxBAJm4cuECTBlSv5eH9Wrw8CBpl9tq1Z2RihloHwXp7FyzquibRjzer3Mnz+fTp06aZfFcPP336ZIm9vu4Ndfj35My5b5q2g7d1YT/GMI53z/5RezudikSaYlApi398474brrzCpbkbIK55wXCQYrd9KV6KLzoTiNcr4Ye/bABx+YFRJr1uSPn3OOWVV71VWaeEcg5bs4jZVzXv0XE8YCgQD79+/XLovhYMeO/FW0c+fC778f/ZiTT84v0nbqBPXqhTrKiBZu+R4IwKxZpl/td9/lj3fqZPrV9uwJWjQmFRFuOS8SbMp1KY7Oh+I0yvkCAgFYssT0qv3sMzh0yIxXrmxWR9xyC6Sk2BqiVIzyXZzGylxX0VakKFu35hdp580r/ElvrtNOy+9J26mTuU5eIl5ODnz6qSnW/vyzGXO74YorzMradu3sjU9EREREJOLt3w8TJ5oWCLmTboDTTzd9x665BqpUsS8+EZEwoKKtCMDmzfmraOfNg3XrCn/f5YLWrfNX0nbsqAamUSYtDd56C8aNgy1bzFhSktmMdvhws2+ciIiIiIhUwMqVplA7cSJkZJixhATo29cUa88+2/ztJSIiKtqGM4/HQ/v27bXLYjBs2FC43cH69YW/73aby3ByV9J27Gga30vQ2JXvmzfnby62f78Zq1sXhg41rbNq1AhpOOIgOseL0yjXpTg6H4rTOC7nDx40rQ8mTDCtEHK1bGkm3P37a9IdxRyX7+J4Vua6irZhzO12U6dOHbvDiHyBgCnK5q6inTcPNm4s/Bi3G9q2zd807LzzoFo1O6J1rFDn+8qVZnOxTz7J35D2pJNMv9prrjEf+IsEk87x4jRW7qQr0UXnQ3Eax+T877+blREffAD79pmx2Fjo3dsUazt31qpaB3BMvoscZuWcV0XbMJaTk8OMGTO48MILiY2NtTucyBEImPYGBYu0ude75/J44Kyz8ou0554LVavaEq4Yocj3QABmzDD9ar//Pn+8Sxe4+2646CJtLiaho3O8OE1OTo7dIUiY0vlQnCaqcz47G774wqyqnTs3f7xZM7Op2A03mMvaxDGiOt9FimDlnFdF2zDnzV0CKMULBMxGYQXbHWzfXvgxsbGmSJvbk7ZDB7MjqYSVYOV7djb8+99mZe0vv5gxjweuvNJsLnbmmUF5WZFj0jleRMTQ+VCcJupyfv16ePNNePdd+PtvM+Z2w6WXmlW1F16o1REOFnX5LhIiKtpK5AkEzKU2BVfS7txZ+DFxcdCuXX5P2vbtza5S4ij79pm547hxsG2bGatUCQYNgmHDzAf+IiIiIiJSDl4vTJ1qVtV+9535Ow2gfn0z4b75Zmjc2N4YRUQimIq2Ev78fvjtt/xVtPPnw65dhR8TH28Ks7ntDs45BxITbQlX7LdxoynUvvVW/qa09eubQu2//qU95UREREREym3rdOAglQAANpVJREFUVnj7bTPZ3ro1f/zCC82q2ksuMVc6iohIhbgCgdyPw5whPT2d5ORk0tLSqBrmPUwDgQD79++nSpUquJzUoN3vh1Wr8lfRzp8Pu3cXfkxioinS5rY7OPts7RwV4azI9+XLTb/azz4Dn8+MnXKK2Vzs6qtNbV8kXDj2HC+OlZaWRrVq1SJiDhYtImXeq/OhOE1E5rzfbzaFmDABvv46f7JdqxbceKNZGdG8ub0xSliKyHwXqQAr57xaaRvmEp2wWtTng59/zm93sGAB7N1b+DFJSWazsNx2B2edZVogSFQpT74HAjB9uinWzp6dP37BBaZY2727NqWV8OWIc7yISCnofChOEzE5v2sXvPcevPEG/PVX/ninTmZVbe/eWhkhxxQx+S4SZlS0DWNer5dp06bRo0eP6Npl0euFFSvy2x0sXAhpaYUfU7kynHdefruDM8/UJTZRrqz5npUFkyaZYu3//mfGPB7o189sLtamTZADFqmgqD3HixRDm5BIcXQ+FKcJ+5wPBMxCmgkT4D//Mbv6AiQnQ//+cMst5nI2kVII+3wXsZiVc14VbSX4cnJg2bL8dgcLF8L+/YUfU7WqKdLmtjs44wyIUXrK0fbuNfPHl1+GHTvMWJUq5oqsoUOhSRN74xMRERERiUj79sGHH5rJ9u+/54+fdZZZVdu3r9nVV0REQkJVMbFedjb89FN+u4MffoADBwo/plo16Ngxv91BSopZJimO5PPBvHku5s9vSKVKLrp2PTod1q+HsWPhnXfy06lhw/zNxZKTQx62iIiIiEhkCwTgv/81hdpPPoHMTDOelATXXmtW1bZta2+MIiIOpaKtVFxWFixdml+kXbQo/3/2uapXz2910LkztG6tIq0AMGWKKbxu2RIDnMmLL0KjRjBunGmR9dNP8NxzMHmy2f8ATPrcdZf5sF+tjUVEREREyigjA/79b1OsXb48f/zUU+G220zBVqsiRERs5QoEAgG7gwilSNlFF8wui16vl5iYmPDaZfHQIfjxx/x2B4sXm7GCatUyzelz2x2ceiq43baEK+FryhTo08d8wF+Qy2XGTj45v18twIUXmmJtt27aXEwiX9ie40WCxMqddKV0ImXeq/OhOI2tOf/LL6ZQ+9FH+S3r4uPhqqtMC4T27TXRFkvpHC9OY+WcVyttw1xmZiZVqlSxN4iDB02RNncl7ZIlZnVtQXXq5K+i7dIFTjpJRVopkc9nVtgW9bFR7tj//mcWZF9zjdlc7PTTQxujSLCFxTleRCQM6HwoThPSnD90CD7/3BRrFy3KHz/hBFOoHTDALLoRCRKd40XKR0XbMOb1epkzZ07od1k8cMD8z3zePFOoXbrUbCZWUL16+atoO3eGVq30iayUWk4OfPYZbNly7Mf++99w5ZXBj0kk1Gw7x4vYxMqddCW66HwoThOynP/jD3jjDXj/fdizx4zFxECvXqZY27WrFtpI0OkcL05j5ZxXRVsxl8X88EN+u4P//heOTLKGDfNX0XbuDC1aqEgrJTp0yGwe9uefsG5d4a8NG8xK29LQ3/giIiIiIqWUkwNffQWvvw6zZ+ePN2lidu+98UaoX9+++EREpNRUtHWi9HRYuDC/3cGyZUdX0Bo3zi/QdukCxx+vIq0c5cAB+Ouvo4uy69bB5s1Ftz7IFRcH2dnHfg3NKUVEREREjmHjRnjrLXjnHdixw4y5XNCjh1lVe/HF2ghaRCTCqGgb5mJiLHiL9u2DBQvy2x2sWAF+f+HHNGtWuN1Bs2Yq0gpgavxHFmRzV89u21byz1aubBZln3CC+WrePP92nTrms4CtW4su7rpc0KgRdOwYnOMSCQeWnONFRKKAzofiNJbkvM8H335retVOm5Y/qa5bF26+GQYNgqZNK/46IhWkc7xI+bgCgZLWwkWfSNlFF5/PFFq3bzdLDTt2LP0no3v2wPz5+e0OVq48uirWvHnhdgdNmlh9BBIhAgGTMgWLsQW/du0q+eerV88vxB75Vbt2ybX/KVOgT5/8OHLl/szkydC7d8WOT0REwkPEzMGiiH7nIlFq+3azovatt2DTpvzxCy4wq2ovvxzUO1RExBZWzr/0cUc4mjIFhg0rvEtTo0YwblzRFazUVFOkzW138MsvRxdpTzwxfxVt587m+cQxAgH4+++i2xisW2cWY5ekTp2ii7LNm0ONGuWPq3dvU5gtKt3HjlXBVqKb3+8nNTWVWrVq4dYmIOIA/iOv8hE5TOdDcZpy5bzfb3rUTphgetbmbvxQowbccIPpV3viicELWqScdI4Xp7FyzquibbjJXXp4ZNF161YzPnkynHtu4SLtb78d/TytWhVud6DGoFHP7zftCoprZZCRUfLPN2x4dEE2999gLs7p3dssBpgzx8u3367k4otT6No1Ri23JOr5fD4WL15Mjx49NIEVR/CVdgdKcRydD8VpypTzu3fD++/DG2/A2rX54x06wG23mb8RExKCGq9IRegcL05j5ZxXRdtw4vOZJYdFdazIHevbN/9T1YJOOSW/SNupk+ljJFHH5zNXQBXVyuDPP+HQoeJ/1uUyXTCKWjF7/PGQlBS64ziSxwOdOwc4cGArnTufroKtiIiIiEQnnw/XvHk0nD8fV6VK0LXr0W3wAgFYtMisqv38c8jKMuNVqsD118Mtt0Dr1qGPXUREQkpF23CyYEHha8SLkluwbd06fxVtp06meahEhZwc2LCh6DYG69eb7xfH44Hjjiu6MNusGcTHh+ooRERERESkkMNt8GK2bOFMgBdfLNwGLy0NPv7YFGt//TX/59q0Matqr77a7PQrIiKOoKJtONm+vXSPe+MN07NIItahQ/DXX0W3Mti40ayoLU5cXH7rgiNbGTRpErl7DrhcLqpUqYKrpJ3LRKKIcl6cRrkuxdH5UByhpDZ4V1xhNhFbvBgOHjTjiYmmSHvrrXDmmSXv7isSxnSOF6exMtddgUBR1+JHr7DeRXfuXHN5zLHMmWNaIUhYy8gwRdgj2xisW2cWVJf0X15S0tGF2dyvhg2PvoJKREQk3IX1HCxK6XcuEiZ8PnPZ27GuqgQ46SSzqvb666FatWBHJiIiFrNy/qWVtuGkY0dzeczWrUVX9Fwu8/2OHUMfmxRp376ii7Lr1sGOHSX/bNWqRRdlTzgB6tVz3ofpfr+fzZs307hxYzWoF0dQzovTWLmTrkQXnQ8lKgUCpt3Brl0wfXrpCrZjx8LQoc77Q0Cims7x4jRWznlVtA0nHo/pZ9Snj/kfdcHCbe7/uMeO1TLLEAoEzIatRRVl160z3ytJzZpFF2WbN4datTQfK8jn87Fy5UoaNGig/5mLIyjnxWms3ElXoovOhxIxDhyAv/82hdi//87/Kni/4O2SNqMoSp06+gNBoo7O8eI0Vs55VbQNN717w+TJMGxY4U9jGzUyBdvevW0LLVoFAmZVbG5P2SMLs2lpJf98vXpFtzJo3hyqVw/NMYiIiIiISBllZRVfcC2qGJvbb7YsqlQxm4eVZv+S+vXL/vwiIhK1VLQNR717w+WX450zh5XffkvKxRcT07WrVthWgN9vuk4UtVr2zz/Nh+YladSo6BWzxx9v5mEiIiIiImIzrxdSU0te/Vrwdnp62V8jIcGsiK1TB2rXzr995P3atc1XYmJ+T1u1wRMRkTJQ0TZceTzQpQvZSUlw9tkq2JaC1wubNhVdmP3rL/NBenHcbmjatOjC7HHHmbmWBJfL5aJ27draVVQcQzkvTqNcl+LofCjF8vth797StSLYtevYvcuKEhNTdMG1uGJspUplb2GgNnjiYDrHi9NYmeuuQKCkPeyjj3bRjWzZ2bB+feFVsrm31683hdvixMSYlbG5rQsKFmabNYO4uJAdhoiIiONoDhZ6+p1L2AkEzOrWkoqvBe+npppVqmXhcpnNI4pb/Xrk7WrVQtdHdsqUo9vgNW6sNngiIlHEyvmXVtqGMZ/Px9q1a2nRogUeB33qmplpVsYWtWJ20ybzgXtx4uOL7i97wglmPhSjjA9bTs13cS7lvDiNNiKT4uh8GOEOHChdK4Lc+9nZZX+N6tWP3Yog93aNGuG7YvVwGzzf3LnsWLmSeikpeLp0Cd94RSygc7w4jTYicwi/38+aNWto3rx51J3c9u8vetOvdetMq6eSVKpUdFG2eXNo2NC0OpDIE835LlIU5bw4jb+kT13F0XQ+DDO5m3Mda1Ou3Nvl2ZyrcuVjF19z79eqFV2XxHk8+Dt14qeMDHp06qScl6inc7w4jZVzXhVtJWj27j26hUHu186dJf9scjK0aFF0K4O6dUN3BZOIiIiISETzek2v19KuhE1LK/trxMebSfqxWhHk3taGESIiIsekoq2UWyBg2kwVtVp23TrYs6fkn69Vq+gVsyecYK5qUmFWREREROQIBTfnKk1bgj17Cm98VRoxMaUvwNapY1bOavIuIiJiKRVtw5jb7aZJkya4bbzePxCA7duLL8zu31/yz9evX3wrg+Tk0ByDRIZwyHeRUFLOi9Mo16U4jj8fBgJmUn2sTblyb+/aVf7NuY7ViiD3dig353Igx+e8OIryXZzGylx3BQJl/dg1smkX3aP5fGYD09xCbMF2Bn/+WXKbKpfLbPBVsBhb8HalSqE7DhEREQlfmoOFnn7nNjp4sHStCHJvl2dzrmrVSt8XtmZNbXYlIiISAlbOv7TSNkz5fDB3ro+lSzdz9tmN6dLFU6F5ltcLGzcWvVr2r79Knid6PNC0adErZo87DhISyh+XSC6fz8eqVato3bq1GtSLIyjnxWms3ElXoojPh2/uXDYvXUrjs8/G06VLeBYXs7NL14og9/aBA2V/jdzNuUrTliDaNudyGM0BxEmU7+I0Vs55VbQNQ1OmwLBhsGWLB2gGQKNGMG4c9O5d/M9lZcH69UUXZjduNIXb4sTGwvHHF12YbdrUfF8kmPx+P5s2beLUU0/V/8zFEZTz4jRW7qQrUeLwpNezZcvhGS+lm/RaIXdzrtKuhC3v5lylaUVQu7b5Skqy/jglLGkOIE6ifBensXLOq6JtmJkyBfr0OXqvgK1bzfjEiXDqqUW3Mti0qeQ9BhISiu4te8IJpsWBzp8iIiIiEhLHmvROnly2wq3fD/v2HXslbO793bvLvjmXx1O6VgS5t7U5l4iIiFSAirZhxOczK2yLmj/mjl1zTcnPUbly0atlTzjBbAqm3t8iIiIiYqtjTXpdLhg+HLp2hT17St6UK/d2amrJl5UVxeUyvV5L2xe2WjVNpkVERCRkVLQNIwsWmA3BjqVSJTj55KILs7Vr6wN9iUxut5uWLVtqV1FxDOW8OI1yXfIca9IbCMDmzVCjRtmfu1q1Y7ciyL2tzbnEJpoDiJMo38VprMx1FW3DyPbtpXvcW2/B1VcHNxaRUPN4PLRq1cruMERCRjkvTqM+dpKntJNeMKsVytIXVptzSQTQHECcRPkuTmPlnFdF2zBSv761jxOJJF6vl6VLl3L22WcTE6NTk0Q/5bw4jbesl65L9CrtZPbbb+Gii4Ibi4gNNAcQJ1G+i9NYOefV+vQw0rGj2TC3uPYGLpfZMKxjx9DGJRIKgUCAXbt2ESjrpiAiEUo5L06jXJc8pZ30/uMfoY1LJEQ0BxAnUb6L01iZ6yrahhGPB8aNM7ePnMPm3h87Vq23RERERCSCadIrIiIickwq2oaZ3r1h8mRo2LDweKNGZrx3b3viEhERERGxjCa9IiIiEmV8Pli4sJgricpBDUXCUO/ecPnlMG+en19/3c2pp9akc2e3FhtIVPN4PKSkpGijGnEM5bw4jXJdjnJ40uufN4/dv/5KzVNPxd25s1bYStTTHECcRPkuTjFlCgwbBlu2WFdqdQUc1lgkPT2d5ORk0tLSqFq1qt3hiIiIiDiC5mChp9+5iIiISPBNmQJ9+oCpsKYD1sy/1B4hjHm9XmbPnq3dlsURlO/iNMp5cRrluhRH50NxGuW8OIXP72PWn7N46JOHmPXnLHx+n90hiVjO5zMrbIOxJFbtEcJYIBBg//792mVRHEH5Lk6jnBenUa5LcXQ+FKdRzosTTPl9CsOmD2NL+hYAnlzzJI2qNmLcRePofZL6lkv4CgTg4EHYu7d0Xxs2wJYtwYlFRVsREREREREREbHElN+n0OezPgQo/MHE1vSt9PmsD5OvmqzCrQRVSYXXPXuOXYjNybH7CAwVbUVEREREREREpMwCgQBZviwOZB8gIzuDtENp3Db1tqMKtkDe2K3f3ErNxJpUiqtEQkxCkV8xbpWrnC4QgAMHSr/i1erCq8cD1asf+2vrVnjkEWuO+UjaiCyM+f1+UlNTqVWrFm632g9LdFO+i9Mo58Vp9u3bR/Xq1SNiDhYtImXeq/OhOI1yXuxwZHE1IzuDAzkFbh8eP2osp/D3Cz4md8wXsL5XrcflKbagmxibePS4p+jHlvar4HPGe+LxuD2WH5MTlbXwWnAV7L591hZea9QoXRE296tyZXC5jv0aPh80a2aKt1ZvRKaPLsKY2+2mTp06dochEhLKd3Ea5bw4jQoTUhydD8VplPNSkkgrrhYU74kn1h1LRk7GMR9br1I9Yj2xHPIeyvvK8edX6HwBHwdyDnAg50AwQy5WrDu2QkXghJgEEmOKKC6X4is+Jh63K3zmTbmF19K0FSjqq6J7LsbElK3YWp7Ca0V4PDBuHPTpY17LyqWxKtqGsZycHGbMmMGFF15IbGys3eGIBJXyXZxGOS9OkxMuzcEk7Oh8KE6jnI8OJRVXSyy45hyj+Bqi4mrluMpUjqtMpbhK+bdjj75d7PePGK8UV4kYdwxzN8yl6wddjxnDv/v8my7NuhQa8/l9ZPmyChVyS/OVmZNZ9Pd8ZXserz+/upjjzyEnO4f92fut/vWXSpwnztJicLwnAbIrk3OgMtkZlcg+UIlD+xPJ3J9A5v54DqTHkZEeS3paDOn73Ozd6wpZ4fVYK2ArVQp+4bWieveGyZNh2LCApZuSqWgb5rwV/S9DJIIo38VplPMiIobOh+I0yvnQUXG16OJqsHRs0pFGVRuxNX1rkX1tXbhoVLURHZt0POp7HreHJHcSSbFJQYuvJF6/lyxv4aJxpreYgrAFX0c+tz/gz4sl25dNti+b9Kz0/AADQHZlyKwOh6oX+DcJMisfMXbEv4eqgb9iHxK5PDnEJO0nplIGcZUPkFD5IAmVM0msmkVSlSwqV82mclUvVap5qVYtQHI1/+GirIvkKjEkxpauwBzrjsUV7lXaopw0hcDwEbCsMfzbmqdU0VZEREREREQkBHx+H/M2zmP+3vlU2liJrsd3Ve/Mw1RcDW1xNVg8bg/jLhpHn8/64MJVqHDrwhTixl40NizzPsYdQ0xcDJXiKgXtNQIB2L+/qF6uAXbv8ZO620vqHj979voPf89F2j436WkeMtI9+LwVa5vg8uTgTkrDnZgGiXsJJOwlEL8bX/xuSNwDCXshce/hf/cUuL2XQOxBclyQA2QCacW9iA/YffirPDHisq2fcXk3wZvy+xT6fNbH5HvTTeU78CJE3hlARERERESigs8H8+a5mD+/IZUqueja1fSGE4lGU36fwrDpw9iSbq6dfXHjizSq2ohxF42j90m9bY6u9FRcjY7iajD1Pqk3k6+aXCjfARpVbcTYi8ZGVL4XpbjCa2k31/IVmeYuwHP4q2SxsWXfVCv3KykpFperFlDriGMKkOPPsaYtRTlaUxzyHsqPhQCZ3kwyvZkVeJfKr6RN8Ir6ivPE8dlvnxW5sryiXIGAlS1yw1+k7KIL5j+a/fv3U6VKlchcGi5SBsp3cRrlvDhNWloa1apVi4g5WLQI93nvlCkwbBiFer81amQ28+gd2X/Pixyl0CqsAnJXHk6+arLlhSwVV1VctVt2jo9XJ69i9bo0Wp2QzB19WhMXGx6fzOUWXsuzuVbxhdfSi4sr/+ZaSUnh3+O1rAKBANm+7JC0pSjq+bN92dYdzCHgaSyZf+mMFeYSExPtDkEkZJTv4jTKeRFxqilTzC7LRy4f2brVjE+erMKtRA+f38ew6cOKXIUVIIALF8OmD+O8xueR6c0sskhaXHG1xO+ruCo2Mh/MediypU3e2Iv3WPvBXCAA6ellL7oGq/BaltWviYnRV3itCJfLRXxMPPEx8SSTHPLX9wf8R/UzLu3Xki1L+PR/nwYlLq20DWM5OTlMmzaNHj16aFdRiXrKd3Ea5bw4ze7du6lVq1ZEzMGiRbjOe30+aNaMYndXdrnMitv169UqQezl8/vMJbo55jLdgzkH825n5hy+X4rvb0zbyNwNc209FhVXJZSK+2Aut0hZ8IM5v7/4VgPHWgW7b5/5+YqIjy//ilcVXgVg7oa5dP2ga/6AVtqKiIiISKTw+X0s3LTQ7jAkTCxYUHzBFswf+Zs3w9ChcMopkJBgvuLji75d1PdiY/WHdDQKBAJ5l7VWtJBamsdberlsKam4KuEkEACvFw4dgqws82/B20f+e+gQZGbCnXceXbDNfT6Afv2gcWNTdA1V4bW4VbC68E0qqmOTjjSq2oit6Vst72urM3eY0q6iIiLRS+d4cZK8jXf+LqFKJ46yfXvpHvfaa+V/DZerdMXd0haBy/O4+PjoLxznblxTrqJpwe+X4fF2iffEkxibSGJMIomxiSTFJuXdTow5fD/3+0fc35K+hfH/HX/M1/j++u+54PgLQnA0Egl8vtIVSUszVpHHB+Pa7Jwc+OuvwmMJCRVb8SpiF4/bw7iLxtHnsz64cFlauFXRNgxFy66iIqWlApY4ic7x4iTFbbwjzla/fuke160bJCcfXUAortiQXWBRZCBgVntl2lfjA/ILuaEsFsfG+QjEZBLwZOL3ZHLIF7yVqJk5mUHvmVocj8tTqDBaqGhaXFG1uCLrMYqwCTEJFZqb+vw+vlzzpVmF5XfBxo6QUR8qb4emC3C5AzSq2oguzbpY9wuScvP77S2S5t6uaL/VYIiNLfr8U/DfPXvg55+P/VyPPAJXXpm/AjYhIfjxiwRL75N6M/mqyebvvEPWLVRQT9swY8euoiJ2OrKABaiAJVFL53gJV/6A/5hfPr+v5O8HCn8/x5fDxRMvZueBneZFLOzvJaUTrvPevJ62WwMQKGIpqitA40auMve0zS20FFX8KG3h14rHhR13NsQcgpisw/8eAk+B2yV+ryw/k0VCIiTEu0hMcJGU5CYxwUWlxBiS4uNKXIl6zJWrRXw/1hNZ/eCn/D6FKx6dCNPHQnrj/G9U3QwXDec/j17r+DlAIGA+fLGrSJr7b06O3b+Jo7ndRX9wU1IBtbiiakUe73YfO9a5c6Fr12M+jDlzoEuXiv5mRMKLz+/j21+/5dLTL7Vk/qWibRjx+X00G9esUPGqIBcuGlVtxPph67UKUaKCCljiJJF8jg8EAiUW5ypa3AvX5zjWz5fqOSIkzpBQ0Tbkwnnee8/LP/LcsLMP3ytYBTD5ePe4pTw79JygvHYwL+k/mJPJwUwvBw/5yDwUIDMzQFZWALwJBb7i82/74oseL+/3fPGQk3jE79R+Hk9oVhmHc5/jKVPgij6Bw9eaH5HzLhf/mezK25gp1Ar2LbWjSJp7OyvLnuMvyZGtVoJZEC1pLCaCrpHO/WBu69aiWytos0mJdmlpaVSrVk0bkUWbBZsWFPvHPECAAJvTN3PNf66hSXITXIdnHS5cuFyuo/4t6Xu5RTF9r2zfq8jvPFq/V14+v49h04cVeclsgAAuXAyfPpzLW15e5gJWIBAgQKDYf0vzmEj4N/d3FQ6xVOTfaHk/jjqWI76/PWN7qc7x5713HjUSawSviFiOAqFILhcu3C53sV8etyfvdpY3i7SsNLtDljDk8/v4t+9KuOosmD7uiFWHW+CiEXyUs4irtv4f2f7syL2k3wNUPvx1WIw7pvQrS2PKtxI1ISaRWBLBm4A3x2PbqmOvt8B77oMDB8yXXezscxwbazbWMyvLj5w/u3Fhvn/WWWaVpx09TMNxKVdub+hQrCgt7jnsLvZHIo8Hxo2DPn3M765gbuX+LseOVcFWope34P8AK0hF2zCyfX/pdmX47H+fBTkSkbIpTyHcF/BxyHuo2OfMLWBVfqoybre7TEUzkUj245Yf7Q6hQkoq6LldbjwuT6mLfrb8fLjHZ/PPl+XDurkb5tL1g1JcHymOk7dQ4eQt0Oqro/p74vazIwPOevusoMfiwlXuzaXKenl/JF7SXxFeb8ntKkJRSA7HPsdFCQTMqsQmTeyOxDhW39JQrDKNi1OxNJL17g2TJ8OwYbClwJqFRo1MwdauVeUikSYsiravvvoqzz33HDt27OD000/nlVde4eyzzy728Z9//jkPP/wwGzZsoEWLFjzzzDP06NEjhBEHR/0qpduVoe8pfWlctfExV6iV9L0SV+kF63kr8L2KxBRp8VoRU6gVPKbDA5Y55DsEYdiEv7jVyKH+N5xiqci/UPIK9xL/DYP4S3Mcf+39iwnLJhwzt+5qfxcn1z45bIt2Jf187vGKAHRs0pFGVRuZjXf0oZoUUGihgtsPx80r8nHV4qtRM6lm0DaXSoxNJN4Tr/NWkMTEmK9KleyLoagNpexYdVxaLhckJYX2svuixkrTt1TkWHr3hssvhzlzvHz77UouvjiFrl1jtMJWpAxsL9p++umnjBw5kgkTJtCuXbv/b+/uw6qu7z+Ov86Bw403aKigJMesTFFHTKcObd4khkoOw5l41SUzV9da3k3LbJt3paKrbbh5Zc7Z5JpyaXKF01ooulBLMi0x1oS0nDcpmamgKKjnnN8f/DxJQFoeON9zvs/HdfmHH77g+815+73evM+H70cZGRlKTExUSUmJIiIial2/a9cujR07Vunp6XrwwQeVlZWlkSNH6sMPP1T37t29kIHn3OiHG4uqn3e4JmWN4Z53COMx+oD6vePv6dGcR2+Yx+qHVis+Ot44A0J+sMP35HA69MbBN254j1+UsIh7PPxCgDVAS4Yu0c9e+5kssvj04NbTGwxcLpfmzJmjFStW6Ny5c+rXr5+WLVumTp06ua85c+aMJk2apE2bNslqtWrUqFFasmSJmjVrVtc/6VNudqNCTmqOBt4xsGGDgV+zWqXQ0Oo/3uJySXl5UmLija/9978lDmaCPwkIkAYMcKmq6gsNGOBiYAt8R14/iKxPnz7q1auXli5dKklyOp2Kjo7WpEmTNHPmzFrXjxkzRhUVFXrjjTfcaz/+8Y8VFxenV1658Q4mIx/IIH19MJNUc8fktUETBzPBX1w7lOlGAywjHsoEfF/c42FGrx94XVNyp+j4qeM+eRDZunXrNG7cuBobDNavX/+tGwz69+9fY4PB4sWLa2wwWLx4sdLT05WZmamOHTtq1qxZKioq0n//+1+FhIRIkoYNG6aTJ09q+fLlunLlisaPH69evXopKyvrpmM3at9LDwCz4WAmADAPT/ZfXh3aXr58WU2aNFF2drZGjhzpXk9LS9O5c+f0z3/+s9bn2O12TZs2TVOnTnWvzZkzRxs2bND+/ftrXV9VVaWq634npby8XNHR0Tp9+rT7m2e1WhUQECCHwyGn8+sDV66tX716Vdd/mwICAmS1Wutdv3LlSo0YAv//qMdvPoy4vvVNhzZpyltTdPz81w9/aR/WXkuGLtHIziPlcHz9u+IWi0WBgYH1xm6UnGw2m5xOZ52x17dOTv6fU05xjlJfT5VU/wBrxN0jfConyf9eJ3LybE4bD26s9x6ffE+yT+bkj68TOXk2p8qqSr318VtK6ZliuAHijXh6g4HL5VJUVJSmT5+up59+WlL1IDsyMlKrVq1SamqqDhw4oK5du2rPnj360Y9+JEnKzc3V8OHDdfz4cUVFRd1U7EYd2kq8iQXzef316oOZpLoPZsrO5jmf8E9Op1OnT59W69atZeXZGzCBc+fO6bbbbvNI/+XVxyOcPn1aDodDkZGRNdYjIyNVXFxc5+eUlpbWeX1paWmd16enp2vevHm11rds2aImTZpIqh4E//CHP9RHH32ko0ePuq/p3LmzunTpovfff19ffvmlez0uLk4dOnTQjh07dP78efd6fHy8IiIitGXLlho/QA0aNEihoaH617/+VSOG4cOH69KlS3r77bfda4GBgUpJSlH8bfFauXWlzl49q9sCb1Ofdn00JGaIjhw5osLCQvf1bdq0Ud++fXXw4EGVlJS4142WU1JSkk6fPq2CggL3evPmzXX//ffr2LFj5GTSnIIVrGWDlmn+B/OrDyT5f61srfT7+3+vlJgUvfnmmz6Vk+R/rxM5eTanlPtTFBcSp6x3s9z3+AEdB+gnMT9RcXGxT+bkj68TOXk2p21bt6myvP7DJ43q8uXL+uCDD/Tcc8+516xWqxISEmp8f69XUFCgadOm1VhLTEzUhg0bJEmHDx9WaWmpEhIS3B9v0aKF+vTpo4KCAqWmpqqgoEAtW7Z0D2wlKSEhQVarVbt379ZDDz3kwSy9IyUmRdkPZ1fvwi6v+SZWxtAMBrbwOxzMBLNyOBwqKCjQ8OHDGdrCFK7fIHGrvLrT9sSJE7r99tu1a9cuxcfHu9dnzJih7du3a/fu3bU+JygoSJmZmRo7dqx77eWXX9a8efP0xRdf1LreF3faXtshU1lZqby8PA0ZMkRBQUHs+iEnv83JJZe2HtyqLQVbNOTHQzSw40AF2YJ8Oid/fJ3IybM51XWP9/Wc/PF1IifP5fTVV1+pXbt2htz1WZ+G6FV37dqlfv366cSJE2rX7utnuz788MOyWCxat26dFi5cqMzMzBqDdEmKiIjQvHnz9OSTT9YZry/2vdYAq7Yd2qbNuza7e4DAgECv1+ut5GTU/4PkZIycXC6rtm6t0pYtRRoypLsGDgxQUJBv5+SPrxM5eS6nqqoq5ebmasiQIbLZbH6Rkz++TuTkuZxKS0s91vN6dadt69atFRAQUGvY+sUXX6ht27Z1fk7btm2/0/XBwcEKDg6utW6z2WSz2WqsBQQEKCCg9kOErhXHza5/8+t+n3Wr1epet9ls7n/LarXW+e5UfbEbLae6YicncpKk+++8X5XFlRp81+AacflyTv74OpGT53Kq6x7v6zn54+tETp7Lqb744Tm++BtmSUlJ6t6suy7ddklVJVXaXLLZ6zvD/XG3OzkZK6fAwHfUv/8FVVV9rs2b/SMnf3ydyMkzOX3++eeSpLy8PL/JyR9fJ3LyXE47d+6UpxjiILLevXvrL3/5i6Tq553Y7XZNnDix3ueEXbx4UZs2bXKv9e3bV7GxsX5xENn1rl69qh07dqh///71/hAE+AvqHWZDzcNszpw5o1atWvlED3ZNQ5y/8Nlnn+muu+7Svn37FBcX575mwIABiouL05IlS/Tqq69q+vTpOnv2rPvjV69eVUhIiNavX1/v4xF8caetzWbT5cuXtXPnTvXt21eBgYFe3yHjj7t+yMlYOVVWVurdd99117w/5OSPrxM5eSan+u7xvpyTP75O5OS5nK5tLPX5g8ik6hN509LStHz5cvXu3VsZGRl67bXXVFxcrMjISI0bN06333670tPTJVWfyDtgwAAtWrRISUlJWrt2rRYuXFjjRN5v40tDWwAAAH/hqz2YpzcYXDuI7Omnn9b06dMlVX9vIiIiah1EtnfvXvXs2VNS9W7ZoUOH+s1BZAAAAP7Ik/2X158CPWbMGL300kuaPXu24uLiVFhYqNzcXPdhY0ePHtXJkyfd1/ft21dZWVn661//qnvvvVfZ2dnasGHDTQ1sfY3T6dSRI0dqTPoBf0W9w2yoeZiNr9b6tGnTtGLFCmVmZurAgQN68sknVVFRofHjx0uSxo0bV+OgsilTpig3N1d/+MMfVFxcrLlz52rv3r2aOHGipOpdHlOnTtX8+fO1ceNGFRUVady4cYqKinLv5o2JidHQoUP1+OOP6/3339e7776riRMnKjU19aYHtr6E+yHMhpqHmVDvMBtP1rohfh9z4sSJ7kb2m/Lz82utjR49WqNHj27gqLzP4XCosLBQUVFRdT5/DvAn1DvMhpqH2Vz/62m+ZMyYMfryyy81e/ZslZaWKi4urtYGg+v/D1/bYPC73/1Ov/nNb9SpU6daGwxmzJihiooKPfHEEzp37pzuu+8+5ebmKiQkxH3NmjVrNHHiRA0ePFhWq1WjRo3Sn//858ZLvBFxP4TZUPMwE+odZuPJntcQQ1sAAADAqDy9wcBisej555/X888/X+814eHhysrK+s6xAgAAwD/wNgcAAAAAAAAAGAhDWwOzWCxq06aNLBaLt0MBGhz1DrOh5mE21Drqw/0QZkPNw0yod5iNJ2vd4nK5XB77aj6AU3QBAAAaHz1Y4+N7DgAA0Lg82X+x09bAHA6HiouLffbgDuC7oN5hNtQ8zIZaR324H8JsqHmYCfUOs/FkrTO0NTCn06mSkhI5nU5vhwI0OOodZkPNw2yoddSH+yHMhpqHmVDvMBtP1jpDWwAAAAAAAAAwEIa2AAAAAAAAAGAgDG0NzGq1ym63y2rlZYL/o95hNtQ8zIZaR324H8JsqHmYCfUOs/FkrVtcLpfLY1/NB3CKLgAAQOOjB2t8fM8BAAAalyf7L97qMDCHw6F9+/ZxyiJMgXqH2VDzMBtqHfXhfgizoeZhJtQ7zMaTtc7Q1sCcTqeOHj3KKYswBeodZkPNw2yoddSH+yHMhpqHmVDvMBtP1jpDWwAAAAAAAAAwkEBvB9DYrj3Ct7y83MuR3NiVK1d08eJFlZeXy2azeTscoEFR7zAbah5mc/78eUlf92JoeL7S93I/hNlQ8zAT6h1m48me13RD22vfvOjoaC9HAgAAYD5fffWVWrRo4e0wTIG+FwAAwDs80fNaXCbb7uB0OnXixAk1b95cFovF2+F8q/LyckVHR+vYsWOc+Au/R73DbKh5mE1ZWZnsdrvOnj2rli1bejscU/CVvpf7IcyGmoeZUO8wG0/2vKbbaWu1WtW+fXtvh/GdhIWFcXODaVDvMBtqHmZjtXKkQmPxtb6X+yHMhpqHmVDvMBtP9Lx0zQAAAAAAAABgIAxtAQAAAAAAAMBAGNoaWHBwsObMmaPg4GBvhwI0OOodZkPNw2yoedSH2oDZUPMwE+odZuPJmjfdQWQAAAAAAAAAYGTstAUAAAAAAAAAA2FoCwAAAAAAAAAGwtAWAAAAAAAAAAyEoa3BLVq0SBaLRVOnTvV2KECDcDgcmjVrljp27KjQ0FDdddddeuGFF8TjtuEvduzYoREjRigqKkoWi0UbNmyodc2BAwf005/+VC1atFDTpk3Vq1cvHT16tPGDBW7RsmXLFBsbq7CwMIWFhSk+Pl5vvfWWJOnMmTOaNGmSOnfurNDQUNntdk2ePFllZWVejhpGQM8Lf0fPC39HzwuzaYy+N7AhAodn7NmzR8uXL1dsbKy3QwEazOLFi7Vs2TJlZmaqW7du2rt3r8aPH68WLVpo8uTJ3g4PuGUVFRW699579dhjjyklJaXWxz/99FPdd999mjBhgubNm6ewsDB9/PHHCgkJ8UK0wK1p3769Fi1apE6dOsnlcikzM1PJycnat2+fXC6XTpw4oZdeekldu3bVkSNH9Mtf/lInTpxQdna2t0OHF9HzwgzoeeHv6HlhNo3R91pcvLVnSBcuXFCPHj308ssva/78+YqLi1NGRoa3wwI87sEHH1RkZKRWrlzpXhs1apRCQ0O1evVqL0YGeJ7FYlFOTo5GjhzpXktNTZXNZtM//vEP7wUGNKDw8HC9+OKLmjBhQq2PrV+/Xo8++qgqKioUGMheAjOi54VZ0PPCTOh5YVae7nt5PIJBPfXUU0pKSlJCQoK3QwEaVN++fbVt2zZ98sknkqT9+/frnXfe0bBhw7wcGdDwnE6n3nzzTd1zzz1KTExURESE+vTpU+evkwG+xuFwaO3ataqoqFB8fHyd15SVlSksLIyBrYnR88Is6HlhZvS88HcN1ffSIRvQ2rVr9eGHH2rPnj3eDgVocDNnzlR5ebm6dOmigIAAORwOLViwQI888oi3QwMa3KlTp3ThwgUtWrRI8+fP1+LFi5Wbm6uUlBS9/fbbGjBggLdDBL6zoqIixcfHq7KyUs2aNVNOTo66du1a67rTp0/rhRde0BNPPOGFKGEE9LwwE3pemBk9L/xVQ/e9DG0N5tixY5oyZYry8vJ4tgtM4bXXXtOaNWuUlZWlbt26qbCwUFOnTlVUVJTS0tK8HR7QoJxOpyQpOTlZv/71ryVJcXFx2rVrl1555RUaWPikzp07q7CwUGVlZcrOzlZaWpq2b99eo4EtLy9XUlKSunbtqrlz53ovWHgNPS/Mhp4XZkbPC3/V0H0vQ1uD+eCDD3Tq1Cn16NHDveZwOLRjxw4tXbpUVVVVCggI8GKEgGc988wzmjlzplJTUyVJP/jBD3TkyBGlp6fTwMLvtW7dWoGBgbXejY2JidE777zjpaiAWxMUFKS7775bktSzZ0/t2bNHS5Ys0fLlyyVJ58+f19ChQ9W8eXPl5OTIZrN5M1x4CT0vzIaeF2ZGzwt/1dB9L0Nbgxk8eLCKiopqrI0fP15dunTRs88+S/MKv3Px4kVZrTUfrx0QEOB+NxbwZ0FBQerVq5dKSkpqrH/yySfq0KGDl6ICPMvpdKqqqkpS9U6DxMREBQcHa+PGjeywNDF6XpgNPS/MjJ4XZuHpvpehrcE0b95c3bt3r7HWtGlTtWrVqtY64A9GjBihBQsWyG63q1u3btq3b5/++Mc/6rHHHvN2aIBHXLhwQYcOHXL//fDhwyosLFR4eLjsdrueeeYZjRkzRv3799egQYOUm5urTZs2KT8/33tBA9/Tc889p2HDhslut+v8+fPKyspSfn6+Nm/erPLycj3wwAO6ePGiVq9erfLycpWXl0uS2rRpw5DOZOh5YTb0vPB39Lwwm8boey0ul8vVkEng1g0cOFBxcXHKyMjwdiiAx50/f16zZs1STk6OTp06paioKI0dO1azZ89WUFCQt8MDbll+fr4GDRpUaz0tLU2rVq2SJL366qtKT0/X8ePH1blzZ82bN0/JycmNHClw6yZMmKBt27bp5MmTatGihWJjY/Xss89qyJAh9f5fkKp/sLvjjjsaN1gYDj0v/Bk9L/wdPS/MpjH6Xoa2AAAAAAAAAGAg1htfAgAAAAAAAABoLAxtAQAAAAAAAMBAGNoCAAAAAAAAgIEwtAUAAAAAAAAAA2FoCwAAAAAAAAAGwtAWAAAAAAAAAAyEoS0AAAAAAAAAGAhDWwAAAAAAAAAwEIa2AOAD7rjjDmVkZHzrNRaLRRs2bGiUeAAAAICGQN8LANUCvR0AAODG9uzZo6ZNm3o7DAAAAKBB0fcCQDWGtgDgA9q0aePtEAAAAIAGR98LANV4PAIANJKBAwdq8uTJmjFjhsLDw9W2bVvNnTtXkuRyuTR37lzZ7XYFBwcrKipKkydPdn/uN39N7ODBg+rfv79CQkLUtWtX5eXl1fr3jh07pocfflgtW7ZUeHi4kpOT9b///a+BswQAAIDZ0fcCwK1jpy0ANKLMzExNmzZNu3fvVkFBgX7+85+rX79+Kisr05/+9CetXbtW3bp1U2lpqfbv31/n13A6nUpJSVFkZKR2796tsrIyTZ06tcY1V65cUWJiouLj47Vz504FBgZq/vz5Gjp0qD766CMFBQU1QrYAAAAwK/peALg1DG0BoBHFxsZqzpw5kqROnTpp6dKl2rZtmyIiItS2bVslJCTIZrPJbrerd+/edX6NrVu3qri4WJs3b1ZUVJQkaeHChRo2bJj7mnXr1snpdOpvf/ubLBaLJOnvf/+7WrZsqfz8fD3wwAMNnCkAAADMjL4XAG4Nj0cAgEYUGxtb4+/t2rXTqVOnNHr0aF26dEl33nmnHn/8ceXk5Ojq1at1fo0DBw4oOjra3bhKUnx8fI1r9u/fr0OHDql58+Zq1qyZmjVrpvDwcFVWVurTTz/1fGIAAADAdeh7AeDWsNMWABqRzWar8XeLxSKn06no6GiVlJRo69atysvL069+9Su9+OKL2r59e63PuRkXLlxQz549tWbNmlof43AHAAAANDT6XgC4NQxtAcAgQkNDNWLECI0YMUJPPfWUunTpoqKiIvXo0aPGdTExMTp27JhOnjypdu3aSZLee++9Gtf06NFD69atU0REhMLCwhotBwAAAOBG6HsB4MZ4PAIAGMCqVau0cuVK/ec//9Fnn32m1atXKzQ0VB06dKh1bUJCgu655x6lpaVp//792rlzp37729/WuOaRRx5R69atlZycrJ07d+rw4cPKz8/X5MmTdfz48cZKCwAAAKiBvhcAbg5DWwAwgJYtW2rFihXq16+fYmNjtXXrVm3atEmtWrWqda3ValVOTo4uXbqk3r176xe/+IUWLFhQ45omTZpox44dstvtSklJUUxMjCZMmKDKykp2IAAAAMBr6HsB4OZYXC6Xy9tBAAAAAAAAAACqsdMWAAAAAAAAAAyEoS0AAAAAAAAAGAhDWwAAAAAAAAAwEIa2AAAAAAAAAGAgDG0BAAAAAAAAwEAY2gIAAAAAAACAgTC0BQAAAAAAAAADYWgLAAAAAAAAAAbC0BYAAAAAAAAADIShLQAAAAAAAAAYCENbAAAAAAAAADAQhrYAAAAAAAAAYCD/B/03Pgd/8pHZAAAAAElFTkSuQmCC", - "text/plain": [ - "
" + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install healpy matplotlib seaborn &> /dev/null" ] - }, - "metadata": {}, - "output_type": "display_data" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Short comparaison between the pure JAX implementation and the CUDA implementation of the S2FFT algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "from jax import numpy as jnp\n", + "import argparse\n", + "import time\n", + "from time import perf_counter\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "from s2fft.utils.healpix_ffts import healpix_fft_jax, healpix_ifft_jax, healpix_fft_cuda, healpix_ifft_cuda\n", + "from s2fft.sampling.reindex import flm_2d_to_hp_fast, flm_hp_to_2d_fast\n", + "import numpy as np\n", + "import s2fft \n", + "from s2fft import forward , inverse\n", + "import healpy as hp\n", + "import numpy as np\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initial Setup and Forward Transform Comparison\n", + "\n", + "This section sets up the HEALPix parameters and performs a forward spherical harmonic transform using `s2fft`'s JAX CUDA implementation, comparing the results with `healpy`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape of j_alms: (48, 95)\n", + "shape of healpix_order_alms: (1176,)\n", + "MSE between j_alms and alms_healpy: (-3.690730140133011e-30+3.982002422466866e-31j)\n" + ] + } + ], + "source": [ + "# Set up\n", + "nside = 16\n", + "npix = hp.nside2npix(nside)\n", + "map_random = jax.random.normal(jax.random.key(0) , shape=npix)\n", + "\n", + "# Compute alms (spherical harmonic coefficients)\n", + "lmax = 3 * nside - 1\n", + "L = lmax + 1 # So S2FFT covers ell=0 to lmax inclusive\n", + "\n", + "# healpy alms\n", + "alms_healpy = hp.map2alm(np.array(map_random), lmax=lmax , iter=3)\n", + "alm_healpy_2d = flm_hp_to_2d_fast(alms_healpy, L=L)\n", + "\n", + "j_alms = forward(map_random, nside=nside, L=L, sampling='healpix' , method='jax_cuda' , iter=3 )\n", + "healpix_order_alms = flm_2d_to_hp_fast(j_alms, L=L)\n", + "print(f\"shape of j_alms: {j_alms.shape}\")\n", + "print(f\"shape of healpix_order_alms: {healpix_order_alms.shape}\")\n", + "\n", + "\n", + "print(f\"MSE between j_alms and alms_healpy: {jnp.mean((healpix_order_alms - alms_healpy) ** 2)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### VMAP and JAX Transforms Test\n", + "\n", + "This cell demonstrates the use of `jax.vmap` with the forward transform and tests JAX's automatic differentiation capabilities (`jacfwd`, `jacrev`) with the CUDA implementation." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of maps: (4, 3072)\n" + ] + } + ], + "source": [ + "# Set up\n", + "nside = 16\n", + "npix = hp.nside2npix(nside)\n", + "map_random = jax.random.normal(jax.random.key(0) , shape=npix)\n", + "# Compute alms (spherical harmonic coefficients)\n", + "lmax = 3 * nside - 1\n", + "L = lmax + 1 # So S2FFT covers ell=0 to lmax inclusive\n", + "\n", + "maps = jnp.stack([map_random, map_random, map_random , map_random], axis=0)\n", + "print(f\"Shape of maps: {maps.shape}\")\n", + "\n", + "def forward_maps(maps):\n", + " return forward(maps, nside=nside, L=L, sampling='healpix', method='jax_cuda').real\n", + "\n", + "alm_maps = jax.vmap(forward_maps)(maps)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inverse Transform Comparison\n", + "\n", + "This cell performs an inverse spherical harmonic transform and compares the reconstructed map from `s2fft`'s JAX CUDA implementation with `healpy`'s reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE between reconstruction_healpy and reconstruction_jax: (1.8236620334440454e-27-8.008792862185043e-31j)\n" + ] + } + ], + "source": [ + "reconstruction_healpy = hp.alm2map(alms_healpy, nside=nside, lmax=lmax)\n", + "reconstruction_jax = inverse(j_alms, nside=nside, L=L, sampling='healpix', method='jax_cuda')\n", + "\n", + "print(f\"MSE between reconstruction_healpy and reconstruction_jax: {jnp.mean((reconstruction_healpy - reconstruction_jax) ** 2)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Performance Benchmarking Functions\n", + "\n", + "This section defines helper functions to benchmark the forward and backward spherical harmonic transforms across different `nside` values, comparing `s2fft`'s JAX CUDA, pure JAX, and `healpy` implementations." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "sampling = \"healpix\"\n", + "n_iter = 3 # Number of iterations for the forward and inverse transforms\n", + "\n", + "def mse(x, y):\n", + " return jnp.mean(jnp.abs(x - y)**2)\n", + "\n", + "\n", + "def run_fwd_test(nside):\n", + " L = 2 * nside \n", + "\n", + " total_pixels = 12 * nside**2\n", + " arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, ))\n", + "\n", + " method = \"jax_cuda\"\n", + " start = time.perf_counter()\n", + " cuda_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = time.perf_counter()\n", + " cuda_jit_time = end - start\n", + "\n", + " start = time.perf_counter()\n", + " cuda_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = time.perf_counter()\n", + " cuda_run_time = end - start\n", + "\n", + " method = \"jax\"\n", + " start = time.perf_counter()\n", + " jax_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = time.perf_counter()\n", + " jax_jit_time = end - start\n", + "\n", + " start = time.perf_counter()\n", + " jax_res = forward(arr, L, nside=nside,sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = time.perf_counter()\n", + " jax_run_time = end - start\n", + "\n", + " method = \"jax_healpy\"\n", + " arr += 0j\n", + " arr = jax.device_put(arr, jax.devices(\"cpu\")[0])\n", + " start = time.perf_counter()\n", + " flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = time.perf_counter()\n", + " healpy_jit_time = end - start\n", + "\n", + " start = time.perf_counter()\n", + " flm = s2fft.forward(arr, L, nside=nside, sampling=sampling, method=method, iter=n_iter ).block_until_ready()\n", + " end = perf_counter()\n", + " healpy_run_time = end - start\n", + "\n", + " print(f\"For nside {nside}\")\n", + " print(f\" -> FWD\")\n", + " print(f\" -> -> cuda_jit_time: {cuda_jit_time:.4f}, cuda_run_time: {cuda_run_time:.4f} mse against hp {mse(cuda_res, flm)}\")\n", + " print(f\" -> -> jax_jit_time: {jax_jit_time:.4f}, jax_run_time: {jax_run_time:.4f} mse against hp {mse(cuda_res, flm)}\")\n", + " print(f\" -> -> healpy_jit_time: {healpy_jit_time:.4f}, healpy_run_time: {healpy_run_time:.4f}\")\n", + "\n", + " return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time\n", + "\n", + "\n", + "def run_bwd_test(nside):\n", + " \n", + " sampling = \"healpix\"\n", + " L = 2 * nside\n", + " total_pixels = 12 * nside**2\n", + " arr = jax.random.normal(jax.random.PRNGKey(0), (total_pixels, )) + 0j\n", + " alm = forward(arr, L, nside=nside, sampling=sampling, method=\"jax_healpy\")\n", + " \n", + " method = \"jax\"\n", + " start = time.perf_counter()\n", + " jax_res = inverse(alm, L, nside=nside,sampling=sampling, method=method).block_until_ready()\n", + " end = time.perf_counter()\n", + " jax_jit_time = end - start\n", + " start = time.perf_counter()\n", + " jax_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()\n", + " end = time.perf_counter()\n", + " jax_run_time = end - start\n", + " \n", + " method = \"jax_cuda\"\n", + " start = time.perf_counter()\n", + " cuda_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()\n", + " end = time.perf_counter()\n", + " cuda_jit_time = end - start\n", + " start = time.perf_counter()\n", + " cuda_res = inverse(alm, L, nside=nside,sampling=sampling, method=method ).block_until_ready()\n", + " end = time.perf_counter()\n", + " cuda_run_time = end - start\n", + "\n", + "\n", + " method = \"jax_healpy\"\n", + " sampling = \"healpix\"\n", + "\n", + " alm = jax.device_put(alm, jax.devices(\"cpu\")[0])\n", + " start = time.perf_counter()\n", + " f = inverse(alm, L, nside=nside, sampling=sampling, method=method).block_until_ready()\n", + " end = time.perf_counter()\n", + " healpy_jit_time = end - start\n", + "\n", + " start = time.perf_counter()\n", + " f = inverse(alm, L, nside=nside, sampling=sampling, method=method ).block_until_ready()\n", + " end = time.perf_counter()\n", + " healpy_run_time = end - start\n", + "\n", + " print(f\"For nside {nside}\")\n", + " print(f\" -> BWD\")\n", + " print(f\" -> -> cuda_jit_time: {cuda_jit_time:.4f}, cuda_run_time: {cuda_run_time:.4f} mse against hp {mse(cuda_res, f)}\")\n", + " print(f\" -> -> jax_jit_time: {jax_jit_time:.4f}, jax_run_time: {jax_run_time:.4f} mse against hp {mse(jax_res, f)}\")\n", + " print(f\" -> -> healpy_jit_time: {healpy_jit_time:.4f}, healpy_run_time: {healpy_run_time:.4f} \")\n", + "\n", + " return cuda_jit_time , cuda_run_time, jax_jit_time, jax_run_time , healpy_jit_time, healpy_run_time" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Clear JAX Caches\n", + "\n", + "Clears JAX's internal caches to ensure fresh compilation for benchmarking." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "jax.clear_caches()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Benchmarking\n", + "\n", + "Executes the benchmarking functions for various `nside` values to collect performance data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "For nside 128\n", + " -> FWD\n", + " -> -> cuda_jit_time: 4.4200, cuda_run_time: 0.6231 mse against hp 2.3766630166715178e-29\n", + " -> -> jax_jit_time: 38.6306, jax_run_time: 0.6253 mse against hp 2.3766630166715178e-29\n", + " -> -> healpy_jit_time: 0.8766, healpy_run_time: 0.4540\n", + "For nside 128\n", + " -> BWD\n", + " -> -> cuda_jit_time: 1.3143, cuda_run_time: 0.0907 mse against hp 2.5339123457221976e-25\n", + " -> -> jax_jit_time: 15.6730, jax_run_time: 0.1263 mse against hp 2.5339096506006936e-25\n", + " -> -> healpy_jit_time: 0.0512, healpy_run_time: 0.0041 \n", + "For nside 256\n", + " -> FWD\n", + " -> -> cuda_jit_time: 8.7759, cuda_run_time: 4.6370 mse against hp 4.332503429570958e-10\n", + " -> -> jax_jit_time: 88.8303, jax_run_time: 4.6417 mse against hp 4.332503429570958e-10\n", + " -> -> healpy_jit_time: 2.5950, healpy_run_time: 1.7487\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mXlaRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:2795\u001b[39m, in \u001b[36m_cached_compilation\u001b[39m\u001b[34m(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler)\u001b[39m\n\u001b[32m 2792\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m dispatch.log_elapsed_time(\n\u001b[32m 2793\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mFinished XLA compilation of \u001b[39m\u001b[38;5;132;01m{fun_name}\u001b[39;00m\u001b[33m in \u001b[39m\u001b[38;5;132;01m{elapsed_time:.9f}\u001b[39;00m\u001b[33m sec\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 2794\u001b[39m fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):\n\u001b[32m-> \u001b[39m\u001b[32m2795\u001b[39m xla_executable = \u001b[43mcompiler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompile_or_get_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2796\u001b[39m \u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdev\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2797\u001b[39m \u001b[43m \u001b[49m\u001b[43mpgle_profiler\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2798\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m xla_executable\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/compiler.py:432\u001b[39m, in \u001b[36mcompile_or_get_cached\u001b[39m\u001b[34m(backend, computation, devices, compile_options, host_callbacks, pgle_profiler)\u001b[39m\n\u001b[32m 431\u001b[39m log_persistent_cache_miss(module_name, cache_key)\n\u001b[32m--> \u001b[39m\u001b[32m432\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_compile_and_write_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 433\u001b[39m \u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 434\u001b[39m \u001b[43m \u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 435\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 436\u001b[39m \u001b[43m \u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 437\u001b[39m \u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 438\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 439\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/compiler.py:694\u001b[39m, in \u001b[36m_compile_and_write_cache\u001b[39m\u001b[34m(backend, computation, compile_options, host_callbacks, module_name, cache_key)\u001b[39m\n\u001b[32m 693\u001b[39m start_time = time.monotonic()\n\u001b[32m--> \u001b[39m\u001b[32m694\u001b[39m executable = \u001b[43mbackend_compile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 695\u001b[39m \u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhost_callbacks\u001b[49m\n\u001b[32m 696\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 697\u001b[39m compile_time = time.monotonic() - start_time\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/profiler.py:334\u001b[39m, in \u001b[36mannotate_function..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 333\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, **decorator_kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m334\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/compiler.py:330\u001b[39m, in \u001b[36mbackend_compile\u001b[39m\u001b[34m(backend, module, options, host_callbacks)\u001b[39m\n\u001b[32m 329\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m handler_result \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m330\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/compiler.py:324\u001b[39m, in \u001b[36mbackend_compile\u001b[39m\u001b[34m(backend, module, options, host_callbacks)\u001b[39m\n\u001b[32m 321\u001b[39m \u001b[38;5;66;03m# Some backends don't have `host_callbacks` option yet\u001b[39;00m\n\u001b[32m 322\u001b[39m \u001b[38;5;66;03m# TODO(sharadmv): remove this fallback when all backends allow `compile`\u001b[39;00m\n\u001b[32m 323\u001b[39m \u001b[38;5;66;03m# to take in `host_callbacks`\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m324\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackend\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuilt_c\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m=\u001b[49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 325\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m xc.XlaRuntimeError \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "\u001b[31mXlaRuntimeError\u001b[39m: INTERNAL: ptxas exited with non-zero error code 2, output: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m nside \u001b[38;5;129;01min\u001b[39;00m nsides:\n\u001b[32m 5\u001b[39m fwd_times.append(run_fwd_test(nside))\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m bwd_times.append(\u001b[43mrun_bwd_test\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnside\u001b[49m\u001b[43m)\u001b[49m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[17]\u001b[39m\u001b[32m, line 68\u001b[39m, in \u001b[36mrun_bwd_test\u001b[39m\u001b[34m(nside)\u001b[39m\n\u001b[32m 66\u001b[39m method = \u001b[33m\"\u001b[39m\u001b[33mjax\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 67\u001b[39m start = time.perf_counter()\n\u001b[32m---> \u001b[39m\u001b[32m68\u001b[39m jax_res = \u001b[43minverse\u001b[49m\u001b[43m(\u001b[49m\u001b[43malm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnside\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnside\u001b[49m\u001b[43m,\u001b[49m\u001b[43msampling\u001b[49m\u001b[43m=\u001b[49m\u001b[43msampling\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m)\u001b[49m.block_until_ready()\n\u001b[32m 69\u001b[39m end = time.perf_counter()\n\u001b[32m 70\u001b[39m jax_jit_time = end - start\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Projects/CMB/s2fft/s2fft/transforms/spherical.py:110\u001b[39m, in \u001b[36minverse\u001b[39m\u001b[34m(flm, L, spin, nside, sampling, method, reality, precomps, spmd, L_lower, _ssht_backend)\u001b[39m\n\u001b[32m 107\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 108\u001b[39m inverse_kwargs[\u001b[33m\"\u001b[39m\u001b[33mnside\u001b[39m\u001b[33m\"\u001b[39m] = nside\n\u001b[32m--> \u001b[39m\u001b[32m110\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_inverse_functions\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43minverse_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/pjit.py:340\u001b[39m, in \u001b[36m_cpp_pjit..cache_miss\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 335\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.no_tracing.value:\n\u001b[32m 336\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mre-tracing function \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjit_info.fun_sourceinfo\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 337\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m`jit`, but \u001b[39m\u001b[33m'\u001b[39m\u001b[33mno_tracing\u001b[39m\u001b[33m'\u001b[39m\u001b[33m is set\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 339\u001b[39m (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,\n\u001b[32m--> \u001b[39m\u001b[32m340\u001b[39m pgle_profiler) = \u001b[43m_python_pjit_helper\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjit_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 342\u001b[39m maybe_fastpath_data = _get_fastpath_data(\n\u001b[32m 343\u001b[39m executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,\n\u001b[32m 344\u001b[39m jaxpr.consts, jit_info.abstracted_axes,\n\u001b[32m 345\u001b[39m pgle_profiler)\n\u001b[32m 347\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/pjit.py:191\u001b[39m, in \u001b[36m_python_pjit_helper\u001b[39m\u001b[34m(fun, jit_info, *args, **kwargs)\u001b[39m\n\u001b[32m 189\u001b[39m args_flat = \u001b[38;5;28mmap\u001b[39m(core.full_lower, args_flat)\n\u001b[32m 190\u001b[39m core.check_eval_args(args_flat)\n\u001b[32m--> \u001b[39m\u001b[32m191\u001b[39m out_flat, compiled, profiler = \u001b[43m_pjit_call_impl_python\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs_flat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 192\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 193\u001b[39m out_flat = pjit_p.bind(*args_flat, **p.params)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/pjit.py:1809\u001b[39m, in \u001b[36m_pjit_call_impl_python\u001b[39m\u001b[34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)\u001b[39m\n\u001b[32m 1797\u001b[39m compiler_options_kvs = compiler_options_kvs + \u001b[38;5;28mtuple\u001b[39m(pgle_compile_options.items())\n\u001b[32m 1798\u001b[39m \u001b[38;5;66;03m# Passing mutable PGLE profile here since it should be extracted by JAXPR to\u001b[39;00m\n\u001b[32m 1799\u001b[39m \u001b[38;5;66;03m# initialize the fdo_profile compile option.\u001b[39;00m\n\u001b[32m 1800\u001b[39m compiled = \u001b[43m_resolve_and_lower\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1801\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_shardings\u001b[49m\u001b[43m=\u001b[49m\u001b[43min_shardings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1802\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_shardings\u001b[49m\u001b[43m=\u001b[49m\u001b[43mout_shardings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_layouts\u001b[49m\u001b[43m=\u001b[49m\u001b[43min_layouts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1803\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_layouts\u001b[49m\u001b[43m=\u001b[49m\u001b[43mout_layouts\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1804\u001b[39m \u001b[43m \u001b[49m\u001b[43mctx_mesh\u001b[49m\u001b[43m=\u001b[49m\u001b[43mctx_mesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_unused\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkeep_unused\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1805\u001b[39m \u001b[43m \u001b[49m\u001b[43minline\u001b[49m\u001b[43m=\u001b[49m\u001b[43minline\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlowering_platforms\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 1806\u001b[39m \u001b[43m \u001b[49m\u001b[43mlowering_parameters\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmlir\u001b[49m\u001b[43m.\u001b[49m\u001b[43mLoweringParameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1807\u001b[39m \u001b[43m \u001b[49m\u001b[43mpgle_profiler\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpgle_profiler\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1808\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m-> \u001b[39m\u001b[32m1809\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1811\u001b[39m \u001b[38;5;66;03m# This check is expensive so only do it if enable_checks is on.\u001b[39;00m\n\u001b[32m 1812\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m compiled._auto_spmd_lowering \u001b[38;5;129;01mand\u001b[39;00m config.enable_checks.value:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:2462\u001b[39m, in \u001b[36mMeshComputation.compile\u001b[39m\u001b[34m(self, compiler_options)\u001b[39m\n\u001b[32m 2460\u001b[39m compiler_options_kvs = \u001b[38;5;28mself\u001b[39m._compiler_options_kvs + t_compiler_options\n\u001b[32m 2461\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._executable \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m compiler_options_kvs:\n\u001b[32m-> \u001b[39m\u001b[32m2462\u001b[39m executable = \u001b[43mUnloadedMeshExecutable\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfrom_hlo\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2463\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_hlo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcompile_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2464\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2465\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m compiler_options_kvs:\n\u001b[32m 2466\u001b[39m \u001b[38;5;28mself\u001b[39m._executable = executable\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:3004\u001b[39m, in \u001b[36mUnloadedMeshExecutable.from_hlo\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 3001\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 3003\u001b[39m util.test_event(\u001b[33m\"\u001b[39m\u001b[33mpxla_cached_compilation\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m3004\u001b[39m xla_executable = \u001b[43m_cached_compilation\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 3005\u001b[39m \u001b[43m \u001b[49m\u001b[43mhlo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspmd_lowering\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3006\u001b[39m \u001b[43m \u001b[49m\u001b[43mtuple_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mauto_spmd_lowering\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mallow_prop_to_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3007\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_prop_to_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mda\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpmap_nreps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3008\u001b[39m \u001b[43m \u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpgle_profiler\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3010\u001b[39m orig_out_shardings = out_shardings\n\u001b[32m 3012\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m auto_spmd_lowering:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:2792\u001b[39m, in \u001b[36m_cached_compilation\u001b[39m\u001b[34m(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_kvs, pgle_profiler)\u001b[39m\n\u001b[32m 2785\u001b[39m compiler_options = \u001b[38;5;28mdict\u001b[39m(compiler_options_kvs)\n\u001b[32m 2787\u001b[39m compile_options = create_compile_options(\n\u001b[32m 2788\u001b[39m computation, mesh, spmd_lowering, tuple_args, auto_spmd_lowering,\n\u001b[32m 2789\u001b[39m allow_prop_to_inputs, allow_prop_to_outputs, backend,\n\u001b[32m 2790\u001b[39m dev, pmap_nreps, compiler_options)\n\u001b[32m-> \u001b[39m\u001b[32m2792\u001b[39m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mwith\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdispatch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlog_elapsed_time\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2793\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mFinished XLA compilation of \u001b[39;49m\u001b[38;5;132;43;01m{fun_name}\u001b[39;49;00m\u001b[33;43m in \u001b[39;49m\u001b[38;5;132;43;01m{elapsed_time:.9f}\u001b[39;49;00m\u001b[33;43m sec\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 2794\u001b[39m \u001b[43m \u001b[49m\u001b[43mfun_name\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mevent\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdispatch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mBACKEND_COMPILE_EVENT\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 2795\u001b[39m \u001b[43m \u001b[49m\u001b[43mxla_executable\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompiler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcompile_or_get_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2796\u001b[39m \u001b[43m \u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdev\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhost_callbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2797\u001b[39m \u001b[43m \u001b[49m\u001b[43mpgle_profiler\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2798\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m xla_executable\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/s2fft/lib/python3.11/site-packages/jax/_src/dispatch.py:183\u001b[39m, in \u001b[36mLogElapsedTimeContextManager.__exit__\u001b[39m\u001b[34m(self, exc_type, exc_value, traceback)\u001b[39m\n\u001b[32m 180\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__enter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 181\u001b[39m \u001b[38;5;28mself\u001b[39m.start_time = time.time()\n\u001b[32m--> \u001b[39m\u001b[32m183\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, exc_type, exc_value, traceback):\n\u001b[32m 184\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m _on_exit:\n\u001b[32m 185\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "fwd_times = []\n", + "bwd_times = []\n", + "nsides = [4 , 8 , 16 , 32 , 64 , 128 , 256 ]\n", + "for nside in nsides:\n", + " fwd_times.append(run_fwd_test(nside))\n", + " bwd_times.append(run_bwd_test(nside))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting Utility\n", + "\n", + "This cell defines a utility function to plot the compilation and execution times obtained from the benchmarking tests." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import seaborn as sns\n", + "sns.plotting_context(\"poster\")\n", + "sns.set(font_scale=1.4)\n", + "\n", + "\n", + "def plot_times(title, nsides, chrono_times):\n", + "\n", + " # Extracting times from the chrono_times\n", + " cuda_jit_times = [times[0] for times in chrono_times]\n", + " cuda_run_times = [times[1] for times in chrono_times]\n", + " jax_jit_times = [times[2] for times in chrono_times]\n", + " jax_run_times = [times[3] for times in chrono_times]\n", + " healpy_jit_times = [times[4] for times in chrono_times]\n", + " healpy_run_times = [times[5] for times in chrono_times]\n", + "\n", + " # Create subplots\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 7))\n", + "\n", + " f2 = lambda a: np.log2(a)\n", + " g2 = lambda b: b**2\n", + "\n", + "\n", + " # Plot for JIT times\n", + " ax1.plot(nsides, cuda_jit_times, 'g-o', label='ours')\n", + " ax1.plot(nsides, jax_jit_times, 'b-o', label='s2fft base')\n", + " ax1.plot(nsides, healpy_jit_times, 'r-o', label='Healpy')\n", + " ax1.set_title('Compilation Times (first run)')\n", + " ax1.set_xlabel('nside')\n", + " ax1.set_ylabel('Time (seconds)')\n", + " ax1.set_xscale('function', functions=(f2, g2))\n", + " ax1.set_xticks(nsides)\n", + " ax1.set_xticklabels(nsides)\n", + " ax1.legend()\n", + " ax1.grid(True, which=\"both\", ls=\"--\")\n", + "\n", + " # Plot for Run times\n", + " ax2.plot(nsides, cuda_run_times, 'g-o', label='ours')\n", + " ax2.plot(nsides, jax_run_times, 'b-o', label='s2fft base')\n", + " ax2.plot(nsides, healpy_run_times, 'r-o', label='Healpy')\n", + " ax2.set_title('Execution Times')\n", + " ax2.set_xlabel('nside')\n", + " ax2.set_ylabel('Time (seconds)')\n", + " ax2.set_xscale('function', functions=(f2, g2))\n", + " ax2.set_xticks(nsides)\n", + " ax2.set_xticklabels(nsides)\n", + " ax2.legend()\n", + " ax2.grid(True, which=\"both\", ls=\"--\")\n", + "\n", + " # Set the overall title for the figure\n", + " fig.suptitle(title, fontsize=16)\n", + "\n", + " # Show the plots\n", + " plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust rect to make space for the suptitle\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize Performance Results\n", + "\n", + "This cell calls the plotting function to visualize the benchmark results for forward and backward transforms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_times(\"Forward FFT Times\", nsides, fwd_times)\n", + "plot_times(\"Backward FFT Times\", nsides, bwd_times)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Final Reconstruction and Error Check\n", + "\n", + "This cell performs a final inverse transform to reconstruct the map and calculates the Mean Squared Error (MSE) against the `healpy` reconstructed map to verify accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape of map_reconstructed: (3072,)\n", + "Mean Squared Error between reconstructed map and healpy map: (1.8236620334440454e-27-8.008792862185043e-31j)\n" + ] + } + ], + "source": [ + "# Test backward transform\n", + "map_reconstructed = inverse(j_alms, nside=nside, L=L, sampling='healpix', method='jax_cuda')\n", + "print(f\"shape of map_reconstructed: {map_reconstructed.shape}\")\n", + "hp_reconstructed = hp.alm2map(alms_healpy, nside=nside, lmax=lmax)\n", + "\n", + "# Compute the mean squared error between the two maps\n", + "mse = jnp.mean((map_reconstructed - hp_reconstructed) ** 2)\n", + "print(f\"Mean Squared Error between reconstructed map and healpy map: {mse}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" } - ], - "source": [ - "\n", - "plot_times(\"Forward FFT Times\", nsides, fwd_times)\n", - "plot_times(\"Backward FFT Times\", nsides, bwd_times)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/pyproject.toml b/pyproject.toml index c797c2b1..304adba2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,9 @@ requires = [ "setuptools", "setuptools-scm", - "scikit-build-core >=0.4.3", - "nanobind >=1.3.2" + "scikit-build-core >=0.11", + "nanobind >=2.0,<2.6", + "jax >= 0.4.0" ] build-backend = "scikit_build_core.build" @@ -28,7 +29,7 @@ classifiers = [ description = "Differentiable and accelerated spherical transforms with JAX" dependencies = [ "numpy>=1.20", - "jax>=0.3.13,<0.6.0", + "jax>=0.3.13", "jaxlib", ] dynamic = [ @@ -81,11 +82,12 @@ torch = [ [tool.scikit-build] # Protect the configuration against future changes in scikit-build-core -minimum-version = "0.4" +minimum-version = "0.8" # Setuptools-style build caching in a local directory build-dir = "build/{wheel_tag}" # Build stable ABI wheels for CPython 3.12+ wheel.py-api = "cp312" +cmake.build-type = "Release" metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" sdist.include = ["s2fft/_version.py"] diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index 4ef7ae68..394bfb89 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -7,6 +7,7 @@ import jax.numpy as jnp import numpy as np from jax import core, custom_vjp +from jax.extend.core import Primitive from jax.interpreters import ad from s2fft.sampling import reindex @@ -342,7 +343,7 @@ def _healpy_map2alm_transpose(dflm: jnp.ndarray, L: int, nside: int): return (jnp.conj(healpy_alm2map(jnp.conj(dflm) / scale_factors, L, nside)),) -_healpy_map2alm_p = core.Primitive("healpy_map2alm") +_healpy_map2alm_p = Primitive("healpy_map2alm") _healpy_map2alm_p.def_impl(_healpy_map2alm_impl) _healpy_map2alm_p.def_abstract_eval(_healpy_map2alm_abstract_eval) ad.deflinear(_healpy_map2alm_p, _healpy_map2alm_transpose) @@ -397,7 +398,7 @@ def _healpy_alm2map_transpose(df: jnp.ndarray, L: int, nside: int) -> tuple: return (scale_factors * jnp.conj(healpy_map2alm(jnp.conj(df), L, nside)),) -_healpy_alm2map_p = core.Primitive("healpy_alm2map") +_healpy_alm2map_p = Primitive("healpy_alm2map") _healpy_alm2map_p.def_impl(_healpy_alm2map_impl) _healpy_alm2map_p.def_abstract_eval(_healpy_alm2map_abstract_eval) ad.deflinear(_healpy_alm2map_p, _healpy_alm2map_transpose) diff --git a/s2fft/utils/healpix_ffts.py b/s2fft/utils/healpix_ffts.py index b5a5efc8..eead27fe 100644 --- a/s2fft/utils/healpix_ffts.py +++ b/s2fft/utils/healpix_ffts.py @@ -1,14 +1,14 @@ from functools import partial +import jax import jax.numpy as jnp -import jaxlib.mlir.ir as ir import numpy as np from jax import jit, vmap # did not find promote_dtypes_complex outside _src from jax._src.numpy.util import promote_dtypes_complex -from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call +from jax.core import ShapedArray +from jax.interpreters import batching from s2fft_lib import _s2fft from s2fft.sampling import s2_samples as samples @@ -537,49 +537,29 @@ def ring_phase_shifts_hp_jax( phi_offsets = p2phi_rings_jax(t, nside) sign = -1 if forward else 1 m_start_ind = 0 if reality else -L + 1 + # Step 5: Calculate the exponent for the phase shifts using JAX einsum. exponent = jnp.einsum( "t, m->tm", phi_offsets, jnp.arange(m_start_ind, L), optimize=True ) + # Step 6: Return the complex exponential of the exponent. return jnp.exp(sign * 1j * exponent) # Custom healpix_fft_cuda primitive -def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm): - # For the forward pass, the input is a HEALPix pixel-space array of size nside^2 * - # 12 and the output is a FTM array of shape (number of rings , width of FTM slice) - # which is (4 * nside - 1 , 2 * L ) - healpix_size = (nside**2 * 12,) - ftm_size = (4 * nside - 1, 2 * L) - if fft_type == "forward": - assert f.shape == healpix_size - return f.update(shape=ftm_size, dtype=f.dtype) - elif fft_type == "backward": - print(f"f.shape {f.shape}") - assert f.shape == ftm_size - return f.update(shape=healpix_size, dtype=f.dtype) - else: - raise ValueError(f"fft_type {fft_type} not recognised.") - - -def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm): - (aval_out,) = ctx.avals_out - a_type = ir.RankedTensorType(f.type) - - out_dtype = aval_out.dtype +def _get_lowering_info(fft_type, norm, out_dtype): + # Step 1: Determine if double precision is used based on output dtype. if out_dtype == np.complex64: - out_type = ir.ComplexType.get(ir.F32Type.get()) is_double = False elif out_dtype == np.complex128: - out_type = ir.ComplexType.get(ir.F64Type.get()) is_double = True else: raise ValueError(f"Unknown output type {out_dtype}") - out_type = ir.RankedTensorType.get(aval_out.shape, out_type) - + # Step 2: Determine if it's a forward transform. forward = fft_type == "forward" + # Step 3: Determine if normalization should be applied. if (forward and norm == "backward") or (not forward and norm == "forward"): normalize = False elif (forward and norm == "forward") or (not forward and norm == "backward"): @@ -587,34 +567,243 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm): else: raise ValueError(f"Unknown norm {norm}") - descriptor = _s2fft.build_healpix_fft_descriptor( - nside, L, reality, forward, normalize, is_double + # Step 4: Return the determined flags. + return is_double, forward, normalize + + +def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint): + """ + Abstract evaluation for the HEALPix FFT CUDA primitive. + This function defines the output shapes and dtypes for the JAX primitive. + + Args: + f: Input array. + L: Harmonic band-limit. + nside: HEALPix Nside resolution parameter. + reality: Whether the signal is real. + fft_type: Type of FFT ("forward" or "backward"). + norm: Normalization type. + adjoint: Whether it's an adjoint operation. + + Returns: + Tuple of ShapedArray objects for output, workspace, and callback parameters. + + """ + # Step 1: Get lowering information (double precision, forward/backward, normalize). + is_double, forward, normalize = _get_lowering_info(fft_type, norm, f.dtype) + + # Step 2: Determine workspace size and type based on precision. + if is_double: + # For double precision, build descriptor for C128 and calculate workspace size. + worksize = _s2fft.build_descriptor_C128( + nside, L, reality, forward, normalize, adjoint + ) + worksize //= 16 # 16 bytes per C128 element + workspace_shape = (worksize,) + workspace_dtype = np.complex128 + else: + # For single precision, build descriptor for C64 and calculate workspace size. + worksize = _s2fft.build_descriptor_C64( + nside, L, reality, forward, normalize, adjoint + ) + worksize //= 8 # 8 bytes per C64 element + workspace_shape = (worksize,) + workspace_dtype = np.complex64 + # Step 3: Define output shapes based on FFT type. + healpix_size = (nside**2 * 12,) + ftm_size = (4 * nside - 1, 2 * L) + if fft_type == "forward": + batch_shape = (f.shape[0],) if f.ndim == 2 else () + out_shape = batch_shape + ftm_size + assert (f.shape[-1],) == healpix_size + + elif fft_type == "backward": + batch_shape = (f.shape[0],) if f.ndim == 3 else () + out_shape = batch_shape + healpix_size + assert f.shape[-2:] == ftm_size + else: + raise ValueError(f"fft_type {fft_type} not recognised.") + + # Step 4: Create ShapedArray objects for output and workspace. + workspace_aval = ShapedArray( + shape=batch_shape + workspace_shape, dtype=workspace_dtype + ) + + # Step 5: Return the ShapedArray objects. + return ( + f.update(shape=out_shape, dtype=f.dtype), + workspace_aval, + ) + + +class MissingCUDASupport(Exception): # noqa : D107 + def __init__(self): # noqa : D107 + super().__init__(""" + S2FFT was compiled without CUDA support. Cuda functions are not supported. + Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip. + """) + + +def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adjoint): + """ + Lowering rule for the HEALPix FFT CUDA primitive. + This function translates the JAX primitive call into a call to the underlying CUDA FFI. + + Args: + ctx: Lowering context. + f: Input array. + L: Harmonic band-limit. + nside: HEALPix Nside resolution parameter. + reality: Whether the signal is real. + fft_type: Type of FFT ("forward" or "backward"). + norm: Normalization type. + adjoint: Whether it's an adjoint operation. + + Returns: + The result of the FFI call. + + """ + # Step 1: Check if CUDA support is compiled in. + if not _s2fft.COMPILED_WITH_CUDA: + raise MissingCUDASupport() + + # Step 2: Get the abstract evaluation results for the outputs. + (aval_out, _) = ctx.avals_out + + # Step 3: Get lowering information (double precision, forward/backward, normalize). + is_double, forward, normalize = _get_lowering_info(fft_type, norm, aval_out.dtype) + + # Step 4: Select the appropriate FFI lowering function based on precision. + if is_double: + ffi_lowered = jax.ffi.ffi_lowering("healpix_fft_cuda_c128") + else: + ffi_lowered = jax.ffi.ffi_lowering("healpix_fft_cuda_c64") + + # Step 5: Call the FFI lowering function with the context and parameters. + return ffi_lowered( + ctx, + f, + nside=nside, + harmonic_band_limit=L, + reality=reality, + normalize=normalize, + forward=forward, + adjoint=adjoint, + ) + + +def _healpix_fft_cuda_batching_rule( + batched_args, batched_axis, L, nside, reality, fft_type, norm, adjoint +): + """ + Batching rule for the HEALPix FFT CUDA primitive. + This function defines how the primitive behaves under JAX's automatic batching. + + Args: + batched_args: Tuple of batched arguments. + batched_axis: Tuple of axes along which arguments are batched. + L: Harmonic band-limit. + nside: HEALPix Nside resolution parameter. + reality: Whether the signal is real. + fft_type: Type of FFT ("forward" or "backward"). + norm: Normalization type. + adjoint: Whether it's an adjoint operation. + + Returns: + Tuple of (output, output_batch_axes). + + """ + # Step 1: Unpack batched arguments and batching axes. + (x,) = batched_args + (bd,) = batched_axis + + # Step 2: Assert correct input dimensions based on FFT type. + if fft_type == "forward": + assert x.ndim == 2 + elif fft_type == "backward": + assert x.ndim == 3 + else: + raise ValueError(f"fft_type {fft_type} not recognised.") + + # Step 3: Move the batching axis to the front. + x = batching.moveaxis(x, bd, 0) + + # Step 4: Bind the primitive with the batched input. + out = _healpix_fft_cuda_primitive.bind( + x, + L=L, + nside=nside, + reality=reality, + fft_type=fft_type, + norm=norm, + adjoint=adjoint, ) + # Step 5: Define batching axes for the outputs (all at axis 0). + batchout = (0,) * len(out) + + # Step 6: Return the output and their batching axes. + return out, batchout - layout = tuple(range(len(a_type.shape) - 1, -1, -1)) - out_layout = tuple(range(len(out_type.shape) - 1, -1, -1)) - - result = custom_call( - "healpix_fft_cuda", - result_types=[out_type], - operands=[f], - operand_layouts=[layout], - result_layouts=[out_layout], - has_side_effect=True, - backend_config=descriptor, + +def _healpix_fft_cuda_transpose( + df: jnp.ndarray, + L: int, + nside: int, + reality: bool, + fft_type: str, + norm: str, + adjoint: bool, +) -> jnp.ndarray: + """ + Transpose rule for the HEALPix FFT CUDA primitive. + This function defines how the adjoint of the primitive is computed for automatic differentiation. + + Args: + df: Tangent (gradient) of the output. + L: Harmonic band-limit. + nside: HEALPix Nside resolution parameter. + reality: Whether the signal is real. + fft_type: Type of FFT ("forward" or "backward"). + norm: Normalization type. + adjoint: Whether it's an adjoint operation. + + Returns: + The adjoint of the input. + + """ + # Step 1: Invert the FFT type and normalization for the adjoint operation. + fft_type = "backward" if fft_type == "forward" else "forward" + norm = "backward" if norm == "forward" else "forward" + + # Step 2: Bind the primitive with the tangent and inverted parameters. + # Access df[0] as df is a tuple of tangents for multiple outputs. + # Return [0] as the primitive also returns multiple outputs, and we only need the first one for the adjoint. + return ( + _healpix_fft_cuda_primitive.bind( + df[0], + L=L, + nside=nside, + reality=reality, + fft_type=fft_type, + norm=norm, + adjoint=not adjoint, + )[0], ) - return result.results # Register healpfix_fft_cuda custom call target for name, fn in _s2fft.registration().items(): - xla_client.register_custom_call_target(name, fn, platform="gpu") + jax.ffi.register_ffi_target(name, fn, platform="CUDA") +# Step 1: Register the HEALPix FFT CUDA primitive with JAX. _healpix_fft_cuda_primitive = register_primitive( "healpix_fft_cuda", - multiple_results=False, + multiple_results=True, # Indicates that the primitive returns multiple outputs. abstract_evaluation=_healpix_fft_cuda_abstract, lowering_per_platform={None: _healpix_fft_cuda_lowering}, + transpose=_healpix_fft_cuda_transpose, + batcher=_healpix_fft_cuda_batching_rule, + is_linear=True, ) @@ -642,10 +831,20 @@ def healpix_fft_cuda( jnp.ndarray: Array of Fourier coefficients for all latitudes. """ + # Step 1: Promote input data to complex dtype if necessary. (f,) = promote_dtypes_complex(f) - return _healpix_fft_cuda_primitive.bind( - f, L=L, nside=nside, reality=reality, fft_type="forward", norm=norm + # Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace). + out, _ = _healpix_fft_cuda_primitive.bind( + f, + L=L, + nside=nside, + reality=reality, + fft_type="forward", + norm=norm, + adjoint=False, ) + # Step 3: Return only the primary output (Fourier coefficients). + return out @partial(jit, static_argnums=(1, 2, 3)) @@ -672,10 +871,20 @@ def healpix_ifft_cuda( jnp.ndarray: HEALPix pixel-space array. """ + # Step 1: Promote input data to complex dtype if necessary. (ftm,) = promote_dtypes_complex(ftm) - return _healpix_fft_cuda_primitive.bind( - ftm, L=L, nside=nside, reality=reality, fft_type="backward", norm=norm + # Step 2: Bind the input to the CUDA primitive. It returns multiple outputs (out, workspace). + out, _ = _healpix_fft_cuda_primitive.bind( + ftm, + L=L, + nside=nside, + reality=reality, + fft_type="backward", + norm=norm, + adjoint=False, ) + # Step 3: Return only the primary output (pixel-space array). + return out _healpix_fft_functions = { diff --git a/s2fft/utils/jax_primitive.py b/s2fft/utils/jax_primitive.py index 6aac7c72..66c6822e 100644 --- a/s2fft/utils/jax_primitive.py +++ b/s2fft/utils/jax_primitive.py @@ -1,7 +1,7 @@ from functools import partial from typing import Callable, Dict, Optional, Union -from jax import core +from jax.extend import core from jax.interpreters import ad, batching, mlir, xla @@ -13,36 +13,79 @@ def register_primitive( batcher: Optional[Callable] = None, jacobian_vector_product: Optional[Callable] = None, transpose: Optional[Callable] = None, + is_linear: bool = False, ): """ Register a new custom JAX primitive. + This function provides a streamlined way to register custom JAX primitives, + including their implementation, abstract evaluation, lowering rules for different + platforms, and optional rules for batching and automatic differentiation. + Args: - name: Name for primitive. - multiple_results: Whether primitive returns multiple values. - abstract_evaluation: Abstract evaluation rule for primitive. - lowering_per_platform: Dictionary mapping from platform names (or `None` for - platform-independent) to lowering rules. - batcher: Optional batched evaluation rule for primitive. - jacobian_vector_product: Optional Jacobian vector product for primitive for - forward-mode automatic differentiation. - transpose: Optional rule for evaluation transpose rule for primitive for - reverse-mode automatic differentiation. + name (str): The name of the primitive. + multiple_results (bool): A boolean indicating whether the primitive returns multiple values. + abstract_evaluation (Callable): A callable that defines the abstract evaluation rule for the primitive. + It should take `ShapedArray` instances as inputs and return `ShapedArray` instances for the outputs. + lowering_per_platform (Dict[Union[None, str], Callable]): A dictionary mapping platform names + (e.g., "cpu", "gpu", or None for platform-independent) to their respective lowering rules. + A lowering rule translates the primitive into a sequence of MLIR operations. + batcher (Optional[Callable]): An optional callable that defines the batched evaluation rule for the primitive. + This is used by JAX's automatic batching (vmap). + jacobian_vector_product (Optional[Callable]): An optional callable that defines the Jacobian-vector product + (JVP) rule for the primitive. This is used for forward-mode automatic differentiation. + transpose (Optional[Callable]): An optional callable that defines the transpose rule for the primitive. + This is used for reverse-mode automatic differentiation (autograd). + is_linear (bool): A boolean indicating whether the primitive is linear. If True and a `transpose` rule + is provided, `ad.deflinear` is used, which can optimize linear operations. Returns: - Registered custom primtive. + jax.core.Primitive: The registered custom JAX primitive object. + + Raises: + ValueError: If an invalid platform is specified in `lowering_per_platform`. """ + # Step 1: Create a new JAX primitive with the given name. primitive = core.Primitive(name) + + # Step 2: Set the `multiple_results` attribute of the primitive. primitive.multiple_results = multiple_results + + # Step 3: Define the default implementation of the primitive using `xla.apply_primitive`. + # This means that by default, the primitive will be lowered to XLA. primitive.def_impl(partial(xla.apply_primitive, primitive)) + + # Step 4: Register the abstract evaluation rule for the primitive. + # This rule tells JAX how to infer the shape and dtype of the primitive's outputs + # given its inputs, without actually executing the computation. primitive.def_abstract_eval(abstract_evaluation) + + # Step 5: Register lowering rules for the primitive across different platforms. + # This step defines how the primitive is translated into lower-level operations + # (e.g., MLIR) for execution on specific hardware (CPU, GPU, etc.). for platform, lowering in lowering_per_platform.items(): mlir.register_lowering(primitive, lowering, platform=platform) + + # Step 6: Register the batching rule if provided. + # The batching rule enables JAX's `vmap` transformation to work with this primitive. if batcher is not None: batching.primitive_batchers[primitive] = batcher + + # Step 7: Register the Jacobian-vector product (JVP) rule if provided. + # The JVP rule is essential for forward-mode automatic differentiation. if jacobian_vector_product is not None: ad.primitive_jvps[primitive] = jacobian_vector_product + + # Step 8: Register the transpose rule if provided. + # The transpose rule is crucial for reverse-mode automatic differentiation (autograd). if transpose is not None: - ad.primitive_transposes[primitive] = transpose + if is_linear: + # If the primitive is linear, use `ad.deflinear` for optimized transpose registration. + ad.deflinear(primitive, transpose) + else: + # Otherwise, use `ad.primitive_transposes` for general transpose registration. + ad.primitive_transposes[primitive] = transpose + + # Step 9: Return the newly registered primitive. return primitive diff --git a/tests/test_healpix_ffts.py b/tests/test_healpix_ffts.py index 3ab75d9b..82969062 100644 --- a/tests/test_healpix_ffts.py +++ b/tests/test_healpix_ffts.py @@ -1,10 +1,12 @@ import healpy as hp import jax +import jax.numpy as jnp import numpy as np import pytest from numpy.testing import assert_allclose from packaging.version import Version as _Version +import s2fft from s2fft.sampling import s2_samples as samples from s2fft.utils.healpix_ffts import ( healpix_fft_cuda, @@ -92,3 +94,96 @@ def test_healpix_ifft_cuda(flm_generator, nside): atol=1e-7, rtol=1e-7, ) + + +@pytest.mark.skipif(not gpu_available, reason="GPU not available") +@pytest.mark.parametrize("nside", nside_to_test) +def test_healpix_fft_cuda_transforms(flm_generator, nside): + L = 2 * nside + + # Generate a random bandlimited signal + def generate_flm(): + flm = flm_generator(L=L, reality=False) + f = s2fft.inverse( + flm, L=L, nside=nside, reality=False, method="jax", sampling="healpix" + ) + return f + + f_stacked = jnp.stack([generate_flm() for _ in range(10)], axis=0) + + def healpix_jax(f): + return healpix_fft_jax(f, L, nside, False).real + + def healpix_cuda(f): + return healpix_fft_cuda(f, L, nside, False).real + + f = f_stacked[0] + # Test VMAP + assert_allclose( + jax.vmap(healpix_jax)(f_stacked), + jax.vmap(healpix_cuda)(f_stacked), + atol=1e-7, + rtol=1e-7, + ) + # test jacfwd + assert_allclose( + jax.jacfwd(healpix_jax)(f.real), + jax.jacfwd(healpix_cuda)(f.real), + atol=1e-7, + rtol=1e-7, + ) + # test jacrev + assert_allclose( + jax.jacrev(healpix_jax)(f.real), + jax.jacrev(healpix_cuda)(f.real), + atol=1e-7, + rtol=1e-7, + ) + + +@pytest.mark.skipif(not gpu_available, reason="GPU not available") +@pytest.mark.parametrize("nside", nside_to_test) +def test_healpix_ifft_cuda_transforms(flm_generator, nside): + L = 2 * nside + + # Generate a random bandlimited signal + def generate_flm(): + flm = flm_generator(L=L, reality=False) + f = s2fft.inverse( + flm, L=L, nside=nside, reality=False, method="jax", sampling="healpix" + ) + ftm = healpix_fft_jax(f, L, nside, False) + return ftm + + ftm_stacked = jnp.stack([generate_flm() for _ in range(10)], axis=0) + ftm = ftm_stacked[0].real + + def healpix_inv_jax(ftm): + return healpix_ifft_jax(ftm, L, nside, False).real + + def healpix_inv_cuda(ftm): + return healpix_ifft_cuda(ftm, L, nside, False).real + + # Test VMAP + assert_allclose( + jax.vmap(healpix_inv_jax)(ftm_stacked).flatten(), + jax.vmap(healpix_inv_cuda)(ftm_stacked).flatten(), + atol=1e-7, + rtol=1e-7, + ) + + # test jacfwd + assert_allclose( + jax.jacfwd(healpix_inv_jax)(ftm.real).flatten(), + jax.jacfwd(healpix_inv_cuda)(ftm.real).flatten(), + atol=1e-7, + rtol=1e-7, + ) + + # test jacrev + assert_allclose( + jax.jacrev(healpix_inv_jax)(ftm.real).flatten(), + jax.jacrev(healpix_inv_cuda)(ftm.real).flatten(), + atol=1e-7, + rtol=1e-7, + )