Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 101 additions & 3 deletions backends/mediatek/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

import collections
import contextlib
import struct

from typing import final, List
from typing import Dict, final, List

import mtk_converter
import mtk_neuron
import torch
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir.backend.backend_details import (
BackendDetails,
ExportedProgram,
Expand All @@ -20,6 +22,9 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec

SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
EXTRACT_SHARED_BLOB_KEY = "ExtractSharedBlobKey"
HEADER_SIZE = 13
HEADER_VERSION = 1
REQUIRED_COMPILE_SPEC_KEYS = {"platform-config"}
SUPPORTED_PLATFORM_CONFIGS = {"mt6989", "mt6991"}

Expand All @@ -41,6 +46,21 @@ def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
)


def _pack_header(num_inputs, num_outputs, model_bytes_size):
header_bytes = struct.pack(
"<BIII", HEADER_VERSION, num_inputs, num_outputs, model_bytes_size
)
assert len(header_bytes) == HEADER_SIZE
return header_bytes


def _unpack_header(header_bytes):
assert len(header_bytes) == HEADER_SIZE
version, num_inputs, num_outputs, buffer_size = struct.unpack("<BIII", header_bytes)
assert version == HEADER_VERSION
return num_inputs, num_outputs, buffer_size


@final
class NeuropilotBackend(BackendDetails):

Expand Down Expand Up @@ -90,8 +110,14 @@ def preprocess(

compile_options = ["--relax-fp32", "--opt=3"]
for spec in module_compile_spec:
# Special compile spec handling
if spec.key in SKIP_COMPILE_SPEC_KEYS:
continue
if spec.key == EXTRACT_SHARED_BLOB_KEY:
compile_options.append("--dla-opt=0")
continue

# General compile spec handling
if spec.value == b"":
compile_options.append(f"--{spec.key}")
else:
Expand All @@ -112,5 +138,77 @@ def preprocess(

num_inputs = len(input_names)
num_outputs = len(output_names)
header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes))
return PreprocessResult(processed_bytes=bytes(header + model_bytes))
header_bytes = _pack_header(num_inputs, num_outputs, len(model_bytes))
return PreprocessResult(processed_bytes=bytes(header_bytes + model_bytes))

@classmethod
def preprocess_multimethod(
cls,
edge_programs: Dict[str, List[ExportedProgram]],
compile_specs: Dict[str, List[List[CompileSpec]]],
) -> Dict[str, list[PreprocessResult]]:

# Follow the default behavior of `preprocess_multimethod`
preprocess_results = {}
for method_name, programs in edge_programs.items():
assert (
method_name in compile_specs
), f"Error: missing compile specs for {method_name}"
compile_specs_for_method = compile_specs[method_name]
assert len(compile_specs_for_method) == len(
programs
), f"Error: method {method_name} has {len(programs)} partitions but only {len(compile_specs_for_method)}"
results_for_method = []
for program, compile_spec_for_program in zip(
programs, compile_specs_for_method
):
preprocess_result = cls.preprocess(program, compile_spec_for_program)
results_for_method.append(preprocess_result)

preprocess_results[method_name] = results_for_method

# Try extract shared data blob if necessary
infos_dict = collections.defaultdict(list)
models_dict = collections.defaultdict(list)
result_dict = collections.defaultdict(list)
for method_name, method_results in preprocess_results.items():
for idx, result in enumerate(method_results):
shared_blob_key = None
for spec in compile_specs[method_name][idx]:
if spec.key == EXTRACT_SHARED_BLOB_KEY:
shared_blob_key = spec.value.decode("utf-8")

if shared_blob_key is None:
continue

header_bytes = result.processed_bytes[:HEADER_SIZE]
model_bytes = result.processed_bytes[HEADER_SIZE:]
num_inputs, num_outputs, model_bytes_size = _unpack_header(header_bytes)
assert len(model_bytes) == model_bytes_size
infos_dict[shared_blob_key].append((num_inputs, num_outputs))
models_dict[shared_blob_key].append(model_bytes)
result_dict[shared_blob_key].append(result)

