Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1454,13 +1454,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return Resolve(default_options);
}

/// <summary>
/// This function converts all the graph TensorProto initializers into OrtValues
/// and creates a in-memory external data reference for each OrtValue.
/// </summary>
/// <returns></returns>
Status ConvertInitializersIntoOrtValues();

/**
* @brief Converts a subset of graph TensorProto initializers into OrtValues and updates the graph proto.
*
Expand Down
33 changes: 20 additions & 13 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
std::move(tensor), ort_value);
}
} else {
// for internal initializer, always allocate memory on device - tensor
ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape,
use_device_allocator_for_initializers, alloc));

if (device == default_cpu_device) {
// deserialize directly to CPU tensor
// Do not use arena for internal initializer, just like we do for OrtValue initializers
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true,
tensor_shape, type,
default_cpu_alloc, tensor));
ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, tensor));
Tensor::InitOrtValue(std::move(tensor), ort_value);
return common::Status::OK();
Expand All @@ -154,13 +154,19 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators");
}

// Allocate according to the plan on the device or directly on the device according to
// use_device_allocator_for_initializers
ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape,
use_device_allocator_for_initializers, alloc));

// deserialize to CPU first for non-CPU allocator, then copy
// for internal initializer
// 1. allocate memory on CPU - deserialized_tensor
// 2. deserialize tensor_proto into a preallocated tensor (deserialized_tensor)
// 1. allocate memory on CPU - deserialized_tensor. Do not use arena not to waste space for temporary buffers.
// 2. deserialize tensor_proto into a pre-allocated tensor (deserialized_tensor)
// 3. copy tensor from CPU to device - deserialized_tensor -> tensor (allocated above) -> ort_value
Tensor deserialized_tensor;
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type,
ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true,
tensor_shape, type,
default_cpu_alloc, deserialized_tensor));

ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, deserialized_tensor));
Expand Down Expand Up @@ -346,6 +352,13 @@ common::Status SaveInitializedTensors(
<< i.second << " bytes for " << i.first.ToString() << std::endl;
}

// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
// This option also means to ignore arena if present and use Reserve().
const bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(
kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

// 3. create weight tensors based on weights buffer
for (const auto& entry : id_to_initialized_tensor) {
// We check for cancellation for every initializer since mapping from disk can be costly
Expand Down Expand Up @@ -375,12 +388,6 @@ common::Status SaveInitializedTensors(
// TODO: if the tensor need be copied, does it have enough room?
ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc));

// ??? Should we ignore this session option if the EP is explicitly providing the read only allocator?
// bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator;
const bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(
kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

// Check if we already have an OrtValue for this initializer on CPU
if (OrtValue ort_value_from_graph;
graph.GetOrtValueInitializer(name, ort_value_from_graph)) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IA
if (len > 0) {
p_data = allocator->Alloc(len);
}
Init(elt_type, shape, p_data, allocator, 0L);
Init(elt_type, shape, p_data, std::move(allocator), 0L);
}

Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr<IAllocator> deleter,
ptrdiff_t offset, gsl::span<const int64_t> strides)
: alloc_info_(deleter->Info()) {
ORT_ENFORCE(elt_type != nullptr);
Init(elt_type, shape, p_data, deleter, offset, strides);
Init(elt_type, shape, p_data, std::move(deleter), offset, strides);
}

void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr<IAllocator> allocator,
Expand Down
76 changes: 43 additions & 33 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,28 @@ Graph::Graph(const Model& owning_model,
ArgNameToTypeMap name_to_type_map;
const auto& model_path = ModelPath();

// If the tensor proto data is large enough, move data from TensorProto to an OrtValue
// - Add external data reference to TensorProto that points to an OrtValue.
// This lambda should not be used on initializers that already have external data reference.
// Otherwise, this function does nothing.
auto put_large_tensor_in_ort_value = [this, &model_path](ONNX_NAMESPACE::TensorProto& tensor_proto) {
size_t size_in_bytes = 0;
ORT_THROW_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_THROW_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
assert(ort_value.IsAllocated());
auto ins_result = ortvalue_initializers_.insert_or_assign(tensor_proto_to_add.name(), std::move(ort_value));
ORT_ENFORCE(ins_result.second, "Unexpected duplicate insert or assign OrtValue for tensor: ", tensor_proto_to_add.name(),
" in the initializer list.");
tensor_proto = std::move(tensor_proto_to_add);
}
};

// Process 'Constant' nodes
// Put the 'TensorProto' stored in the 'Constant' nodes attribute into the graphs initializer list
for (auto& node : graph_proto_->node()) {
Expand All @@ -1250,6 +1272,8 @@ Graph::Graph(const Model& owning_model,
}
}

put_large_tensor_in_ort_value(*tensor);

// Ensure initializers are also graph inputs.
if (ir_version_ < 4) {
TypeProto t{utils::TypeProtoFromTensorProto(*tensor)};
Expand Down Expand Up @@ -1326,7 +1350,25 @@ Graph::Graph(const Model& owning_model,
}

// Copy initial tensors to a map.
for (auto& tensor : graph_proto_->initializer()) {
for (int i = 0, lim = graph_proto_->initializer_size(); i < lim; ++i) {
auto& tensor = *graph_proto_->mutable_initializer(i);
// If data is on disk, it will be loaded either by optimizers
// or during session state finalization.
// If data is already in memory, do nothing.
if (!utils::HasExternalData(tensor)) {
// sparse_tensor_names_ contain references to strings to save memory
// in case we replace the tensor_proto, we want to make sure we remove
// the old reference first, and then add a new one.
const bool is_sparse = sparse_tensor_names_.count(tensor.name());
if (is_sparse) {
sparse_tensor_names_.erase(tensor.name());
}
put_large_tensor_in_ort_value(tensor);
if (is_sparse) {
sparse_tensor_names_.emplace(tensor.name());
}
}

auto p = name_to_initial_tensor_.emplace(tensor.name(), &tensor);
if (!p.second) {
LOGS(logger_, WARNING) << "Duplicate initializer (dense, sparse or ConstantNode): '" << tensor.name()
Expand Down Expand Up @@ -3415,38 +3457,6 @@ Status Graph::Resolve(const ResolveOptions& options) {
return ForThisAndAllSubgraphs(all_subgraphs, finalize_func);
}

Status Graph::ConvertInitializersIntoOrtValues() {
std::vector<Graph*> all_subgraphs;
FindAllSubgraphs(all_subgraphs);

auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
const auto& model_path = graph.ModelPath();
auto& graph_proto = *graph.graph_proto_;
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
auto& tensor_proto = *graph_proto.mutable_initializer(i);
if (utils::HasExternalData(tensor_proto)) {
continue; // ignore data on disk, that will be loaded either by EP or at session_state finalize
}

size_t size_in_bytes = 0;
ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &size_in_bytes));
if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) {
OrtValue ort_value;
ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto,
CPUAllocator::DefaultInstance(), ort_value));
constexpr const bool use_tensor_buffer_true = true;
auto tensor_proto_to_add = utils::TensorToTensorProto(ort_value.Get<Tensor>(), tensor_proto.name(),
use_tensor_buffer_true);
ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(tensor_proto_to_add, ort_value));
}
}
return Status::OK();
};

return ForThisAndAllSubgraphs(all_subgraphs, put_weights_maybe_in_memory_func);
}

void Graph::SetName(const std::string& name) {
graph_proto_->set_name(name);
}
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/graph/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,19 @@ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_ini
return GetOrCreateNodeArg(graph, new_initializer);
}

NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
const bool has_external_data = utils::HasExternalData(new_initializer);
ORT_ENFORCE(!has_external_data, "Expecting an initializer that contains data inline");

Tensor tensor;
ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(),
new_initializer, tensor));
auto tensor_proto_with_ptr = utils::TensorToTensorProto(tensor, new_initializer.name(), true);
return AddInitializerWithExternalData(graph, tensor_proto_with_ptr, std::move(tensor));
return AddInitializerWithOrtValue(graph, tensor_proto_with_ptr, std::move(tensor));
}

NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
Tensor&& tensor) {
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
Tensor&& tensor) {
OrtValue ort_value;
if (utils::HasExternalDataInMemory(new_initializer)) {
Tensor::InitOrtValue(std::move(tensor), ort_value);
Expand All @@ -307,8 +307,8 @@ NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::Tens
return GetOrCreateNodeArg(graph, new_initializer);
}

NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
OrtValue ort_value) {
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
OrtValue ort_value) {
ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(new_initializer, ort_value));
return GetOrCreateNodeArg(graph, new_initializer);
}
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_ini
/// <param name="new_initializer">TensorProto with external data contained in ort_value</param>
/// <param name="ort_value">ort_value with data</param>
/// <returns></returns>
NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
OrtValue ort_value);
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer,
OrtValue ort_value);

/** Add a new initializer to 'graph'.
* Checks that new_initializer does not already exist in 'graph' before adding it.
Expand All @@ -55,7 +55,7 @@ NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::Tens
* @returns The NodeArg for the new initializer.
* @remarks No matching graph input is created, so the initializer will be constant.
*/
NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, Tensor&& tensor);
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, Tensor&& tensor);

