From 40db6708afe493ab6984403d841a7d7c83ed7fba Mon Sep 17 00:00:00 2001 From: Alfredo Rodriguez Date: Thu, 15 Aug 2024 15:38:19 -0400 Subject: [PATCH 1/3] Added method to create tensor from a raw data buffer. --- include/cppflow/tensor.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/cppflow/tensor.h b/include/cppflow/tensor.h index 51cd940..33fc248 100644 --- a/include/cppflow/tensor.h +++ b/include/cppflow/tensor.h @@ -96,6 +96,18 @@ class tensor { tensor &operator=(const tensor &other) = default; tensor &operator=(tensor &&other) = default; + + /** + * Static method to create a tensor from a raw buffer + * @param data A pointer to the raw data used to initialize the tensor + * @param len The length of the buffer in bytes + * @param shape The shape of the requested tensor + * @param type The type contained in the data buffer + */ + static tensor create_from_raw_data(const void* data, size_t len, const std::vector& shape, enum TF_DataType type) { + return tensor(type, data, len, shape); + } + /** * @return Shape of the tensor */ From 57f8d4c3d2b05acd0375d3b0fe3e9af8a1dd11e3 Mon Sep 17 00:00:00 2001 From: Alfredo Rodriguez Date: Thu, 15 Aug 2024 17:32:23 -0400 Subject: [PATCH 2/3] Added method to build model from a buffer in memory containing the graph definition of a saved model. --- include/cppflow/model.h | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/include/cppflow/model.h b/include/cppflow/model.h index 368e145..9cce677 100644 --- a/include/cppflow/model.h +++ b/include/cppflow/model.h @@ -66,6 +66,7 @@ class model { explicit model(const std::string& filename, const TYPE type = TYPE::SAVED_MODEL); + explicit model(const void* data, size_t size); model(const model &model) = default; model(model &&model) = default; @@ -144,6 +145,34 @@ inline model::model(const std::string &filename, const TYPE type) { status_check(this->status.get()); } +inline model::model(const void* data, size_t size) +{ + this->graph = { TF_NewGraph(), TF_DeleteGraph }; + + // Create the session. + std::unique_ptr session_options = { TF_NewSessionOptions(), TF_DeleteSessionOptions }; + + auto session_deleter = [](TF_Session* sess) { + TF_DeleteSession(sess, context::get_status()); + status_check(context::get_status()); + }; + + this->session = { TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter }; + status_check(context::get_status()); + + // Import the graph definition + TF_Buffer* def = TF_NewBufferFromString(data, size); + if (def == nullptr) { + throw std::runtime_error("Failed to import graph def from input data"); + } + + std::unique_ptr graph_opts = { TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions }; + TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status()); + TF_DeleteBuffer(def); + + status_check(context::get_status()); +} + inline std::vector model::get_operations() const { std::vector result; size_t pos = 0; From f90de6053c1a4c5003808eec7c9ef86061a59cdb Mon Sep 17 00:00:00 2001 From: Alfredo Rodriguez Date: Fri, 16 Aug 2024 16:35:13 -0400 Subject: [PATCH 3/3] Fixed omission of initialization of status member of model. --- include/cppflow/model.h | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/include/cppflow/model.h b/include/cppflow/model.h index 9cce677..4f19ce0 100644 --- a/include/cppflow/model.h +++ b/include/cppflow/model.h @@ -147,18 +147,21 @@ inline model::model(const std::string &filename, const TYPE type) { inline model::model(const void* data, size_t size) { + this->status = { TF_NewStatus(), &TF_DeleteStatus }; this->graph = { TF_NewGraph(), TF_DeleteGraph }; // Create the session. - std::unique_ptr session_options = { TF_NewSessionOptions(), TF_DeleteSessionOptions }; + std::unique_ptr + session_options = { TF_NewSessionOptions(), TF_DeleteSessionOptions }; - auto session_deleter = [](TF_Session* sess) { - TF_DeleteSession(sess, context::get_status()); - status_check(context::get_status()); + auto session_deleter = [this](TF_Session* sess) { + TF_DeleteSession(sess, this->status.get()); + status_check(this->status.get()); }; - this->session = { TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter }; - status_check(context::get_status()); + this->session = { TF_NewSession(this->graph.get(), session_options.get(), + this->status.get()), session_deleter }; + status_check(this->status.get()); // Import the graph definition TF_Buffer* def = TF_NewBufferFromString(data, size); @@ -166,11 +169,14 @@ inline model::model(const void* data, size_t size) throw std::runtime_error("Failed to import graph def from input data"); } - std::unique_ptr graph_opts = { TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions }; - TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status()); + std::unique_ptr graph_opts = { + TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions }; + TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), + this->status.get()); TF_DeleteBuffer(def); - status_check(context::get_status()); + status_check(this->status.get()); } inline std::vector model::get_operations() const {