data_store_output_dict = {}
for key, models in models_dict.items():
ndm = NamedDataStore()
blob, new_models = mtk_neuron.extract_shared_data(
models, options="-e union"
)
ndm.add_named_data(key, bytes(blob))
data_store_output_dict[key] = ndm.get_named_data_store_output()
models.clear()
models.extend(new_models)

for key, data_store_output in data_store_output_dict.items():
for idx, (model_info, model_bytes) in enumerate(
zip(infos_dict[key], models_dict[key])
):
num_inputs, num_outputs = model_info
header_bytes = _pack_header(num_inputs, num_outputs, len(model_bytes))
result_dict[key][idx].data_store_output = data_store_output
result_dict[key][idx].processed_bytes = bytes(
header_bytes + model_bytes
)

return preprocess_results
121 changes: 85 additions & 36 deletions backends/mediatek/runtime/NeuronBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#include "NeuronPayloadHeader.h"
#include "api/NeuronAdapter.h"

#include <executorch/runtime/executor/pte_data_map.h>
#include "executorch/runtime/core/error.h"
#include "executorch/runtime/core/exec_aten/util/dim_order_util.h"

#include <algorithm>
#include <memory>
Expand All @@ -24,6 +24,7 @@ namespace executorch {
namespace backends {
namespace neuron {

using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
using executorch::runtime::ArrayRef;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
Expand All @@ -38,12 +39,22 @@ using executorch::runtime::Span;

const char kHighAddrKey[] = "HighAddr";
const char kImportForeverKey[] = "ImportForever";
const char kSharedWeightsKey[] = "ExtractSharedBlobKey";

Result<DelegateHandle*> NeuronBackend::init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const {
NeuronDelegateSetting setting;
MemoryAllocator* runtime_allocator = context.get_runtime_allocator();
NeuronExecuTorchDelegate* delegate =
runtime_allocator->allocateInstance<NeuronExecuTorchDelegate>();
if (delegate == nullptr) {
return Error::MemoryAllocationFailed;
}

new (delegate) NeuronExecuTorchDelegate();

for (auto& compile_spec : compile_specs) {
if (std::strcmp(compile_spec.key, kHighAddrKey) == 0) {
setting.mHighAddr = *static_cast<char*>(compile_spec.value.buffer);
Expand All @@ -54,11 +65,62 @@ Result<DelegateHandle*> NeuronBackend::init(
"NeuronBackend",
"IsImportForever Enable : %d",
setting.mImportForever);
} else if (std::strcmp(compile_spec.key, kSharedWeightsKey) == 0) {
setting.mSharedWeights = true;
std::string shared_weights_key(
static_cast<char*>(compile_spec.value.buffer),
compile_spec.value.nbytes);
LogInfo(
"NeuronBackend",
"SharedWeights Enabled for %s",
shared_weights_key.c_str());
std::shared_ptr<NeuronSharedWeights> neuron_shared_weights;
if (neuron_shared_weights_cache_.find(shared_weights_key) !=
neuron_shared_weights_cache_.end()) {
neuron_shared_weights =
neuron_shared_weights_cache_.at(shared_weights_key).lock();
if (neuron_shared_weights) {
LogInfo(
"NeuronBackend",
"Reusing cached shared weights with key %s",
shared_weights_key.c_str());
delegate->SetSharedWeights(neuron_shared_weights);
continue;
} else {
LogInfo(
"NeuronBackend",
"Shared weights cache expired: %s",
shared_weights_key.c_str());
neuron_shared_weights_cache_.erase(shared_weights_key); // Expired
}
}
const NamedDataMap* named_data_map = context.get_named_data_map();
Result<FreeableBuffer> shared_weights =
named_data_map->get_data(shared_weights_key.c_str());

if (shared_weights.ok()) {
LogInfo(
"NeuronBackend",
"Loaded shared weights from named_data_map. Size: %zu",
shared_weights.get().size());
FreeableBuffer& buffer = shared_weights.get();
neuron_shared_weights =
std::make_shared<NeuronSharedWeights>(std::move(buffer));
delegate->SetSharedWeights(neuron_shared_weights);
neuron_shared_weights_cache_[shared_weights_key] =
neuron_shared_weights;
} else {
LogError(
"NeuronBackend",
"Failed to load shared weights from named_data_map.");
return Error::Internal;
}
} else {
LogWarn("NeuronBackend", "unknown compile spec: %s", compile_spec.key);
}
}
auto Payload = NeuronPayload(processed->data(), processed->size());

LogInfo(
"NeuronBackend",
"version %u, input %u, output %u, length %u, payload size: %zu",
Expand All @@ -68,19 +130,7 @@ Result<DelegateHandle*> NeuronBackend::init(
Payload.Header.DataLen,
processed->size());

MemoryAllocator* runtime_allocator = context.get_runtime_allocator();
NeuronExecuTorchDelegate* delegate =
runtime_allocator->allocateInstance<NeuronExecuTorchDelegate>();
if (delegate == nullptr) {
return Error::MemoryAllocationFailed;
}

new (delegate) NeuronExecuTorchDelegate();

if (delegate == nullptr) {
return nullptr;
}
auto res = delegate->LoadCompiledNetwork(Payload, setting);
int res = delegate->LoadCompiledNetwork(Payload, setting);
return res == NEURON_NO_ERROR ? delegate : nullptr;
}

Expand Down Expand Up @@ -112,21 +162,22 @@ Error NeuronExecuTorchDelegate::execute(
return Error::InvalidState;
};

ET_CHECK_OR_RETURN_ERROR(
CheckDimOrder(args) == NEURON_NO_ERROR,
Internal,
"Expecting default dim_order but got a non default dim_order tensor input");

PrepareInputsOuputs(args);

auto allocator =
dynamic_cast<neuron::BufferAllocator*>(context.get_temp_allocator());
size_t inputCount = mInputSizes.size(), outputCount = mOutputSizes.size();

for (int i = 0; i < inputCount; i++) {
auto tensor_in = args[i]->toTensor();
ET_CHECK_OR_RETURN_ERROR(
runtime::is_contiguous_dim_order(
tensor_in.dim_order().data(), tensor_in.dim()),
Internal,
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
i);

auto data_ptr = args[i]->toTensor().data_ptr();
auto data_size = args[i]->toTensor().nbytes();

size_t inputCount = mInputSizes.size() + neuron_shared_weights_.size();
size_t outputCount = mOutputSizes.size();

for (size_t i = 0; i < inputCount; i++) {
auto data_ptr = mPreparedInputs[i].data_ptr;
auto data_size = mPreparedInputs[i].size;
if (IsCached</*isInput=*/true>(i, data_ptr)) {
continue;
};
Expand All @@ -141,22 +192,20 @@ Error NeuronExecuTorchDelegate::execute(
}
}

for (int o = inputCount; o < inputCount + outputCount; o++) {
auto data_ptr = args[o]->toTensor().data_ptr();
auto data_size = args[o]->toTensor().nbytes();
auto output_index = o - inputCount;
if (IsCached</*isInput=*/false>(output_index, data_ptr)) {
for (size_t o = 0; o < outputCount; o++) {
auto data_ptr = mPreparedOutputs[o].data_ptr;
auto data_size = mPreparedOutputs[o].size;
if (IsCached</*isInput=*/false>(o, data_ptr)) {
continue;
};
auto unit = allocator != nullptr ? allocator->Find(data_ptr) : nullptr;
if (unit) {
UpdateCache</*isInput=*/false>(output_index, data_ptr);
UpdateCache</*isInput=*/false>(o, data_ptr);
size_t offset = (char*)data_ptr - (char*)unit->GetAddress();
mExecutor.SetInputOutputFromMemory</*isInput*/ false>(
output_index, unit->GetNeuronMemory(), offset, data_size);
o, unit->GetNeuronMemory(), offset, data_size);
} else {
mExecutor.SetInputOutput</*isInput=*/false>(
output_index, data_ptr, data_size);
mExecutor.SetInputOutput</*isInput=*/false>(o, data_ptr, data_size);
}
}

Expand Down
Loading
Loading