/** Add a new initializer to 'graph'.
* The function unpacks data into a tensor and converts new_initializer to a TensorProto with external data in memory.
Expand All @@ -67,7 +67,7 @@ NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::Tens
* @returns The NodeArg for the new initializer.
* @remarks No matching graph input is created, so the initializer will be constant.
*/
NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer);
NodeArg& AddInitializerWithOrtValue(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer);

/// <summary>
/// If the initializer with the given name does not exist in the destination graph, but exists in the
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size,
utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow<size_t>(element_count) * sizeof(MLFloat16));
}

return graph_utils::AddInitializer(graph, initializer);
return graph_utils::AddInitializerWithOrtValue(graph, initializer);
}

static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph,
"total_count: ", total_count, " values.size(): ", values.size());

utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t));
return &graph_utils::AddInitializer(graph, const_tensor);
return &graph_utils::AddInitializerWithOrtValue(graph, const_tensor);
}

NodeArg* InsertNodesForValidIndices(Graph& graph,
Expand Down
13 changes: 9 additions & 4 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) {
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(clamped_slice_length);
constant_arg_out->SetShape(result_shape);
graph_utils::AddInitializer(graph, shape_constant);
graph_utils::AddInitializerWithOrtValue(graph, shape_constant);
}

return is_concrete_shape; // convert to constant if this is true
Expand Down Expand Up @@ -317,19 +317,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph.
auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx];
const Tensor& out_tensor = ort_value.Get<Tensor>();
constexpr const bool use_tensor_buffer_false = false;
constexpr const bool use_tensor_buffer_true = true;
ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(
out_tensor,
constant_arg_out->Name(),
use_tensor_buffer_false);
use_tensor_buffer_true);

