diff --git a/README_zh-CN.md b/README_zh-CN.md
index fc1e68d..fc216de 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -185,6 +185,27 @@ python entry.py
> 当前版本已支持从numpy文件中读取initializer数据。点击“Open *.npy”按钮,在弹出的对话框中选择numpy文件,数据便会自动解析并呈现在上方的输入框中,也支持在读取的数据基础上进一步编辑。
+## 合并两个onnx模型
+
+1. 初始页面加载两个模型
+ 将两个模型文件拖动到页面,或在模型选择页面选择两个模型。等待加载完成。模型会并列显示。
+ 修改完成后点击`download`下载。
+
+2. 在操作页面进行合并
+ 点击load model mode:下方的下拉框,可以选择`new`,`tail`,`parallel`三种模式。
+ `new`会打开新的模型;`tail`会将新模型追加到上一个模型的末尾;`parallel`会将新模型和旧模型并列。
+ 拖动或点击按钮选择模型后会显示合并后的模型。
+ 修改完成后点击`download`下载。
+
+> **已知问题:**
+> 1. 只能复制已知节点,未知节点需要先修改`metadata.json`,即先变为已知节点。
+> 2. 调试模式下按下Alt,进入断点或出现网页跳转后,松开Alt,返回该网页后Alt会被认为一直是按下状态。需要重新按一次解除。
+> 3. 合并模型暂时不支持回退。
+
+## 选择多个连续的节点,并执行复制或删除操作
+
+按住`Alt`键,点击第一个节点,`Alt`不松手点击第二个节点,完成选择。按下`j`键并松开,实现复制。按下`l`键并松开,实现删除。松开`Alt`取消选中状态。
+
`onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~
diff --git a/onnx_modifier/flask_server.py b/onnx_modifier/flask_server.py
index dc54fb1..7d27fab 100644
--- a/onnx_modifier/flask_server.py
+++ b/onnx_modifier/flask_server.py
@@ -1,6 +1,7 @@
import argparse
import logging
-from flask import Flask, render_template, request
+import time
+from flask import Flask, render_template, request, send_file
from .onnx_modifier import onnxModifier
logging.basicConfig(level=logging.INFO)
@@ -22,6 +23,40 @@ def open_model():
return 'OK', 200
+
+@app.route('/merge_model', methods=['POST'])
+def merge_model():
+
+ onnx_file1 = request.files['file0']
+ onnx_file2 = request.files['file1']
+ timestamp = time.time()
+ global onnx_modifier
+ onnx_modifier, stream ,merged_name = onnxModifier.merge(
+ onnx_file1.filename, onnx_file1.stream,
+ onnx_file2.filename, onnx_file2.stream,
+ "", str(int(timestamp)) + "_")
+
+ return send_file(stream,
+ mimetype='application/octet-stream',
+ as_attachment=True,
+ download_name=merged_name)
+
+@app.route('/append_model', methods=['POST'])
+def append_model():
+ method = request.form.get('method')
+ onnx_file1 = request.files['file']
+ timestamp = time.time()
+ global onnx_modifier
+ if onnx_modifier:
+ onnx_modifier, stream ,merged_name = onnx_modifier.append(
+ onnx_file1.filename, onnx_file1.stream,
+ str(int(timestamp)) + "_", int(method))
+
+ return send_file(stream,
+ mimetype='application/octet-stream',
+ as_attachment=True,
+ download_name=merged_name)
+
@app.route('/download', methods=['POST'])
def modify_and_download_model():
modify_info = request.get_json()
diff --git a/onnx_modifier/onnx_modifier.py b/onnx_modifier/onnx_modifier.py
index bcd02b9..ac70523 100644
--- a/onnx_modifier/onnx_modifier.py
+++ b/onnx_modifier/onnx_modifier.py
@@ -3,6 +3,7 @@
# https://github.com/saurabh-shandilya/onnx-utils
# https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model
+from io import BytesIO
import os
import copy
import struct
@@ -39,13 +40,87 @@ def from_name_json_stream(cls, name, stream):
return cls(name, model_proto)
@classmethod
- def from_name_protobuf_stream(cls, name, stream):
+ def from_name_protobuf_stream(cls, name, stream, prefix = ""):
# https://leimao.github.io/blog/ONNX-IO-Stream/
logging.info("loading model...")
stream.seek(0)
model_proto = onnx.load_model(stream, "protobuf", load_external_data=False)
+ model_proto = onnx.compose.add_prefix(model_proto, prefix=prefix)
logging.info("load done!")
return cls(name, model_proto)
+
+ @classmethod
+ def merge(cls, name1, stream1, name2, stream2, prefix1, prefix2):
+ stream1.seek(0)
+ model_proto1 = onnx.load_model(stream1, "protobuf", load_external_data=False)
+ model_proto1 = onnx.compose.add_prefix(model_proto1, prefix=prefix1)
+
+ stream2.seek(0)
+ model_proto2 = onnx.load_model(stream2, "protobuf", load_external_data=False)
+ model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix2)
+
+ model_proto1.graph.input.extend(model_proto2.graph.input)
+ model_proto1.graph.node.extend(model_proto2.graph.node)
+ model_proto1.graph.initializer.extend(model_proto2.graph.initializer)
+ model_proto1.graph.output.extend(model_proto2.graph.output)
+
+ merged_name = name1.split('.')[0] + "_" + name2.split('.')[0] + ".onnx"
+ byte_stream = BytesIO()
+ onnx.save_model(model_proto1, byte_stream)
+ byte_stream.seek(0)
+ return cls(merged_name, model_proto1), byte_stream, merged_name
+
+
+ def find_next_node_by_input(self, model, input_name):
+ first_node = None
+ for node in model.graph.node:
+ if input_name in node.input:
+ first_node = node
+ break
+ return first_node
+
+ def find_previous_node_by_output(self, model, output_name):
+ last_node = None
+ for node in model.graph.node:
+ if output_name in node.output:
+ last_node = node
+ return last_node
+
+ def append(self, name, stream, prefix, method = 1):
+
+ stream.seek(0)
+ model_proto2 = onnx.load_model(stream, "protobuf", load_external_data=False)
+ model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix)
+
+ model_proto1 = self.model_proto
+ if method == 1:
+ model1_last_node = self.find_previous_node_by_output(model_proto1, model_proto1.graph.output[-1].name)
+ model2_first_node = self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name)
+
+ model1_last_node_output_name = model1_last_node.output[0]
+ model2_first_node_input_name = model2_first_node.input[0]
+
+ model2_first_node.input.remove(model2_first_node_input_name)
+ model2_first_node.input.insert(0, model1_last_node_output_name)
+
+ del model_proto1.graph.output[-1]
+
+ if not self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name):
+ del model_proto2.graph.input[0]
+
+ model_proto1.graph.input.extend(model_proto2.graph.input)
+ model_proto1.graph.node.extend(model_proto2.graph.node)
+ model_proto1.graph.initializer.extend(model_proto2.graph.initializer)
+ model_proto1.graph.output.extend(model_proto2.graph.output)
+
+ merged_name = self.model_name.split('.')[0] + "_" + name.split('.')[0] + ".onnx"
+
+ onnx_mdf = onnxModifier(merged_name, model_proto1)
+ byte_stream = BytesIO()
+ onnx.save_model(onnx_mdf.model_proto, byte_stream)
+ byte_stream.seek(0)
+
+ return onnx_mdf, byte_stream, merged_name
def reload(self):
self.model_proto = copy.deepcopy(self.model_proto_backup)
@@ -460,6 +535,7 @@ def modify(self, modify_info):
logging.debug("=== modify_info ===\n", modify_info)
self.add_nodes(modify_info['added_node_info'], modify_info['node_states'])
+ self.change_initializer(modify_info['added_tensor'])
self.change_initializer(modify_info['changed_initializer'])
self.change_node_io_name(modify_info['node_renamed_io'])
self.edit_inputs(modify_info['added_inputs'], modify_info['rebatch_info'])
diff --git a/onnx_modifier/static/index.js b/onnx_modifier/static/index.js
index 6b474b9..f048c17 100644
--- a/onnx_modifier/static/index.js
+++ b/onnx_modifier/static/index.js
@@ -116,7 +116,89 @@ host.BrowserHost = class {
});
}
+ // open the the model at flask server and reload
+ // open mode 0: open new model, 1: add to tail, 2: parallel
+ remote_open_model(obj_files, append_model = 0) {
+ var files = Array.from(obj_files);
+ const acceptd_file = files.filter(file => this._view.accept(file.name)).slice(0, 2);
+ // console.log(file)
+ if(acceptd_file.length == 1 && append_model == 0)
+ {
+ var file = acceptd_file[0];
+ this.upload_filename = file.name;
+ var form = new FormData();
+ form.append('file', file);
+
+ // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
+ fetch('/open_model', {
+ method: 'POST',
+ body: form
+ }).then(function (response) {
+ return response.text();
+ }).then(function (text) {
+ console.log('POST response: ');
+ // Should be 'OK' if everything was successful
+ console.log(text);
+ });
+
+ if (file) {
+ this._open(file, files);
+ this._view.modifier.clearGraph();
+ }
+ } else if (acceptd_file.length == 2 || append_model != 0) {
+ var form = new FormData();
+ var url;
+ if (append_model == 0) {
+ url = '/merge_model';
+ for(var i = 0; i < acceptd_file.length; i++)
+ {
+ form.append('file' + i, acceptd_file[i]);
+ }
+ }
+ else if (append_model != 0) {
+ url = '/append_model';
+ form.append('file', acceptd_file[0]);
+ }
+ form.append('method', Number(append_model));
+ // console.log(file)
+ // this.upload_filename = file.name;
+ let filename = 'unknown';
+ // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
+ this._view.showLoading();
+ fetch(url, {
+ method: 'POST',
+ body: form
+ }).then((response) =>{
+
+ const contentDisposition = response.headers.get('content-disposition');
+ if (contentDisposition) {
+ const matches = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/i);
+ if (matches && matches[1]) {
+ filename = decodeURIComponent(matches[1].replace(/['"]/g, ''));
+ }
+ }
+ if (!response.ok) {
+ this._view.hideLoading();
+ throw new Error(`HTTP error! status: ${response.status}`);
+ }
+ var blob = response.blob();
+ return blob;
+ }).then(blob => {
+ var file = new File([blob], filename);
+ // console.log('POST response: ');
+ // // Should be 'OK' if everything was successful
+ // console.log(text);
+ if (file) {
+ this.upload_filename = file.name;
+ files = [];
+ files.push(file);
+ this._open(file, files);
+ this._view.modifier.clearGraph();
+ }
+ });
+ }
+ }
start() {
this.window.addEventListener('error', (e) => {
@@ -328,40 +410,26 @@ host.BrowserHost = class {
});
openFileDialog.addEventListener('change', (e) => {
if (e.target && e.target.files && e.target.files.length > 0) {
- const files = Array.from(e.target.files);
- const file = files.find((file) => this._view.accept(file.name));
- // console.log(file)
- this.upload_filename = file.name;
- var form = new FormData();
- form.append('file', file);
-
- // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
- fetch('/open_model', {
- method: 'POST',
- body: form
- }).then(function (response) {
- return response.text();
- }).then(function (text) {
- console.log('POST response: ');
- // Should be 'OK' if everything was successful
- console.log(text);
- });
-
-
- if (file) {
- this._open(file, files);
- this._view.modifier.clearGraph();
- }
+ this.remote_open_model(e.target.files, this.loadModelMode);
}
});
+
}
- const openModelButton = this.document.getElementById('load-model');
- if (openModelButton && openFileDialog) {
- openModelButton.addEventListener('click', () => {
- openFileDialog.value = '';
- openFileDialog.click();
+
+ var loadModelDropDown = this.document.getElementById('load-model-dropdown');
+ if (loadModelDropDown) {
+ loadModelDropDown.addEventListener('change', (e) => {
+ this.loadModelMode = loadModelDropDown.selectedIndex;
});
}
+
+ // const openModelButton = this.document.getElementById('load-model');
+ // if (openModelButton && openFileDialog) {
+ // openModelButton.addEventListener('click', () => {
+ // openFileDialog.value = '';
+ // openFileDialog.click();
+ // });
+ // }
const githubButton = this.document.getElementById('github-button');
const githubLink = this.document.getElementById('logo-github');
if (githubButton && githubLink) {
@@ -379,27 +447,8 @@ host.BrowserHost = class {
this.document.body.addEventListener('drop', (e) => {
e.preventDefault();
if (e.dataTransfer && e.dataTransfer.files && e.dataTransfer.files.length > 0) {
- const files = Array.from(e.dataTransfer.files);
- const file = files.find((file) => this._view.accept(file.name));
- this.upload_filename = file.name;
- var form = new FormData();
- form.append('file', file);
-
- // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful
- fetch('/open_model', {
- method: 'POST',
- body: form
- }).then(function (response) {
- return response.text();
- }).then(function (text) {
- console.log('POST response: ');
- // Should be 'OK' if everything was successful
- console.log(text);
- });
- if (file) {
- this._open(file, files);
- this._view.modifier.clearGraph();
- }
+ this.remote_open_model(e.dataTransfer.files, this.loadModelMode);
+
}
});
@@ -588,6 +637,7 @@ host.BrowserHost = class {
// 'modified_inputs_info' : this.arrayToObject(this.process_modified_inputs(this._view.modifier.inputModificationForSave,
// this._view.modifier.renameMap, this._view.modifier.name2NodeStates)),
'rebatch_info' : this.mapToObjectRec(this._view.modifier.reBatchInfo),
+ 'added_tensor' : this.mapToObjectRec(this._view.modifier.addedTensor),
'changed_initializer' : this.mapToObjectRec(this._view.modifier.initializerEditInfo),
'postprocess_args' : {'shapeInf' : this._view.modifier.downloadWithShapeInf, 'cleanUp' : this._view.modifier.downloadWithCleanUp}
})
diff --git a/onnx_modifier/static/modifier.js b/onnx_modifier/static/modifier.js
index 925c8ec..802f1b2 100644
--- a/onnx_modifier/static/modifier.js
+++ b/onnx_modifier/static/modifier.js
@@ -23,6 +23,8 @@ modifier.Modifier = class {
this.downloadWithShapeInf = false;
this.downloadWithCleanUp = false;
+ this.addedTensor = new Map();
+
}
loadModelGraph(model, graphs) {
@@ -42,6 +44,17 @@ modifier.Modifier = class {
this.name2NodeStatesOrig.set(name, 'Exist');
}
this.updateAddNodeDropDown();
+ this.updateLoadModelDropDown();
+ }
+
+ isOptionExists(selectElement, optionText) {
+
+ for (let i = 0; i < selectElement.options.length; i++) {
+ if (selectElement.options[i].text === optionText) {
+ return true;
+ }
+ }
+ return false;
}
// TODO: add filter feature like here: https://www.w3schools.com/howto/howto_js_dropdown.asp
@@ -50,12 +63,29 @@ modifier.Modifier = class {
var addNodeDropdown = this.view._host.document.getElementById('add-node-dropdown');
for (const node of this.model.supported_nodes) {
// node: [domain, op]
- var option = new Option(node[1], node[0] + ':' + node[1]);
- // console.log(option)
- addNodeDropdown.appendChild(option);
+ if (!this.isOptionExists(addNodeDropdown, node[1])) {
+ var option = new Option(node[1], node[0] + ':' + node[1]);
+ // console.log(option)
+ addNodeDropdown.appendChild(option);
+ }
}
}
+ updateLoadModelDropDown() {
+ // update dropdown supported node lost
+ var loadModelDropdown = this.view._host.document.getElementById('load-model-dropdown');
+ this.supported_load_mode = ["new", "tail", "parallel"]
+
+ if (loadModelDropdown.options.length == 0) {
+ for (let [index, mode] of this.supported_load_mode.entries()) {
+ var option = new Option(mode, index);
+ loadModelDropdown.appendChild(option);
+ }
+
+ }
+
+ }
+
getShapeTypeInfo(name) {
for (var value_info of this.graph._value_info) {
if (value_info.name == name && value_info.type && value_info.type.tensor_type) {
@@ -79,13 +109,13 @@ modifier.Modifier = class {
}
- try_get_node_name(op_type)
+ try_get_node_name(op_type, input_node_id)
{
- var node_id = (this.addNodeKey++).toString(); // in case input (onnx) node has no name
+ var node_id = (input_node_id || this.addNodeKey++).toString(); // in case input (onnx) node has no name
var modelNodeName = 'custom_added_' + op_type + node_id;
if (this.addedNode.has(modelNodeName) || this.name2NodeStates.get(modelNodeName) ){
- modelNodeName = this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ');
+ modelNodeName = try_get_node_name(op_type, Date.parse(new Date()));//this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ');
}
return modelNodeName;
}
@@ -103,6 +133,69 @@ modifier.Modifier = class {
this.applyAndUpdateView();
}
+ // duplicate a node with ( _cp + unique_id ) as param suffix
+ duplicateNode(node_name, unique_id = "") {
+ //avoid to add a existed name node
+ var srcModelNode = this.name2ModelNode.get(node_name);
+ if (!srcModelNode.type){
+ return;
+ }
+ var dstModelNodeName = this.try_get_node_name(srcModelNode.type.name);
+
+ var properties = new Map();
+ properties.set('domain', "ai.onnx");
+ properties.set('op_type', srcModelNode.type.name);
+ properties.set('name', dstModelNodeName);
+
+
+ var attributes = new Map();
+ for (const attribute of srcModelNode.attributes) {
+ attributes.set(attribute.name, [
+ attribute.value?attribute.value.toString():"undefined",
+ attribute.type||"undefined"])
+ // attributes.set(key, modelNode.attributes.get(key));
+ }
+
+ var outputs = new Map();
+ for (const output of srcModelNode.outputs) {
+ var dstNameList = [];
+ for (const srcArg of output.arguments) {
+ var dstName = srcArg.name + "_cp" + unique_id;
+ dstNameList.push([dstName, false]);
+ this.graph.copy_tensor(dstName, srcArg.name);
+
+ }
+ outputs.set(output.name, dstNameList);
+
+ }
+
+ var inputs = new Map();
+ for (const input of srcModelNode.inputs) {
+ var dstNameList = [];
+ for (const srcArg of input.arguments) {
+ var dstName = srcArg.name + "_cp" + unique_id;
+
+ if (this.graph._context._tensors &&
+ this.graph._context._tensors.has(srcArg.name)) {
+ this.graph.copy_tensor(dstName, srcArg.name);
+ var initializer_info = this.graph.get_initializer_info(dstName);
+ if(initializer_info) {
+ this.addedTensor.set(dstName, initializer_info);
+ }
+
+ }
+ dstNameList.push([dstName, false])
+ }
+ inputs.set(input.name, dstNameList)
+
+ }
+
+ this.addedNode.set(dstModelNodeName,
+ new view.LightNodeInfo(properties, attributes, inputs, outputs));
+ this.applyAndUpdateView();
+ }
+
+
addModelOutput(node_name) {
var modelNode = this.name2ModelNode.get(node_name);
// use a output argument as a proxy
@@ -402,6 +495,10 @@ modifier.Modifier = class {
for (const [modelNodeName, node_info] of this.addedNode) {
// console.log(node_info)
var node = this.graph.make_custom_added_node(node_info);
+ if (!node) {
+ console.log("node not supported yet");
+ continue;
+ }
// console.log(node)
for (const input of node.inputs) {
@@ -489,6 +586,7 @@ modifier.Modifier = class {
this.graph.reset_custom_added_node();
this.graph.reset_custom_modified_outputs();
this.graph.reset_custom_modified_inputs();
+ this.graph.reset_custom_added_tensors();
}
// reset load location
var container = this.view._getElementById('graph');
diff --git a/onnx_modifier/static/onnx.js b/onnx_modifier/static/onnx.js
index 06c98c9..8031214 100644
--- a/onnx_modifier/static/onnx.js
+++ b/onnx_modifier/static/onnx.js
@@ -465,6 +465,10 @@ onnx.Graph = class {
this._custom_added_inputs = []
this._custom_deleted_inputs = []
+ this._custom_added_tensors = new Set()
+ this._graph_config_max_input_count = 12;
+ this._graph_config_max_output_count = 12;
+
// model parameter assignment here!
// console.log(graph)
for (const initializer of graph.initializer) {
@@ -564,6 +568,34 @@ onnx.Graph = class {
return this._nodes.concat(this._custom_added_node);
}
+ // for duplicated nodes we should manage the tensors in initializers
+ // its behavior is a little different with editedInitializers
+
+ reset_custom_added_tensors() {
+ for (const tensor_name of this._custom_added_tensors) {
+ this._context._tensors.delete(tensor_name);
+ }
+ this._custom_added_tensors = new Set()
+
+ }
+
+ copy_tensor(dst_name, source_name) {
+ var tensor= this._context.tensor(source_name);
+ tensor.name = dst_name;
+ tensor.is_custom_added = 1;
+ this._context._tensors.set(dst_name, tensor);
+ this._custom_added_tensors.add(dst_name);
+
+ }
+
+ get_initializer_info(name) {
+ var initializer = this._context._tensors.get(name).initializer;
+ if (!initializer) return null;
+ return initializer.type?[initializer.type.toString().replace(/\s+/g, ''), initializer.value?initializer.toString(1).replace(/\s+/g, ''):""]:null
+ }
+
+ //
+
reset_custom_added_node() {
this._custom_added_node = []
// this._custom_add_node_io_idx = 0
@@ -576,6 +608,10 @@ onnx.Graph = class {
make_custom_added_node(node_info) {
// type of node_info == LightNodeInfo
const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain'));
+ if(!schema)
+ {
+ return null;
+ }
// console.log(schema)
// console.log(node_info.attributes)
@@ -583,8 +619,8 @@ onnx.Graph = class {
// console.log(node_info.outputs)
// var max_input = schema.max_input
// var min_input = schema.max_input
- var max_custom_add_input_num = Math.min(schema.max_input, 8) // set at most 8 custom_add inputs
- var max_custom_add_output_num = Math.min(schema.max_output, 8) // set at most 8 custom_add outputs
+ var max_custom_add_input_num = Math.min(schema.max_input, this._graph_config_max_input_count) // set at most 12 custom_add inputs
+ var max_custom_add_output_num = Math.min(schema.max_output, this._graph_config_max_output_count) // set at most 12 custom_add outputs
// console.log(node_info)
var inputs = []
@@ -605,7 +641,14 @@ onnx.Graph = class {
else {
var arg_name = 'list_custom_input_' + (this._custom_add_node_io_idx++).toString()
}
- arg_list.push(this._context.argument(arg_name))
+ var arg = this._context.argument(arg_name);
+ if (node_info_input && node_info_input[j] && node_info_input[j].length == 2) {
+ arg.is_optional = node_info_input[j][1];
+
+ } else if (input.option && input.option == 'optional') {
+ arg.is_optional = true;
+ }
+ arg_list.push(arg);
}
}
else {
@@ -615,14 +658,21 @@ onnx.Graph = class {
else {
var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString()
}
- arg_list = [this._context.argument(arg_name)]
+ var arg = this._context.argument(arg_name);
+ if (node_info_input && node_info_input[0] && node_info_input[0].length == 2) {
+ arg.is_optional = node_info_input[0][1];
+
+ } else if(input.option && input.option == 'optional') {
+ arg.is_optional = true;
+ }
+ arg_list = [arg]
}
for (var arg of arg_list) {
arg.is_custom_added = true;
- if (input.option && input.option == 'optional') {
- arg.is_optional = true;
- }
+ // if (input.option && input.option == 'optional') {
+ // arg.is_optional = true;
+ // }
}
inputs.push(new onnx.Parameter(input.name, arg_list));
}
@@ -643,6 +693,12 @@ onnx.Graph = class {
else {
var arg_name = 'list_custom_output_' + (this._custom_add_node_io_idx++).toString()
}
+ if (node_info_output && node_info_output[i] && node_info_output[i].length == 2) {
+ arg.is_optional = node_info_output[i][1];
+
+ } else if (output.option && output.option == 'optional') {
+ arg.is_optional = true;
+ }
arg_list.push(this._context.argument(arg_name))
}
}
@@ -653,15 +709,20 @@ onnx.Graph = class {
else {
var arg_name = 'custom_output_' + (this._custom_add_node_io_idx++).toString()
}
-
+ if (node_info_output && node_info_output[0] && node_info_output[0].length == 2) {
+ arg.is_optional = node_info_output[0][1];
+
+ } else if (output.option && output.option == 'optional') {
+ arg.is_optional = true;
+ }
arg_list = [this._context.argument(arg_name)]
}
for (var arg of arg_list) {
arg.is_custom_added = true;
- if (output.option && output.option == 'optional') {
- arg.is_optional = true;
- }
+ // if (output.option && output.option == 'optional') {
+ // arg.is_optional = true;
+ // }
}
outputs.push(new onnx.Parameter(output.name, arg_list));
}
@@ -1216,13 +1277,13 @@ onnx.Tensor = class {
return this._decode(context, 0);
}
- toString() {
+ toString(unlimit=0) {
const context = this._context();
// console.log(context)
if (context.state) {
return '';
}
- context.limit = 10000;
+ context.limit = (unlimit==0)?10000:100000000;
const value = this._decode(context, 0);
// console.log(value)
// console.log(onnx.Tensor._stringify(value, '', ' '))
@@ -1666,19 +1727,34 @@ onnx.Metadata = class {
}
constructor(data) {
+ this._order_list = ["Conv", "BatchNormalization", "LeakyRelu", "Concat", "Add", "UserDefined"]
this._map = new Map();
+ let disorder_maps = this._map;
if (data) {
const metadata = JSON.parse(data);
for (const item of metadata) {
- if (!this._map.has(item.module)) {
- this._map.set(item.module, new Map());
+ if (!disorder_maps.has(item.module)) {
+ disorder_maps.set(item.module, new Map());
}
- const map = this._map.get(item.module);
+ const map = disorder_maps.get(item.module);
if (!map.has(item.name)) {
map.set(item.name, []);
}
map.get(item.name).push(item);
}
+
+ // reorder the map, facilitate the addition of nodes
+ for(let [name, disorder_map] of disorder_maps) {
+ let ordered_map = new Map();
+ for (let order of this._order_list) {
+ if(disorder_map.has(order))
+ {
+ ordered_map.set(order, disorder_map.get(order));
+ disorder_map.delete(order);
+ }
+ }
+ this._map.set(name, new Map([...ordered_map, ...disorder_map]))
+ }
}
}
@@ -1895,6 +1971,7 @@ onnx.GraphContext = class {
}
argument(name, original_name) {
+ if(!original_name) original_name = name;
const tensor = this.tensor(name);
// console.log(tensor)
const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
diff --git a/onnx_modifier/static/view-grapher.css b/onnx_modifier/static/view-grapher.css
index 065bca0..7e6a75d 100644
--- a/onnx_modifier/static/view-grapher.css
+++ b/onnx_modifier/static/view-grapher.css
@@ -136,4 +136,36 @@
}
/* render selected nodes */
- .highlight path { stroke: #FF6347; stroke-width: 2px;}
\ No newline at end of file
+ .highlight path { stroke: #FF6347; stroke-width: 2px;}
+/* wait for merging the model, it may be extremely slow */
+
+.loading-overlay {
+ display: none; /* 默认隐藏 */
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0, 0, 0, 0.5);
+ justify-content: center;
+ align-items: center;
+ z-index: 9999;
+}
+
+
+.loading-spinner {
+ border: 8px solid #f3f3f3;
+ border-top: 8px solid #3498db;
+ border-radius: 50%;
+ width: 60px;
+ height: 60px;
+ animation: loading-spin 1s linear infinite;
+}
+
+@keyframes loading-spin {
+ 0% { transform: rotate(0deg); }
+ 100% { transform: rotate(360deg); }
+
+}
+
+.selected path { stroke: #FF6347; fill:#3498db; stroke-width: 2px;}
diff --git a/onnx_modifier/static/view-grapher.js b/onnx_modifier/static/view-grapher.js
index ebbe6f1..842ef05 100644
--- a/onnx_modifier/static/view-grapher.js
+++ b/onnx_modifier/static/view-grapher.js
@@ -131,6 +131,97 @@ grapher.Graph = class {
}
}
+ // highlight the selected nodes
+ removeHighlight() {
+ for (const [name,node] of this.modifier.name2ViewNode) {
+ if(node.element) {
+ node.element.classList.remove("selected");
+ }
+ }
+ }
+
+ setHighlight(sets) {
+ this.removeHighlight();
+ for(const set of sets) {
+ var node = this.modifier.name2ViewNode.get(set);
+ node.element.classList.add('selected');
+ }
+
+ }
+
+ // find nodes between the two selected nodes
+
+ getAllPathNodesRecursive(startNodeName, endNodeName, allPathNodeNames, currentPathNodeNames = [], deadPathNodeNames = [], depth = 0) {
+ var len = 0;
+ if (depth > 400) return len;
+ var direction_switcher = false; // startNode and endNode may be swapped
+ if(depth === 0 ) {
+ if(!this.modifier.namedEdges.get(startNodeName) ||
+ !this.modifier.name2ModelNode.get(endNodeName)) {
+ return len;
+ } else {
+ direction_switcher = true;
+ }
+
+ }
+
+ currentPathNodeNames.push(startNodeName);
+
+ if (startNodeName === endNodeName) {
+ len = currentPathNodeNames.length;
+ currentPathNodeNames.forEach(item => allPathNodeNames.add(item));
+ // allPathNodeNames = new Set([...allPathNodeNames, ...currentPathNodeNames]);
+ } else {
+
+ const children = new Set(this.modifier.namedEdges.get(startNodeName) || []);
+ var currentPathLen = currentPathNodeNames.length;
+ for (let child of children) {
+ if(allPathNodeNames.has(child))
+ {
+ currentPathNodeNames.forEach(item => allPathNodeNames.add(item));
+ continue;
+ }else if(deadPathNodeNames.includes(child))
+ {
+ continue;
+ }
+ while(currentPathNodeNames.length - currentPathLen > 0)
+ {
+ currentPathNodeNames.pop();
+ }
+ len += this.getAllPathNodesRecursive(child, endNodeName, allPathNodeNames,
+ currentPathNodeNames, deadPathNodeNames, depth + 1);
+
+ }
+ while(currentPathNodeNames.length - currentPathLen > 0)
+ {
+ currentPathNodeNames.pop();
+ }
+
+ if (len == 0) {
+ deadPathNodeNames.push(startNodeName);
+ }
+ }
+
+ if (direction_switcher && len == 0) {
+ deadPathNodeNames = [];
+ currentPathNodeNames = [];
+ len = this.getAllPathNodesRecursive(endNodeName, startNodeName, allPathNodeNames,
+ currentPathNodeNames, deadPathNodeNames, depth + 1);
+ }
+
+ return len;
+ }
+
+ getAllPathNodeNames(startNodeName, endNodeName)
+ {
+ var allPathNodeNames = new Set([startNodeName]);
+ this.getAllPathNodesRecursive(startNodeName, endNodeName, allPathNodeNames);
+ return allPathNodeNames;
+
+ }
+
+ //
+
build(document, origin) {
const createGroup = (name) => {
const element = document.createElementNS('http://www.w3.org/2000/svg', 'g');
diff --git a/onnx_modifier/static/view-sidebar.js b/onnx_modifier/static/view-sidebar.js
index 3575546..f2b5bdc 100644
--- a/onnx_modifier/static/view-sidebar.js
+++ b/onnx_modifier/static/view-sidebar.js
@@ -224,6 +224,8 @@ sidebar.NodeSidebar = class {
this._addHeader('Model Input Output editing helper');
+ this._addButton('Duplicate Node');
+ this.add_span();
this._addButton('Add Output');
this.add_span();
this._addButton('Add Input');
@@ -328,6 +330,12 @@ sidebar.NodeSidebar = class {
this._host._view.modifier.addModelOutput(this._modelNodeName);
});
}
+ if (title === 'Duplicate Node') {
+ buttonElement.addEventListener('click', () => {
+ var time_now = Date.parse(new Date())/1000;
+ this._host._view.modifier.duplicateNode(this._modelNodeName, time_now);
+ });
+ }
if (title === 'Add Input') {
buttonElement.addEventListener('click', () => {
// show dialog
@@ -965,6 +973,61 @@ sidebar.ArgumentView = class {
return this._renameAuxelements;
}
+ // just move numpy dataloader commands to a single funtion
+ add_np_dataloader(inputInitializerVal, inputInitializerType)
+ {
+ const editInitializerNumpyVal = this._host.document.createElement('div');
+ editInitializerNumpyVal.className = 'sidebar-view-item-value-line-border';
+ editInitializerNumpyVal.innerHTML = 'Or import from a *.npy file:';
+ this._element.appendChild(editInitializerNumpyVal);
+
+ const openFileButton_ = this._host.document.createElement('button');
+ openFileButton_.setAttribute("display", "none");
+ openFileButton_.innerHTML = "Open *.npy"
+ const openFileDialog_ = this._host.document.createElement('input');
+ openFileDialog_.setAttribute("type", "file");
+
+ openFileButton_.addEventListener('click', () => {
+ openFileDialog_.value = '';
+ openFileDialog_.click();
+ });
+ var orig_arg_name = this._host._view.modifier.getOriginalName(this._param_type, this._modelNodeName, this._param_index, this._arg_index);
+ openFileDialog_.addEventListener('change', (e) => {
+ if (e.target && e.target.files && e.target.files.length > 0) {
+ var reader = new FileReader();
+ var context = this;
+ reader.onload = function() {
+ var npLoader = new npyjs.Npyjs();
+ npLoader.load(reader.result, (out) => {
+ // `array` is a one-dimensional array of the raw data
+ // `shape` is a one-dimensional array that holds a numpy-style shape.
+ // console.log(
+ // `You loaded an array with ${out.shape} \nelements: ${out.data}.`
+ // );
+ var fmt_tensor = npLoader.format_np(out.data, out.shape);
+ var dataType = context._argument.type?context._argument.type._dataType:`${out.dtype}[${out.shape.toString()}]`
+ context._host._view.modifier.changeInitializer(context._modelNodeName, context._parameterName, context._param_type, context._param_index,
+ context._arg_index, dataType, fmt_tensor);
+ // [type, value]
+
+ var initializerEditInfo = context._host._view.modifier.initializerEditInfo.get(orig_arg_name)
+ if (initializerEditInfo) {
+ // [type, value]
+ inputInitializerVal.innerHTML = initializerEditInfo[1];
+ if(inputInitializerType) {
+ inputInitializerType.innerHTML = initializerEditInfo[0];
+ }
+ inputInitializerVal.setAttribute("tab-size", '10px');
+ }
+
+ });
+ };
+ reader.readAsArrayBuffer(e.target.files[0]);
+ }
+ });
+ this._element.appendChild(openFileButton_);
+ }
+
toggle() {
if (this._expander) {
if (this._expander.innerText == '+') {
@@ -1022,7 +1085,7 @@ sidebar.ArgumentView = class {
this._element.appendChild(location);
}
- if (initializer) {
+ if (initializer && !this._argument.is_custom_added) {
const editInitializerVal = this._host.document.createElement('div');
editInitializerVal.className = 'sidebar-view-item-value-line-border';
editInitializerVal.innerHTML = 'This is an initializer, you can input a new value for it here:';
@@ -1046,46 +1109,7 @@ sidebar.ArgumentView = class {
});
this._element.appendChild(inputInitializerVal);
- const editInitializerNumpyVal = this._host.document.createElement('div');
- editInitializerNumpyVal.className = 'sidebar-view-item-value-line-border';
- editInitializerNumpyVal.innerHTML = 'Or import from a *.npy file:';
- this._element.appendChild(editInitializerNumpyVal);
-
- const openFileButton_ = this._host.document.createElement('button');
- openFileButton_.setAttribute("display", "none");
- openFileButton_.innerHTML = "Open *.npy"
- const openFileDialog_ = this._host.document.createElement('input');
- openFileDialog_.setAttribute("type", "file");
-
- openFileButton_.addEventListener('click', () => {
- openFileDialog_.value = '';
- openFileDialog_.click();
- });
-
- openFileDialog_.addEventListener('change', (e) => {
- if (e.target && e.target.files && e.target.files.length > 0) {
- var reader = new FileReader();
- var context = this;
- reader.onload = function() {
- var npLoader = new npyjs.Npyjs();
- npLoader.load(reader.result, (out) => {
- // `array` is a one-dimensional array of the raw data
- // `shape` is a one-dimensional array that holds a numpy-style shape.
- // console.log(
- // `You loaded an array with ${out.shape} \nelements: ${out.data}.`
- // );
- var fmt_tensor = npLoader.format_np(out.data, out.shape);
- context._host._view.modifier.changeInitializer(context._modelNodeName, context._parameterName, context._param_type, context._param_index,
- context._arg_index, context._argument.type._dataType, fmt_tensor);
- // [type, value]
- inputInitializerVal.innerHTML = context._host._view.modifier.initializerEditInfo.get(orig_arg_name)[1];
- inputInitializerVal.setAttribute("tab-size", '10px');
- });
- };
- reader.readAsArrayBuffer(e.target.files[0]);
- }
- });
- this._element.appendChild(openFileButton_);
+ this.add_np_dataloader(inputInitializerVal, orig_arg_name)
// this._element.appendChild(openFileDialog_);
}
@@ -1142,6 +1166,8 @@ sidebar.ArgumentView = class {
this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, init_type, init_val);
});
this._element.appendChild(inputInitializerType);
+
+ this.add_np_dataloader(inputInitializerVal, inputInitializerType);
// <====== input type ======
}
diff --git a/onnx_modifier/static/view.js b/onnx_modifier/static/view.js
index 8b38c08..3e400a2 100644
--- a/onnx_modifier/static/view.js
+++ b/onnx_modifier/static/view.js
@@ -26,6 +26,11 @@ view.View = class {
direction: 'vertical',
mousewheel: 'scroll'
};
+
+ this.isAltKeyDown = false;
+ this.selectedFisrtNodeName = null;
+ this.selectedSecondNodeName = null;
+
this.lastScrollLeft = 0;
this.lastScrollTop = 0;
this._zoom = 1;
@@ -59,6 +64,42 @@ view.View = class {
container.addEventListener('scroll', (e) => this._scrollHandler(e));
container.addEventListener('wheel', (e) => this._wheelHandler(e), { passive: false });
container.addEventListener('mousedown', (e) => this._mouseDownHandler(e));
+
+ // for select nodes
+
+ container.addEventListener('keydown', (e) => {
+
+ if (e.key === 'Alt' && !this.isAltKeyDown) {
+ this.isAltKeyDown = true;
+ }
+ });
+ container.addEventListener('keyup', (e) => {
+
+ if (e.key === 'Alt' && this.isAltKeyDown) {
+ this.isAltKeyDown = false;
+ this.selectedFisrtNodeName = null;
+ this.selectedSecondNodeName = null;
+ this._graph.removeHighlight();
+ }
+ if (this.isAltKeyDown && e.key.toLowerCase() === 'j') {
+ if(this.isAltKeyDown && this.selectedFisrtNodeName) {
+ var time_now = Date.parse(new Date()) / 1000;
+ var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName);
+ for (const name of sets) {
+ this.modifier.duplicateNode(name, time_now);
+ }
+ }
+ } else if (this.isAltKeyDown && e.key.toLowerCase() === 'l') {
+
+ if(this.isAltKeyDown && this.selectedFisrtNodeName) {
+ var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName);
+ for (const name of sets) {
+ this._host._view.modifier.deleteSingleNode(name);
+ }
+ }
+ }
+ });
+
switch (this._host.agent) {
case 'safari':
container.addEventListener('gesturestart', (e) => this._gestureStartHandler(e), false);
@@ -74,7 +115,36 @@ view.View = class {
});
}
+ highlight(node, input, modelNodeName) {
+ if (node) {
+ if(this.isAltKeyDown) {
+ if(this.selectedFisrtNodeName)
+ {
+ this.selectedSecondNodeName = modelNodeName;
+ var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName);
+ this._graph.setHighlight(sets.add(this.selectedFisrtNodeName))
+ } else {
+ this.selectedFisrtNodeName = modelNodeName;
+ this.selectedSecondNodeName = modelNodeName;
+ var sets = new Set([this.selectedFisrtNodeName]);
+ this._graph.setHighlight(sets);
+ }
+ }
+
+ }
+
+ }
+
+ showLoading() {
+ this._host.document.getElementById('loading').style.display = 'flex';
+ }
+
+ hideLoading() {
+ this._host.document.getElementById('loading').style.display = 'none';
+ }
+
show(page) {
+ this.hideLoading();
if (!page) {
page = (!this._model && !this.activeGraph) ? 'welcome' : 'default';
}
@@ -779,7 +849,7 @@ view.View = class {
}
showNodeProperties(node, input, modelNodeName) {
- if (node) {
+ if (node && !this.isAltKeyDown) {
try {
// console.log(node) // 注意是onnx.Node, 不是grapher.Node,所以没有update(), 没有element元素
const nodeSidebar = new sidebar.NodeSidebar(this._host, node, modelNodeName);
@@ -1077,6 +1147,7 @@ view.Node = class extends grapher.Node {
const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location);
const title = header.add(null, styles, content, tooltip);
title.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName));
+ title.on('click', () => this.context.view.highlight(node, null, this.modelNodeName));
if (node.type.nodes && node.type.nodes.length > 0) {
const definition = header.add(null, styles, '\u0192', 'Show Function Definition');
definition.on('click', () => this.context.view.pushGraph(node.type));
@@ -1112,6 +1183,7 @@ view.Node = class extends grapher.Node {
if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
const list = this.list();
list.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName));
+ list.on('click', () => this.context.view.highlight(node, null, this.modelNodeName));
for (const initializer of initializers) {
const argument = initializer.arguments[0];
const type = argument.type;
diff --git a/onnx_modifier/templates/index.html b/onnx_modifier/templates/index.html
index 5eca38f..abd36b1 100644
--- a/onnx_modifier/templates/index.html
+++ b/onnx_modifier/templates/index.html
@@ -124,6 +124,20 @@
left: 2px;
top: 135px;
}
+.load-model-text {
+ font-size: 15px;
+ position: absolute;
+ border: none;
+ left: 2px;
+ top: 180px;
+ font-family:"Oliviar Sans Light";
+}
+.op-load-model-dropdown {
+ font-size: 15px;
+ position:absolute;
+ left: 2px;
+ top: 210px;
+}
.graph-add-input-dropdown {
font-size: 15px;
}
@@ -391,6 +405,9 @@