diff --git a/include/cppflow/model.h b/include/cppflow/model.h index 368e145..4f19ce0 100644 --- a/include/cppflow/model.h +++ b/include/cppflow/model.h @@ -66,6 +66,7 @@ class model { explicit model(const std::string& filename, const TYPE type = TYPE::SAVED_MODEL); + explicit model(const void* data, size_t size); model(const model &model) = default; model(model &&model) = default; @@ -144,6 +145,40 @@ inline model::model(const std::string &filename, const TYPE type) { status_check(this->status.get()); } +inline model::model(const void* data, size_t size) +{ + this->status = { TF_NewStatus(), &TF_DeleteStatus }; + this->graph = { TF_NewGraph(), TF_DeleteGraph }; + + // Create the session. + std::unique_ptr + session_options = { TF_NewSessionOptions(), TF_DeleteSessionOptions }; + + auto session_deleter = [this](TF_Session* sess) { + TF_DeleteSession(sess, this->status.get()); + status_check(this->status.get()); + }; + + this->session = { TF_NewSession(this->graph.get(), session_options.get(), + this->status.get()), session_deleter }; + status_check(this->status.get()); + + // Import the graph definition + TF_Buffer* def = TF_NewBufferFromString(data, size); + if (def == nullptr) { + throw std::runtime_error("Failed to import graph def from input data"); + } + + std::unique_ptr graph_opts = { + TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions }; + TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), + this->status.get()); + TF_DeleteBuffer(def); + + status_check(this->status.get()); +} + inline std::vector model::get_operations() const { std::vector result; size_t pos = 0; diff --git a/include/cppflow/tensor.h b/include/cppflow/tensor.h index 51cd940..33fc248 100644 --- a/include/cppflow/tensor.h +++ b/include/cppflow/tensor.h @@ -96,6 +96,18 @@ class tensor { tensor &operator=(const tensor &other) = default; tensor &operator=(tensor &&other) = default; + + /** + * Static method to create a tensor from a raw buffer + * @param data A pointer to the raw data used to initialize the tensor + * @param len The length of the buffer in bytes + * @param shape The shape of the requested tensor + * @param type The type contained in the data buffer + */ + static tensor create_from_raw_data(const void* data, size_t len, const std::vector& shape, enum TF_DataType type) { + return tensor(type, data, len, shape); + } + /** * @return Shape of the tensor */