diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index bab6b7bc..5aa6abf3 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -1,10 +1,12 @@ // Copyright (c) 2023, DeepLink. #include +#include #include #include #include +#include "torch/library.h" #include #include #include @@ -363,4 +365,138 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { } } +std::tuple adamw( + at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, + const c10::optional& max_exp_avg_sq_opt, const at::Tensor& grad, + double lr, double beta1, double beta2, double epsilon, double weight_decay, + int64_t step, bool amsgrad) { + // the diopiAdamW func has no "maximize" param + at::Tensor& grad_ref = + const_cast(grad); // todo: grad is const value + at::Tensor max_exp_avg_sq_opt_value = + max_exp_avg_sq_opt.value_or(at::Tensor()); + callDiopi(diopiAdamW, param, grad_ref, exp_avg, exp_avg_sq, + max_exp_avg_sq_opt_value, lr, beta1, beta2, epsilon, weight_decay, + step, amsgrad); + return std::tie(param, exp_avg, exp_avg_sq); +} + +at::Tensor& apply_penalty(at::Tensor& logits, + const at::Tensor& presence_penalty, + const at::Tensor& frequency_penalty, + const at::Tensor& p_token_ids, + const at::Tensor& p_token_counts, + const at::Tensor& p_cumsum_seq_len, + int64_t p_max_len_in_batch) { + callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty, + p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch); + return logits; +} + +at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc, + at::Tensor& out) { + callDiopi(diopiDestIndexCopyKV, out, k, dest_loc); + return out; +} + +std::tuple rms_norm( + at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, + const OptionalIntArray& normalized_shape, const at::Tensor& weight, + const c10::optional& bias_opt, double eps) { + callDiopi(diopiRMSNorm, output, inv_rms, input, normalized_shape, weight, + bias_opt, eps); + return std::tie(output, inv_rms); +} + +std::tuple rms_norm_backward( + at::Tensor& grad_input, at::Tensor& grad_weight, at::Tensor& grad_bias_opt, + const at::Tensor& grad_output, const at::Tensor& input, + const at::Tensor& weight, const c10::optional& bias_opt, + const at::Tensor& inv_rms, const OptionalIntArray& normalized_shape, + double eps) { + callDiopi(diopiRMSNormBackward, grad_input, grad_weight, grad_bias_opt, + grad_output, input, weight, bias_opt, inv_rms, normalized_shape, + eps); + return std::tie(grad_input, grad_weight, grad_bias_opt); +} + +at::Tensor& apply_rotary(at::Tensor& output, const at::Tensor& input, + const at::Tensor& cos, const at::Tensor& sin, + const bool conj, const bool interleaved) { + callDiopi(diopiRotaryEmbedding, output, input, cos, sin, conj, interleaved); + return output; +} + +at::Tensor& example_for_all_backend(at::Tensor& inout) { + std::cout << __FUNCTION__ << ": " << inout.options() << "\n"; + return inout; +} + +at::Tensor& example_only_for_xpu(at::Tensor& inout) { + std::cout << __FUNCTION__ << ": " << inout.options() << "\n"; + return inout; +} + +// By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1, +// AutogradPrivateUse1 etc) are registered. If you need to register separately +// for a certain backend, separate registration for a certain backend is also +// supported. +TORCH_LIBRARY(deeplink_ext_, m) { + if (&diopiAdamW != nullptr) { + m.def( + "adamw(Tensor(a!) param, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, " + "Tensor? max_exp_avg_sq_opt, Tensor grad, float lr, float beta1, float " + "beta2, float epsilon, float weight_decay, int step, bool " + "amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!))", + adamw); + } + if (&diopiApplyPenalty != nullptr) { + m.def( + "apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor " + "frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor " + "p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)", + apply_penalty); + } + if (&diopiDestIndexCopyKV != nullptr) { + m.def( + "dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor " + "dest_loc)->Tensor(a!)", + dest_index_copy_kv); + } + if (&diopiDestIndexCopyKV != nullptr) { + m.def( + "rms_norm(Tensor(a!) output, Tensor(b!) inv_rms, Tensor input, int[]? " + "normalized_shape, Tensor weight, Tensor? bias_opt, float eps) -> " + "(Tensor(a!), Tensor(b!))", + rms_norm); + } + + if (&diopiRMSNormBackward != nullptr) { + m.def( + "rms_norm_backward(Tensor(a!) grad_input, Tensor(b!) grad_weight, " + "Tensor(c!) grad_bias_opt, Tensor grad_output, Tensor input, Tensor " + "weight, Tensor? bias_opt, Tensor inv_rms, int[]? normalized_shape, " + "float eps) -> (Tensor(a!), Tensor(b!), Tensor(c!))", + rms_norm_backward); + } + if (&diopiRotaryEmbedding != nullptr) { + m.def( + "apply_rotary(Tensor(a!) output, Tensor input, Tensor cos, Tensor sin, " + "bool conj, bool interleaved) -> Tensor(a!)", + apply_rotary); + } + + m.def("example(Tensor(a!) inout)->Tensor(a!)", example_for_all_backend); +} + +// only impl for dipu +TORCH_LIBRARY_IMPL(deeplink_ext_, XPU, m) { + // m.impl("example", example_only_for_xpu); +} + +int n = [](){ + std::cout << "deeplink_ext_ loaded" << std::endl; + return 0; +}(); + } // namespace dipu::dipu_ext diff --git a/setup.py b/setup.py index 0f7fb640..352ca023 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ # Copyright (c) 2024, DeepLink. from setuptools import find_packages, setup, Extension -from torch.utils.cpp_extension import BuildExtension, include_paths, library_paths +from torch.utils.cpp_extension import BuildExtension, CppExtension, include_paths, library_paths + import glob import os import subprocess @@ -86,3 +87,16 @@ def get_ext(): cmdclass={"build_ext": BuildExtensionWithCompdb}, install_requires=["einops"], ) + + +setup( + name='deeplink_ext_ops', + ext_modules=[ + CppExtension( + name='deeplink_ext_ops', + sources=glob.glob("./csrc/*.cpp"), + extra_compile_args=[' -g ']), + ], + cmdclass={ + 'build_ext': BuildExtension + }) \ No newline at end of file diff --git a/test_dispatch.py b/test_dispatch.py new file mode 100644 index 00000000..4e2b62f5 --- /dev/null +++ b/test_dispatch.py @@ -0,0 +1,26 @@ +import torch +import torch_dipu +import deeplink_ext + +so_path = deeplink_ext.__path__[0] + "/cpp_extensions.cpython-39-x86_64-linux-gnu.so" +torch.ops.load_library(so_path) +print(f"torch.ops.loaded_libraries:{torch.ops.loaded_libraries}") + +#print(torch.ops.deeplink_ext_.dest_index_copy_kv) + +def code_to_profile(): + x = torch.randn(3,4) + y = torch.ops.deeplink_ext_.example(x) + y = torch.ops.deeplink_ext_.example(x.cuda()) + + +with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] +) as p: + code_to_profile() +print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) +