Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions tools/pnnx/src/load_torchscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,10 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes,
const std::vector<std::string>& input_types,
const std::vector<std::vector<char> >& input_contents,
const std::vector<std::vector<int64_t> >& input_shapes2,
const std::vector<std::string>& input_types2,
const std::vector<std::vector<char> >& input_contents2,
const std::vector<std::string>& customop_modules,
const std::vector<std::string>& module_operators,
const std::string& foldable_constants_zippath,
Expand Down Expand Up @@ -646,31 +648,68 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
}

std::vector<at::Tensor> input_tensors;
for (size_t i = 0; i < traced_input_shapes.size(); i++)
if (input_contents.size() != 0)
{
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
at::TensorOptions options(input_type_to_c10_ScalarType(type));
at::IntArrayRef shape2(shape);
at::Tensor t = torch::from_blob((void*)input_contents[i].data(), shape2, options);
if (device == "gpu")
t = t.cuda();

input_tensors.push_back(t);
input_tensors.push_back(t);
}
}
else
{
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = traced_input_shapes[i];
const std::string& type = traced_input_types[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

input_tensors.push_back(t);
}
}

std::vector<at::Tensor> input_tensors2;
for (size_t i = 0; i < input_shapes2.size(); i++)
if (input_contents2.size() != 0)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];
for (size_t i = 0; i < input_shapes2.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
at::TensorOptions options(input_type_to_c10_ScalarType(type));
at::IntArrayRef shape2(shape);
at::Tensor t = torch::from_blob((void*)input_contents2[i].data(), shape2, options);
if (device == "gpu")
t = t.cuda();

input_tensors2.push_back(t);
input_tensors2.push_back(t);
}
}
else if (input_shapes2.size() != 0)
{
for (size_t i = 0; i < input_shapes2.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

input_tensors2.push_back(t);
}
}
torch::jit::Module mod;

try
Expand Down
4 changes: 3 additions & 1 deletion tools/pnnx/src/load_torchscript.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

namespace pnnx {

int load_torchscript(const std::string& ptpath, Graph& g,
int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes,
const std::vector<std::string>& input_types,
const std::vector<std::vector<char> >& input_contents,
const std::vector<std::vector<int64_t> >& input_shapes2,
const std::vector<std::string>& input_types2,
const std::vector<std::vector<char> >& input_contents2,
const std::vector<std::string>& customop_modules,
const std::vector<std::string>& module_operators,
const std::string& foldable_constants_zippath,
Expand Down
65 changes: 62 additions & 3 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <string>
#include <vector>

#include "utils.h"

#if defined _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
Expand Down Expand Up @@ -152,6 +154,40 @@ static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, c
}
}

static bool file_maybe_numpy(const std::string& path)
{
FILE* fp = fopen(path.c_str(), "rb");
if (!fp)
{
fprintf(stderr, "open failed %s\n", path.c_str());
return false;
}

char signature[7];
fread(signature, sizeof(char), 6, fp);
signature[6] = '\0';

fclose(fp);

return strcmp(signature, "\x93NUMPY") == 0;
}

static void parse_numpy_file_list(char* s, std::vector<std::vector<int64_t> >& shapes, std::vector<std::string>& types, std::vector<std::vector<char> >& contents)
{
std::vector<std::string> list;
parse_string_list(s, list);

for (auto& s : list)
{
if (!file_maybe_numpy(s))
{
fprintf(stderr, "%s is not a valid numpy file", s.c_str());
return;
}
pnnx::parse_numpy_file(s.c_str(), shapes, types, contents);
}
}

static bool model_file_maybe_torchscript(const std::string& path)
{
FILE* fp = fopen(path.c_str(), "rb");
Expand Down Expand Up @@ -213,6 +249,9 @@ static void show_usage()
fprintf(stderr, " device=cpu/gpu\n");
fprintf(stderr, " inputshape=[1,3,224,224],...\n");
fprintf(stderr, " inputshape2=[1,3,320,320],...\n");
fprintf(stderr, " input=file1.npy,file2.npy,...(conflict with inputshape)\n");
fprintf(stderr, " input2=file1.npy,file2.npy,...(conflict with inputshape2)\n");

#if _WIN32
fprintf(stderr, " customop=C:\\Users\\nihui\\AppData\\Local\\torch_extensions\\torch_extensions\\Cache\\fused\\fused.dll,...\n");
#else
Expand Down Expand Up @@ -260,8 +299,10 @@ int main(int argc, char** argv)
std::string device = "cpu";
std::vector<std::vector<int64_t> > input_shapes;
std::vector<std::string> input_types;
std::vector<std::vector<char> > input_contents;
std::vector<std::vector<int64_t> > input_shapes2;
std::vector<std::string> input_types2;
std::vector<std::vector<char> > input_contents2;
std::vector<std::string> customop_modules;
std::vector<std::string> module_operators;

Expand Down Expand Up @@ -310,6 +351,24 @@ int main(int argc, char** argv)
parse_string_list(value, customop_modules);
if (strcmp(key, "moduleop") == 0)
parse_string_list(value, module_operators);
if (strcmp(key, "input") == 0)
{
if (input_shapes.size() != 0)
{
fprintf(stderr, "parameter conflict: input and input_shape cannot be used at the same time.");
exit(1);
}
parse_numpy_file_list(value, input_shapes, input_types, input_contents);
}
if (strcmp(key, "input2") == 0)
{
if (input_shapes2.size() != 0)
{
fprintf(stderr, "parameter conflict: input2 and input_shape2 cannot be used at the same time.");
exit(1);
}
parse_numpy_file_list(value, input_shapes2, input_types2, input_contents2);
}
}

// print options
Expand Down Expand Up @@ -363,9 +422,9 @@ int main(int argc, char** argv)
else
#endif
{
load_torchscript(ptpath, pnnx_graph,
device, input_shapes, input_types,
input_shapes2, input_types2,
load_torchscript(ptpath, pnnx_graph, device,
input_shapes, input_types, input_contents,
input_shapes2, input_types2, input_contents2,
customop_modules, module_operators,
foldable_constants_zippath, foldable_constants);
}
Expand Down
Loading
Loading