diff --git a/backends/mediatek/preprocess.py b/backends/mediatek/preprocess.py index 6ffaebb7589..b2a79dafabe 100644 --- a/backends/mediatek/preprocess.py +++ b/backends/mediatek/preprocess.py @@ -4,14 +4,16 @@ # except in compliance with the License. See the license file in the root # directory of this source tree for more details. +import collections import contextlib import struct -from typing import final, List +from typing import Dict, final, List import mtk_converter import mtk_neuron import torch +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.backend_details import ( BackendDetails, ExportedProgram, @@ -20,6 +22,9 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec SKIP_COMPILE_SPEC_KEYS = {"ImportForever"} +EXTRACT_SHARED_BLOB_KEY = "ExtractSharedBlobKey" +HEADER_SIZE = 13 +HEADER_VERSION = 1 REQUIRED_COMPILE_SPEC_KEYS = {"platform-config"} SUPPORTED_PLATFORM_CONFIGS = {"mt6989", "mt6991"} @@ -41,6 +46,21 @@ def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None: ) +def _pack_header(num_inputs, num_outputs, model_bytes_size): + header_bytes = struct.pack( + " Dict[str, list[PreprocessResult]]: + + # Follow the default behavior of `preprocess_multimethod` + preprocess_results = {} + for method_name, programs in edge_programs.items(): + assert ( + method_name in compile_specs + ), f"Error: missing compile specs for {method_name}" + compile_specs_for_method = compile_specs[method_name] + assert len(compile_specs_for_method) == len( + programs + ), f"Error: method {method_name} has {len(programs)} partitions but only {len(compile_specs_for_method)}" + results_for_method = [] + for program, compile_spec_for_program in zip( + programs, compile_specs_for_method + ): + preprocess_result = cls.preprocess(program, compile_spec_for_program) + results_for_method.append(preprocess_result) + + preprocess_results[method_name] = results_for_method + + # Try extract shared data blob if necessary + infos_dict = collections.defaultdict(list) + models_dict = collections.defaultdict(list) + result_dict = collections.defaultdict(list) + for method_name, method_results in preprocess_results.items(): + for idx, result in enumerate(method_results): + shared_blob_key = None + for spec in compile_specs[method_name][idx]: + if spec.key == EXTRACT_SHARED_BLOB_KEY: + shared_blob_key = spec.value.decode("utf-8") + + if shared_blob_key is None: + continue + + header_bytes = result.processed_bytes[:HEADER_SIZE] + model_bytes = result.processed_bytes[HEADER_SIZE:] + num_inputs, num_outputs, model_bytes_size = _unpack_header(header_bytes) + assert len(model_bytes) == model_bytes_size + infos_dict[shared_blob_key].append((num_inputs, num_outputs)) + models_dict[shared_blob_key].append(model_bytes) + result_dict[shared_blob_key].append(result) + + data_store_output_dict = {} + for key, models in models_dict.items(): + ndm = NamedDataStore() + blob, new_models = mtk_neuron.extract_shared_data( + models, options="-e union" + ) + ndm.add_named_data(key, bytes(blob)) + data_store_output_dict[key] = ndm.get_named_data_store_output() + models.clear() + models.extend(new_models) + + for key, data_store_output in data_store_output_dict.items(): + for idx, (model_info, model_bytes) in enumerate( + zip(infos_dict[key], models_dict[key]) + ): + num_inputs, num_outputs = model_info + header_bytes = _pack_header(num_inputs, num_outputs, len(model_bytes)) + result_dict[key][idx].data_store_output = data_store_output + result_dict[key][idx].processed_bytes = bytes( + header_bytes + model_bytes + ) + + return preprocess_results diff --git a/backends/mediatek/runtime/NeuronBackend.cpp b/backends/mediatek/runtime/NeuronBackend.cpp index 6319089dd3d..7b4084f66b8 100644 --- a/backends/mediatek/runtime/NeuronBackend.cpp +++ b/backends/mediatek/runtime/NeuronBackend.cpp @@ -12,8 +12,8 @@ #include "NeuronPayloadHeader.h" #include "api/NeuronAdapter.h" +#include #include "executorch/runtime/core/error.h" -#include "executorch/runtime/core/exec_aten/util/dim_order_util.h" #include #include @@ -24,6 +24,7 @@ namespace executorch { namespace backends { namespace neuron { +using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap; using executorch::runtime::ArrayRef; using executorch::runtime::BackendExecutionContext; using executorch::runtime::BackendInitContext; @@ -38,12 +39,22 @@ using executorch::runtime::Span; const char kHighAddrKey[] = "HighAddr"; const char kImportForeverKey[] = "ImportForever"; +const char kSharedWeightsKey[] = "ExtractSharedBlobKey"; Result NeuronBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef compile_specs) const { NeuronDelegateSetting setting; + MemoryAllocator* runtime_allocator = context.get_runtime_allocator(); + NeuronExecuTorchDelegate* delegate = + runtime_allocator->allocateInstance(); + if (delegate == nullptr) { + return Error::MemoryAllocationFailed; + } + + new (delegate) NeuronExecuTorchDelegate(); + for (auto& compile_spec : compile_specs) { if (std::strcmp(compile_spec.key, kHighAddrKey) == 0) { setting.mHighAddr = *static_cast(compile_spec.value.buffer); @@ -54,11 +65,62 @@ Result NeuronBackend::init( "NeuronBackend", "IsImportForever Enable : %d", setting.mImportForever); + } else if (std::strcmp(compile_spec.key, kSharedWeightsKey) == 0) { + setting.mSharedWeights = true; + std::string shared_weights_key( + static_cast(compile_spec.value.buffer), + compile_spec.value.nbytes); + LogInfo( + "NeuronBackend", + "SharedWeights Enabled for %s", + shared_weights_key.c_str()); + std::shared_ptr neuron_shared_weights; + if (neuron_shared_weights_cache_.find(shared_weights_key) != + neuron_shared_weights_cache_.end()) { + neuron_shared_weights = + neuron_shared_weights_cache_.at(shared_weights_key).lock(); + if (neuron_shared_weights) { + LogInfo( + "NeuronBackend", + "Reusing cached shared weights with key %s", + shared_weights_key.c_str()); + delegate->SetSharedWeights(neuron_shared_weights); + continue; + } else { + LogInfo( + "NeuronBackend", + "Shared weights cache expired: %s", + shared_weights_key.c_str()); + neuron_shared_weights_cache_.erase(shared_weights_key); // Expired + } + } + const NamedDataMap* named_data_map = context.get_named_data_map(); + Result shared_weights = + named_data_map->get_data(shared_weights_key.c_str()); + + if (shared_weights.ok()) { + LogInfo( + "NeuronBackend", + "Loaded shared weights from named_data_map. Size: %zu", + shared_weights.get().size()); + FreeableBuffer& buffer = shared_weights.get(); + neuron_shared_weights = + std::make_shared(std::move(buffer)); + delegate->SetSharedWeights(neuron_shared_weights); + neuron_shared_weights_cache_[shared_weights_key] = + neuron_shared_weights; + } else { + LogError( + "NeuronBackend", + "Failed to load shared weights from named_data_map."); + return Error::Internal; + } } else { LogWarn("NeuronBackend", "unknown compile spec: %s", compile_spec.key); } } auto Payload = NeuronPayload(processed->data(), processed->size()); + LogInfo( "NeuronBackend", "version %u, input %u, output %u, length %u, payload size: %zu", @@ -68,19 +130,7 @@ Result NeuronBackend::init( Payload.Header.DataLen, processed->size()); - MemoryAllocator* runtime_allocator = context.get_runtime_allocator(); - NeuronExecuTorchDelegate* delegate = - runtime_allocator->allocateInstance(); - if (delegate == nullptr) { - return Error::MemoryAllocationFailed; - } - - new (delegate) NeuronExecuTorchDelegate(); - - if (delegate == nullptr) { - return nullptr; - } - auto res = delegate->LoadCompiledNetwork(Payload, setting); + int res = delegate->LoadCompiledNetwork(Payload, setting); return res == NEURON_NO_ERROR ? delegate : nullptr; } @@ -112,21 +162,22 @@ Error NeuronExecuTorchDelegate::execute( return Error::InvalidState; }; + ET_CHECK_OR_RETURN_ERROR( + CheckDimOrder(args) == NEURON_NO_ERROR, + Internal, + "Expecting default dim_order but got a non default dim_order tensor input"); + + PrepareInputsOuputs(args); + auto allocator = dynamic_cast(context.get_temp_allocator()); - size_t inputCount = mInputSizes.size(), outputCount = mOutputSizes.size(); - - for (int i = 0; i < inputCount; i++) { - auto tensor_in = args[i]->toTensor(); - ET_CHECK_OR_RETURN_ERROR( - runtime::is_contiguous_dim_order( - tensor_in.dim_order().data(), tensor_in.dim()), - Internal, - "Expecting default dim_order but got a non default dim_order tensor for external input %u", - i); - - auto data_ptr = args[i]->toTensor().data_ptr(); - auto data_size = args[i]->toTensor().nbytes(); + + size_t inputCount = mInputSizes.size() + neuron_shared_weights_.size(); + size_t outputCount = mOutputSizes.size(); + + for (size_t i = 0; i < inputCount; i++) { + auto data_ptr = mPreparedInputs[i].data_ptr; + auto data_size = mPreparedInputs[i].size; if (IsCached(i, data_ptr)) { continue; }; @@ -141,22 +192,20 @@ Error NeuronExecuTorchDelegate::execute( } } - for (int o = inputCount; o < inputCount + outputCount; o++) { - auto data_ptr = args[o]->toTensor().data_ptr(); - auto data_size = args[o]->toTensor().nbytes(); - auto output_index = o - inputCount; - if (IsCached(output_index, data_ptr)) { + for (size_t o = 0; o < outputCount; o++) { + auto data_ptr = mPreparedOutputs[o].data_ptr; + auto data_size = mPreparedOutputs[o].size; + if (IsCached(o, data_ptr)) { continue; }; auto unit = allocator != nullptr ? allocator->Find(data_ptr) : nullptr; if (unit) { - UpdateCache(output_index, data_ptr); + UpdateCache(o, data_ptr); size_t offset = (char*)data_ptr - (char*)unit->GetAddress(); mExecutor.SetInputOutputFromMemory( - output_index, unit->GetNeuronMemory(), offset, data_size); + o, unit->GetNeuronMemory(), offset, data_size); } else { - mExecutor.SetInputOutput( - output_index, data_ptr, data_size); + mExecutor.SetInputOutput(o, data_ptr, data_size); } } diff --git a/backends/mediatek/runtime/include/NeuronBackend.h b/backends/mediatek/runtime/include/NeuronBackend.h index 529b11d48ee..1d2e8563ab3 100644 --- a/backends/mediatek/runtime/include/NeuronBackend.h +++ b/backends/mediatek/runtime/include/NeuronBackend.h @@ -18,6 +18,7 @@ #include #include #include +#include "executorch/runtime/core/exec_aten/util/dim_order_util.h" #include #include @@ -27,6 +28,50 @@ namespace executorch { namespace backends { namespace neuron { +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; +using executorch::runtime::Span; + +class NeuronSharedWeights { + public: + explicit NeuronSharedWeights(const FreeableBuffer& shared_weights_buffer) { + auto& buffer_allocator = GET_NEURON_ALLOCATOR; + nbytes_ = shared_weights_buffer.size(); + data_ = buffer_allocator.Allocate(nbytes_); + ET_CHECK_MSG( + data_ != nullptr, + "Error: Failed to allocate memory for shared weights of size %zu", + nbytes_); + std::memcpy(data_, shared_weights_buffer.data(), nbytes_); + } + + explicit NeuronSharedWeights(FreeableBuffer&& shared_weights_buffer) + : NeuronSharedWeights(shared_weights_buffer) { + shared_weights_buffer.Free(); + } + + ~NeuronSharedWeights() { + if (data_ == nullptr || nbytes_ == 0) { + return; + } + auto& buffer_allocator = GET_NEURON_ALLOCATOR; + buffer_allocator.RemoveBuffer(data_); + } + + void* data() const { + return data_; + } + + size_t size() const { + return nbytes_; + } + + private: + void* data_ = nullptr; + size_t nbytes_ = 0; +}; + class NeuronBackend final : public ::executorch::runtime::BackendInterface { public: ::executorch::runtime::Result<::executorch::runtime::DelegateHandle*> init( @@ -44,6 +89,10 @@ class NeuronBackend final : public ::executorch::runtime::BackendInterface { void destroy(::executorch::runtime::DelegateHandle* handle) const override; bool is_available() const override; + + private: + mutable std::unordered_map> + neuron_shared_weights_cache_; }; extern const char kHighAddrKey[]; @@ -54,6 +103,8 @@ struct NeuronDelegateSetting { bool mImportForever = false; + bool mSharedWeights = false; + std::string ToRuntimeOption() { if (mHighAddr && mImportForever) { return "--apusys-config \"{ \\\"high_addr\\\": true, \\\"import_forever\\\": true }\""; @@ -69,6 +120,13 @@ struct NeuronDelegateSetting { class NeuronExecuTorchDelegate { public: + struct InputOutputInfo { + void* data_ptr; + size_t size; + + InputOutputInfo(void* ptr, size_t sz) : data_ptr(ptr), size(sz) {} + }; + class MemoryCache { public: template @@ -104,16 +162,22 @@ class NeuronExecuTorchDelegate { auto res = mExecutor.LoadFromCompiledNetwork( payload.CompiledNetwork, payload.Header.DataLen, - payload.Header.InputCount, + mSettings.mSharedWeights ? payload.Header.InputCount + 1 + : payload.Header.InputCount, payload.Header.OutputCount, runtimeOption); CHECK_NO_ERROR(res); CHECK_TRUE(mExecutor.IsValid()); - SummaryIoCounts(); + SummarizeIoSizes(payload.Header.InputCount, payload.Header.OutputCount); mPLock = std::unique_ptr(new ScopePerformancer); return NEURON_NO_ERROR; } + int SetSharedWeights(std::shared_ptr sharedWeights) { + neuron_shared_weights_.push_back(sharedWeights); + return NEURON_NO_ERROR; + } + ::executorch::runtime::Error execute( ET_UNUSED ::executorch::runtime::BackendExecutionContext& context, ::executorch::runtime::Span<::executorch::runtime::EValue*> args) const; @@ -129,19 +193,19 @@ class NeuronExecuTorchDelegate { mCache.UpdateCache(index, ptr); } - int SummaryIoCounts() { - for (int i = 0;; i++) { + int SummarizeIoSizes(uint32_t input_count, uint32_t output_count) { + for (int i = 0; i < input_count; i++) { size_t size = mExecutor.GetInputOutputPaddedSize(i); if (size == 0) { - break; + LogWarn("NeuronBackend", "Model input:%d got size: %lu", i, size); } LogInfo("NeuronBackend", "Model input:%d size: %lu", i, size); mInputSizes.push_back(size); } - for (int o = 0;; o++) { + for (int o = 0; o < output_count; o++) { size_t size = mExecutor.GetInputOutputPaddedSize(o); if (size == 0) { - break; + LogWarn("NeuronBackend", "Model output:%d got size: %lu", o, size); } LogInfo("NeuronBackend", "Model output:%d size: %lu", o, size); mOutputSizes.push_back(size); @@ -149,14 +213,70 @@ class NeuronExecuTorchDelegate { return NEURON_NO_ERROR; } - int HintNeuronBackend( - ::executorch::runtime::Span<::executorch::runtime::EValue*> args) const; + int CheckDimOrder(Span args) const { + size_t data_input_count = mInputSizes.size(); + for (int i = 0; i < data_input_count; i++) { + auto tensor_in = args[i]->toTensor(); + LogInfo("NeuronBackend", "Checking dim order for input %d", i); + if (!runtime::is_contiguous_dim_order( + tensor_in.dim_order().data(), tensor_in.dim())) { + return NEURON_BAD_DATA; + } + } + + return NEURON_NO_ERROR; + } + + int PrepareInputsOuputs(Span args) const { + bool has_shared_weights_input = neuron_shared_weights_.size() > 0; + + size_t data_input_count = mInputSizes.size(); + size_t data_output_count = mOutputSizes.size(); + + mPreparedInputs.clear(); + mPreparedOutputs.clear(); + mPreparedInputs.reserve(data_input_count); + mPreparedOutputs.reserve(data_output_count); + + // Prepare input data + for (int i = 0; i < data_input_count; i++) { + auto tensor_in = args[i]->toTensor(); + auto data_ptr = tensor_in.data_ptr(); + auto data_size = tensor_in.nbytes(); + mPreparedInputs.push_back(InputOutputInfo{data_ptr, data_size}); + } + + // Prepare shared weights if any as the last model inputs + if (has_shared_weights_input) { + for (const auto& shared_weights : neuron_shared_weights_) { + mPreparedInputs.push_back( + InputOutputInfo{shared_weights->data(), shared_weights->size()}); + } + } + + // Prepare output data + for (int o = data_input_count; o < data_input_count + data_output_count; + o++) { + auto tensor_out = args[o]->toTensor(); + auto data_ptr = tensor_out.data_ptr(); + auto data_size = tensor_out.nbytes(); + mPreparedOutputs.push_back(InputOutputInfo{data_ptr, data_size}); + } + + return NEURON_NO_ERROR; + } + + int HintNeuronBackend(Span args) const; private: std::vector mInputSizes; std::vector mOutputSizes; + mutable std::vector mPreparedInputs; + + mutable std::vector mPreparedOutputs; + mutable MemoryCache mCache; std::unique_ptr mPLock; @@ -167,6 +287,9 @@ class NeuronExecuTorchDelegate { mutable std::unordered_set mHasImported; + mutable std::vector> + neuron_shared_weights_; + private: NeuronExecuTorchDelegate(const NeuronExecuTorchDelegate&); diff --git a/examples/mediatek/aot_utils/llm_utils/utils.py b/examples/mediatek/aot_utils/llm_utils/utils.py index a7c242559a1..e7fd039bb9e 100644 --- a/examples/mediatek/aot_utils/llm_utils/utils.py +++ b/examples/mediatek/aot_utils/llm_utils/utils.py @@ -336,10 +336,10 @@ def generate_mask( return combined_mask.copy() -def get_dest_path(output_folder, exp_name, shape, chunk_idx): - dest_folder_root = output_folder + f"_{shape}" +def get_dest_path(output_folder, exp_name, shape=None, chunk_idx=0): + dest_folder_root = output_folder + f"{f'_{shape}' if shape is not None else ''}" os.makedirs(dest_folder_root, exist_ok=True) - fname = f"{exp_name}_{shape}_{chunk_idx}.pte" + fname = f"{exp_name}{f'_{shape}' if shape is not None else ''}_{chunk_idx}.pte" dest_path = os.path.join(dest_folder_root, fname) return dest_path diff --git a/examples/mediatek/executor_runner/llama_runner/LlamaConfig.h b/examples/mediatek/executor_runner/llama_runner/LlamaConfig.h index f512d59b5c5..549c5e3f3cc 100644 --- a/examples/mediatek/executor_runner/llama_runner/LlamaConfig.h +++ b/examples/mediatek/executor_runner/llama_runner/LlamaConfig.h @@ -40,6 +40,7 @@ struct LlamaModelPaths { std::string token_embedding_path; std::vector prompt_model_paths; std::vector gen_model_paths; + std::vector model_package_paths; }; } // namespace example diff --git a/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp b/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp index bf4b3eefdde..861c7911c89 100644 --- a/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp +++ b/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp @@ -21,6 +21,7 @@ #include "LlamaConfig.h" #include "LlamaModelChunk.h" +#include "Utils.h" #include "llm_helper/include/llm_types.h" #include "llm_helper/include/mask_builder.h" @@ -42,11 +43,13 @@ inline std::vector getIndexRange( LlamaModelChunk::LlamaModelChunk( const ModelPathMap& modelPathMap, const LlamaModelOptions& modelOptions, + const bool useSharedWeights, const size_t initBatchSize, const size_t numCache, const size_t numRotEmbInputs, const RotaryEmbeddingMasterLut* rotEmbMasterLut) : ModelChunk(modelPathMap, initBatchSize), + kIsSharedWeightsUsed(useSharedWeights), kMaxTokenLength(modelOptions.max_token_length), kCacheLength(modelOptions.cache_size), kMaskType(modelOptions.mask_type), @@ -61,6 +64,31 @@ LlamaModelChunk::LlamaModelChunk( LlamaModelChunk::~LlamaModelChunk() {} +std::string LlamaModelChunk::SelectMethod( + const std::vector& methodNames) const { + const size_t curTokenSize = GetModelId(); + for (const auto& methodName : methodNames) { + const auto matches = utils::extract_substr(methodName, "([0-9]+)t[0-9]+c"); + ET_CHECK_MSG( + matches.size() == 2, "Invalid method name: %s", methodName.c_str()); + // Extract the first match group as token size + const size_t methodTokenSize = + static_cast(std::atol(matches[1].c_str())); + if (curTokenSize == methodTokenSize) { + ET_LOG( + Debug, + "Selected method \"%s\" for token size %zu", + methodName.c_str(), + curTokenSize); + return methodName; + } + } + ET_LOG( + Error, + "Unable to find suitable method, fallback to use the first method."); + return {}; +} + size_t LlamaModelChunk::GetExpectedInputCount() const { const size_t rotEmbInputCount = kRotEmbInputIndexes.size(); const size_t cacheInputCount = kCacheInputIndexes.size(); diff --git a/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h b/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h index 0a5002199db..5e791a4fa75 100644 --- a/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h +++ b/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h @@ -44,6 +44,7 @@ class LlamaModelChunk : public ModelChunk { explicit LlamaModelChunk( const ModelPathMap& modelPathMap, const LlamaModelOptions& modelOptions, + const bool useSharedWeights, const size_t initBatchSize, const size_t numCache, const size_t numRotEmbInputs, @@ -104,6 +105,17 @@ class LlamaModelChunk : public ModelChunk { size_t GetExpectedOutputCount() const; private: + bool AllowModelsCoexist() const override { + return kIsSharedWeightsUsed; + } + + std::string SelectMethod( + const std::vector& methodNames) const override; + + private: + // Whether shared weights is used + bool kIsSharedWeightsUsed = false; + // Input/Output Indexes const size_t kMaskInputIndex; const std::vector kRotEmbInputIndexes; diff --git a/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp b/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp index 2254241a001..620df6cf2ff 100644 --- a/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp +++ b/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp @@ -24,9 +24,6 @@ void LlamaRuntime::Initialize( const LlamaModelOptions& modelOptions, const LlamaModelPaths& modelPaths) { mModelOptions = modelOptions; - const size_t numChunk = modelPaths.gen_model_paths.size(); - const size_t numCache = 2 * modelOptions.num_layer / numChunk; - ET_CHECK_MSG(numChunk > 0, "No model to initialize"); // Initialize rotary embedding master lookup table const size_t rotEmbDim = modelOptions.hidden_size / modelOptions.num_head; @@ -37,12 +34,36 @@ void LlamaRuntime::Initialize( modelOptions.rot_emb_base); mRotEmbMasterLut->generate(); + const bool useSharedWeights = !modelPaths.model_package_paths.empty(); + + ET_CHECK_MSG( + !useSharedWeights || + modelPaths.prompt_model_paths.empty() && + modelPaths.gen_model_paths.empty(), + "The paths for both prompt and gen model paths should be empty when shared weights is used."); + + const size_t numChunk = useSharedWeights + ? modelPaths.model_package_paths.size() + : modelPaths.gen_model_paths.size(); + ET_CHECK_MSG(numChunk > 0, "No model to initialize"); + const size_t numCache = 2 * modelOptions.num_layer / numChunk; + constexpr size_t numRotEmbInputs = 1; - const bool usePromptModel = !modelPaths.prompt_model_paths.empty(); + const bool usePromptModel = !modelPaths.prompt_model_paths.empty() || + !modelPaths.model_package_paths.empty(); const size_t initBatchSize = usePromptModel ? modelOptions.prompt_token_batch_size : 1; mTokenBatchSize = initBatchSize; + // Get effective prompt and gen model paths + const auto& [prompt_model_paths, gen_model_paths] = [&] { + if (useSharedWeights) { + return std::pair{ + modelPaths.model_package_paths, modelPaths.model_package_paths}; + } + return std::pair{modelPaths.prompt_model_paths, modelPaths.gen_model_paths}; + }(); + for (size_t chunkIdx = 0; chunkIdx < numChunk; chunkIdx++) { ModelPathMap modelPathMap; auto addModelPath = [&](const auto& modelPaths, const size_t batchSize) { @@ -50,12 +71,12 @@ void LlamaRuntime::Initialize( return; modelPathMap[batchSize] = modelPaths[chunkIdx]; }; - addModelPath( - modelPaths.prompt_model_paths, modelOptions.prompt_token_batch_size); - addModelPath(modelPaths.gen_model_paths, 1); + addModelPath(prompt_model_paths, modelOptions.prompt_token_batch_size); + addModelPath(gen_model_paths, 1); auto llamaChunk = std::make_unique( modelPathMap, modelOptions, + useSharedWeights, initBatchSize, numCache, numRotEmbInputs, diff --git a/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp b/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp index 7f67fb4ca79..0b4d38ee851 100644 --- a/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp +++ b/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp @@ -41,20 +41,171 @@ using executorch::runtime::Tag; static constexpr size_t kMethodAllocatorPoolSize = 4 * 1024U * 1024U; // 4MB -// ExecuTorch model instance -// The member ordering affects the order of destruction. -struct ModelInstance { - std::unique_ptr program; +// ExecuTorch model instance with cacheable program. +class ModelInstance { + public: + explicit ModelInstance(const std::string& modelPath) { + if (mCachedPrograms.find(modelPath) != mCachedPrograms.end()) { + auto cachedProgram = mCachedPrograms.at(modelPath).lock(); + if (cachedProgram) { + mProgramInstance = cachedProgram; + ET_LOG( + Debug, "Loaded existing program from cache: %s", modelPath.c_str()); + return; + } else { + mCachedPrograms.erase(modelPath); // Expired + } + } + ET_LOG(Debug, "Loading model from scratch: %s", modelPath.c_str()); + mProgramInstance = std::make_shared(); + + // Create a loader to get the data of the program file. There are other + // DataLoaders that use mmap() or point to data that's already in memory, + // and users can create their own DataLoaders to load from arbitrary + // sources. + Result loader = FileDataLoader::from(modelPath.c_str()); + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + loader.error()); + // Extract the data loader out to a persistent storage before loading the + // program. + mProgramInstance->dataLoader = + std::make_unique(std::move(loader.get())); + + // Parse the program file. This is immutable, and can also be reused between + // multiple execution invocations across multiple threads. + Result program_loaded = + Program::load(mProgramInstance->dataLoader.get()); + ET_CHECK_MSG( + program_loaded.ok(), + "Failed to parse model file %s", + modelPath.c_str()); + ET_LOG(Debug, "Model file %s is loaded.", modelPath.c_str()); + + // Extract program out to a persistent storage before calling any of its + // methods. + mProgramInstance->program = + std::make_unique(std::move(program_loaded.get())); + mCachedPrograms.emplace(modelPath, mProgramInstance); + } + + Method& GetMethod() { + ET_CHECK_MSG(mMethod != nullptr, "Method is not loaded."); + return *mMethod; + } + + const Method& GetMethod() const { + ET_CHECK_MSG(mMethod != nullptr, "Method is not loaded."); + return *mMethod; + } + + Program& GetProgram() { + return *(mProgramInstance->program); + } + + const Program& GetProgram() const { + return *(mProgramInstance->program); + } - std::vector> planned_buffers; - std::vector> planned_spans; + std::vector GetMethodNames() const { + std::vector methodNames; + for (size_t i = 0; i < GetProgram().num_methods(); i++) { + const auto method_name_result = GetProgram().get_method_name(i); + ET_CHECK_MSG(method_name_result.ok(), "Program has no method %zu", i); + methodNames.emplace_back(*method_name_result); + } + return methodNames; + } - std::vector method_allocator_pool; - std::unique_ptr method_allocator; - std::unique_ptr planned_memory; - std::unique_ptr memory_manager; + void LoadFirstMethod() { + // Use the first method in the program. + const char* method_name = nullptr; + const auto method_name_result = GetProgram().get_method_name(0); + ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); + method_name = *method_name_result; + ET_LOG(Debug, "Loading the first method."); + LoadMethod(method_name); + } - std::unique_ptr method; + void LoadMethod(const std::string& method_name) { + ET_CHECK_MSG(!mMethod, "Method is already loaded."); + const auto method_name_cstr = method_name.c_str(); + + // MethodMeta describes the memory requirements of the method. + Result method_meta = GetProgram().method_meta(method_name_cstr); + ET_CHECK_MSG( + method_meta.ok(), + "Failed to get method_meta for %s: 0x%x", + method_name_cstr, + (unsigned int)method_meta.error()); + + mMethodAllocatorPool.resize(kMethodAllocatorPoolSize); + mMethodAllocator = std::make_unique( + kMethodAllocatorPoolSize, mMethodAllocatorPool.data()); + mMethodAllocator->enable_profiling("method allocator"); + + size_t num_memory_planned_buffers = + method_meta->num_memory_planned_buffers(); + for (size_t id = 0; id < num_memory_planned_buffers; ++id) { + // .get() will always succeed because id < num_memory_planned_buffers. + size_t buffer_size = static_cast( + method_meta->memory_planned_buffer_size(id).get()); + ET_LOG( + Debug, "Setting up planned buffer %zu, size %zu.", id, buffer_size); + mPlannedBuffers.push_back(std::make_unique(buffer_size)); + mPlannedSpans.push_back({mPlannedBuffers.back().get(), buffer_size}); + } + mPlannedMemory = std::make_unique( + Span>{mPlannedSpans.data(), mPlannedSpans.size()}); + + // Assemble all of the allocators into the MemoryManager that the Executor + // will use. + auto& neuron_allocator = GET_NEURON_ALLOCATOR; + mMemoryManager = std::make_unique( + mMethodAllocator.get(), + mPlannedMemory.get(), + dynamic_cast(&neuron_allocator)); + + ET_LOG(Debug, "Loading method %s", method_name_cstr); + Result method = + GetProgram().load_method(method_name_cstr, mMemoryManager.get()); + ET_CHECK_MSG( + method.ok(), + "Loading of method %s failed with status 0x%" PRIx32, + method_name_cstr, + method.error()); + + mMethod = std::make_unique(std::move(method.get())); + } + + private: + ModelInstance(const ModelInstance&) = delete; + ModelInstance& operator=(ModelInstance&&) = delete; + ModelInstance& operator=(const ModelInstance&) = delete; + + private: + // The member ordering below affects the order of destruction. + + struct ProgramInstance { + std::unique_ptr dataLoader; + std::unique_ptr program; + }; + std::shared_ptr mProgramInstance; + + std::vector> mPlannedBuffers; + std::vector> mPlannedSpans; + + std::vector mMethodAllocatorPool; + std::unique_ptr mMethodAllocator; + std::unique_ptr mPlannedMemory; + std::unique_ptr mMemoryManager; + + std::unique_ptr mMethod; + + // Maps .pte file paths to the cached program instances. + inline static std::unordered_map> + mCachedPrograms; }; void ModelChunk::Initialize() { @@ -487,94 +638,18 @@ void ModelChunk::ReleaseIoBuffers() { Method& ModelChunk::GetModelMethod() { auto modelInstance = reinterpret_cast(GetModelInstance()); - return *(modelInstance->method); + return modelInstance->GetMethod(); } // Override the virtual functions void* ModelChunk::CreateModelInstance(const std::string& modelPath) { - auto modelInstance = new ModelInstance; - - // Create a loader to get the data of the program file. There are other - // DataLoaders that use mmap() or point to data that's already in memory, and - // users can create their own DataLoaders to load from arbitrary sources. - Result loader = FileDataLoader::from(modelPath.c_str()); - ET_CHECK_MSG( - loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, loader.error()); - - // Parse the program file. This is immutable, and can also be reused between - // multiple execution invocations across multiple threads. - Result program_loaded = Program::load(&loader.get()); - if (!program_loaded.ok()) { - ET_LOG(Error, "Failed to parse model file %s", modelPath.c_str()); - return nullptr; - } - ET_LOG(Debug, "Model file %s is loaded.", modelPath.c_str()); - - // Extract program out to a persistent storage before calling any of its - // methods. - modelInstance->program = - std::make_unique(std::move(program_loaded.get())); - auto& program = modelInstance->program; - - // Use the first method in the program. - const char* method_name = nullptr; - { - const auto method_name_result = program->get_method_name(0); - ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); - method_name = *method_name_result; - } - ET_LOG(Debug, "Using method %s", method_name); - - // MethodMeta describes the memory requirements of the method. - Result method_meta = program->method_meta(method_name); - ET_CHECK_MSG( - method_meta.ok(), - "Failed to get method_meta for %s: 0x%x", - method_name, - (unsigned int)method_meta.error()); - - modelInstance->method_allocator_pool.resize(kMethodAllocatorPoolSize); - modelInstance->method_allocator = std::make_unique( - kMethodAllocatorPoolSize, modelInstance->method_allocator_pool.data()); - auto& method_allocator = modelInstance->method_allocator; - method_allocator->enable_profiling("method allocator"); - - auto& planned_buffers = modelInstance->planned_buffers; // Owns the memory - auto& planned_spans = modelInstance->planned_spans; // Passed to the allocator - - size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers(); - for (size_t id = 0; id < num_memory_planned_buffers; ++id) { - // .get() will always succeed because id < num_memory_planned_buffers. - size_t buffer_size = - static_cast(method_meta->memory_planned_buffer_size(id).get()); - ET_LOG(Debug, "Setting up planned buffer %zu, size %zu.", id, buffer_size); - planned_buffers.push_back(std::make_unique(buffer_size)); - planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + auto modelInstance = new ModelInstance(modelPath); + const auto selectedMethod = SelectMethod(modelInstance->GetMethodNames()); + if (!selectedMethod.empty()) { + modelInstance->LoadMethod(selectedMethod); + } else { + modelInstance->LoadFirstMethod(); // Load the first available method } - modelInstance->planned_memory = std::make_unique( - Span>{planned_spans.data(), planned_spans.size()}); - auto& planned_memory = modelInstance->planned_memory; - - // Assemble all of the allocators into the MemoryManager that the Executor - // will use. - auto& neuron_allocator = GET_NEURON_ALLOCATOR; - modelInstance->memory_manager = std::make_unique( - method_allocator.get(), - planned_memory.get(), - dynamic_cast(&neuron_allocator)); - auto& memory_manager = modelInstance->memory_manager; - - ET_LOG(Debug, "Begin loading method %s", method_name); - Result method = - program->load_method(method_name, memory_manager.get()); - ET_CHECK_MSG( - method.ok(), - "Loading of method %s failed with status 0x%" PRIx32, - method_name, - method.error()); - ET_LOG(Debug, "Method loaded."); - - modelInstance->method = std::make_unique(std::move(method.get())); return modelInstance; } diff --git a/examples/mediatek/executor_runner/llama_runner/ModelChunk.h b/examples/mediatek/executor_runner/llama_runner/ModelChunk.h index 67d9e30b5f1..66a1f4586bb 100644 --- a/examples/mediatek/executor_runner/llama_runner/ModelChunk.h +++ b/examples/mediatek/executor_runner/llama_runner/ModelChunk.h @@ -94,6 +94,11 @@ class ModelChunk : protected MultiTokenSizeModelLoader { executorch::runtime::Method& GetModelMethod(); private: + virtual std::string SelectMethod( + const std::vector& methodNames) const { + return {}; // Default choose the first available one. + } + // Override the virtual functions void* CreateModelInstance(const std::string& modelPath) override; diff --git a/examples/mediatek/executor_runner/llama_runner/Utils.h b/examples/mediatek/executor_runner/llama_runner/Utils.h index 24e8a4d6e50..9afd036a21e 100644 --- a/examples/mediatek/executor_runner/llama_runner/Utils.h +++ b/examples/mediatek/executor_runner/llama_runner/Utils.h @@ -59,6 +59,22 @@ static std::vector split(const std::string& str, const char sep) { return tokens; } +static std::vector extract_substr( + const std::string& str, + const std::string& pattern) { + std::vector tokens; + const std::regex token_pattern(pattern); + std::smatch matches; + auto cur = str.cbegin(); + while (std::regex_search(cur, str.cend(), matches, token_pattern)) { + for (const auto& match : matches) { + tokens.push_back(match.str()); + } + cur = matches.suffix().first; + } + return tokens; +} + static std::string read_file(const std::string& filepath) { std::ifstream file(filepath); std::stringstream buffer; diff --git a/examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp b/examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp index 733cc8c3465..f1c5de2f40a 100644 --- a/examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp +++ b/examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp @@ -66,14 +66,12 @@ DEFINE_string( token_embedding_path, "embedding.bin", "Input token embedding lookup table path."); +DEFINE_string(prompt_model_paths, "", "Comma-separated prompt model paths."); +DEFINE_string(gen_model_paths, "", "Comma-separated generative model paths."); DEFINE_string( - prompt_model_paths, - "model_128t.pte", - "Comma-separated prompt model paths."); -DEFINE_string( - gen_model_paths, - "model_1t.pte", - "Comma-separated generative model paths."); + model_package_paths, + "", + "Comma-separated weight-shared model package paths."); // Tokenizer DEFINE_string(tokenizer_path, "tokenizer.model", "tokenizer.model vocab path."); @@ -132,7 +130,9 @@ LlamaModelPaths get_model_paths() { .tokenizer_path = FLAGS_tokenizer_path, .token_embedding_path = FLAGS_token_embedding_path, .prompt_model_paths = split(FLAGS_prompt_model_paths, ','), - .gen_model_paths = split(FLAGS_gen_model_paths, ',')}; + .gen_model_paths = split(FLAGS_gen_model_paths, ','), + .model_package_paths = split(FLAGS_model_package_paths, ','), + }; return model_paths; } @@ -311,7 +311,8 @@ int main(int argc, char** argv) { LlamaModelOptions model_options = get_model_options(); LlamaModelPaths model_paths = get_model_paths(); - if (model_paths.prompt_model_paths.empty()) { + if (model_paths.prompt_model_paths.empty() && + model_paths.model_package_paths.empty()) { model_options.prompt_token_batch_size = 1; ET_LOG( Info, diff --git a/examples/mediatek/executor_runner/run_llama3_sample.sh b/examples/mediatek/executor_runner/run_llama3_sample.sh index d9d6ea43e3c..eeb1089d792 100644 --- a/examples/mediatek/executor_runner/run_llama3_sample.sh +++ b/examples/mediatek/executor_runner/run_llama3_sample.sh @@ -36,18 +36,12 @@ TOKENIZER_PATH="/data/local/tmp/llama3/tokenizer.model" TOKEN_EMBEDDING_PATH="/data/local/tmp/llama3/embedding_llama3_8b_instruct_fp32.bin" # Comma-Separated Paths -PROMPT_MODEL_PATHS="\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_0.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_1.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_2.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_3.pte," - -# Comma-Separated Paths -GEN_MODEL_PATHS="\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_0.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_1.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_2.pte,\ -/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_3.pte," +WEIGHT_SHARED_MODEL_PACKAGE_PATHS="\ +/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_0.pte,\ +/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1.pte,\ +/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_2.pte,\ +/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_3.pte,\ +" PROMPT_FILE=/data/local/tmp/llama3/sample_prompt.txt @@ -75,6 +69,5 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD --tokenizer_type=$TOKENIZER_TYPE \ --tokenizer_path=$TOKENIZER_PATH \ --token_embedding_path=$TOKEN_EMBEDDING_PATH \ - --prompt_model_paths=$PROMPT_MODEL_PATHS \ - --gen_model_paths=$GEN_MODEL_PATHS \ + --model_package_paths=$WEIGHT_SHARED_MODEL_PACKAGE_PATHS \ --prompt_file=$PROMPT_FILE \ No newline at end of file diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 6a098e2a9b1..60c57850d00 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -42,6 +42,10 @@ NeuropilotQuantizer, Precision, ) +from executorch.exir.backend.backend_api import ( + MethodProgramsPartitionerSpec, + to_backend, +) from executorch.exir.backend.backend_details import CompileSpec from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from tqdm import tqdm @@ -331,52 +335,64 @@ def export_to_et_ir( prepared_graph(*example_inputs) # dummy calibration converted_graph = convert_pt2e(prepared_graph, fold_quantize=False) - print("Getting ATen Dialect Graph") + method_to_edge_program = {} + method_to_partitioner = {} + edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False) + + model_shared_key_name = f"{exp_name}_{chunk_idx}" + # Fixed Shape Export Here for shape, ntok_and_cache in export_shapes.items(): - dest_path = get_dest_path(output_folder, exp_name, shape, chunk_idx) - print(f"Exporting Shape {shape} to:\n{dest_path}") + model_fname = f"{exp_name}_{shape}_{chunk_idx}" example_inputs = model.get_example_inputs(*ntok_and_cache) + print(f"Getting ATen Dialect Graph for {exp_name} {shape} chunk {chunk_idx}") aten_dialect: exir.ExportedProgram = torch.export.export( converted_graph, example_inputs, strict=True ) - print("Lowering to Edge Dialect Graph") - edge_program: exir.EdgeProgramManager = exir.to_edge( - aten_dialect, - compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), - ) + method_to_edge_program[f"{model_fname}"] = exir.to_edge( + aten_dialect + ).exported_program() del aten_dialect - print("Delegating Edge Program to Neuropilot Backend") compile_spec = [ CompileSpec("gno", b"LTS"), CompileSpec("gno-exp", b""), CompileSpec("gno-non-4d-tiling", b""), CompileSpec("ImportForever", struct.pack("?", True)), CompileSpec("platform-config", b"mt6989"), + CompileSpec("ExtractSharedBlobKey", model_shared_key_name.encode()), ] - partitioner = NeuropilotPartitioner(compile_spec) - delegated_program = edge_program.to_backend(partitioner) - print("Exported Delegated Program:") - print(delegated_program.exported_program()) - del edge_program - - print("Transforming delegated program to executorch backend") - executorch_program = delegated_program.to_executorch( - config=exir.ExecutorchBackendConfig( - memory_planning_pass=exir.passes.MemoryPlanningPass( - alloc_graph_input=False, - alloc_graph_output=False, - ), - extract_delegate_segments=True, - ) - ) + method_to_partitioner[f"{model_fname}"] = NeuropilotPartitioner(compile_spec) - print(f"ET Model Dest: {dest_path}\n") - os.makedirs(dest_path.rsplit("/", 1)[0], exist_ok=True) - with open(dest_path, "wb") as file: - file.write(executorch_program.buffer) + print("Delegating Edge Program to Neuropilot Backend") + delegated_program = to_backend( + MethodProgramsPartitionerSpec(method_to_edge_program, method_to_partitioner) + ) + + edge_manager = exir.EdgeProgramManager( + delegated_program, compile_config=edge_compile_config + ) + del delegated_program + + print("Transforming delegated program to executorch backend") + executorch_program = edge_manager.to_executorch( + config=exir.ExecutorchBackendConfig( + memory_planning_pass=exir.passes.MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + ) + del edge_manager + print(f"\n Model Size: {len(executorch_program.buffer)}") + + dest_path = get_dest_path(output_folder, exp_name, None, chunk_idx) + print(f"{exp_name} ET Model chunk {chunk_idx} Dest: {dest_path}\n") + os.makedirs(dest_path.rsplit("/", 1)[0], exist_ok=True) + with open(dest_path, "wb") as file: + file.write(executorch_program.buffer) def main():