ONNX_NAMESPACE::TensorShapeProto result_shape;
for (auto& dim : out_tensor.Shape().GetDims()) {
result_shape.add_dim()->set_dim_value(dim);
}

constant_arg_out->SetShape(result_shape);
graph.AddInitializedTensor(out_tensorproto);
// The data is too small and has been inlined.
if (!utils::HasExternalData(out_tensorproto)) {
ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue()));
} else {
ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value));
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name);
new_conv_B_tensor_proto.set_name(new_name);

NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithOrtValue(graph, new_conv_B_tensor_proto);
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);

} else {
Expand All @@ -94,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie
auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name());
new_conv_B_tensor_proto.set_name(new_name);

NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithOrtValue(graph, new_conv_B_tensor_proto);
graph_utils::AddNodeInput(node, 2, new_add_B_node_arg);
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/conv_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff
new_conv_W_tensor_proto.set_name(new_W_name);
new_conv_B_tensor_proto.set_name(new_B_name);

NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto);
NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithOrtValue(graph, new_conv_W_tensor_proto);
graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg);

auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto);
auto& new_conv_B_node_arg = graph_utils::AddInitializerWithOrtValue(graph, new_conv_B_tensor_proto);

if (conv_inputs.size() == 3) {
graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg);
Expand Down
Loading
Loading