Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/cppflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class model {

explicit model(const std::string& filename,
const TYPE type = TYPE::SAVED_MODEL);
explicit model(const void* data, size_t size);
model(const model &model) = default;
model(model &&model) = default;

Expand Down Expand Up @@ -144,6 +145,40 @@ inline model::model(const std::string &filename, const TYPE type) {
status_check(this->status.get());
}

inline model::model(const void* data, size_t size)
{
this->status = { TF_NewStatus(), &TF_DeleteStatus };
this->graph = { TF_NewGraph(), TF_DeleteGraph };

// Create the session.
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>
session_options = { TF_NewSessionOptions(), TF_DeleteSessionOptions };

auto session_deleter = [this](TF_Session* sess) {
TF_DeleteSession(sess, this->status.get());
status_check(this->status.get());
};

this->session = { TF_NewSession(this->graph.get(), session_options.get(),
this->status.get()), session_deleter };
status_check(this->status.get());

// Import the graph definition
TF_Buffer* def = TF_NewBufferFromString(data, size);
if (def == nullptr) {
throw std::runtime_error("Failed to import graph def from input data");
}

std::unique_ptr<TF_ImportGraphDefOptions,
decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions };
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(),
this->status.get());
TF_DeleteBuffer(def);

status_check(this->status.get());
}

inline std::vector<std::string> model::get_operations() const {
std::vector<std::string> result;
size_t pos = 0;
Expand Down
12 changes: 12 additions & 0 deletions include/cppflow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ class tensor {
tensor &operator=(const tensor &other) = default;
tensor &operator=(tensor &&other) = default;


/**
* Static method to create a tensor from a raw buffer
* @param data A pointer to the raw data used to initialize the tensor
* @param len The length of the buffer in bytes
* @param shape The shape of the requested tensor
* @param type The type contained in the data buffer
*/
static tensor create_from_raw_data(const void* data, size_t len, const std::vector<int64_t>& shape, enum TF_DataType type) {
return tensor(type, data, len, shape);
}

/**
* @return Shape of the tensor
*/
Expand Down