diff --git a/README_zh-CN.md b/README_zh-CN.md index fc1e68d..fc216de 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -185,6 +185,27 @@ python entry.py > 当前版本已支持从numpy文件中读取initializer数据。点击“Open *.npy”按钮,在弹出的对话框中选择numpy文件,数据便会自动解析并呈现在上方的输入框中,也支持在读取的数据基础上进一步编辑。 +## 合并两个onnx模型 + +1. 初始页面加载两个模型 + 将两个模型文件拖动到页面,或在模型选择页面选择两个模型。等待加载完成。模型会并列显示。 + 修改完成后点击`download`下载。 + +2. 在操作页面进行合并 + 点击load model mode:下方的下拉框,可以选择`new`,`tail`,`parallel`三种模式。 + `new`会打开新的模型;`tail`会将新模型追加到上一个模型的末尾;`parallel`会将新模型和旧模型并列。 + 拖动或点击按钮选择模型后会显示合并后的模型。 + 修改完成后点击`download`下载。 + +> **已知问题:** +> 1. 只能复制已知节点,未知节点需要先修改`metadata.json`,即先变为已知节点。 +> 2. 调试模式下按下Alt,进入断点或出现网页跳转后,松开Alt,返回该网页后Alt会被认为一直是按下状态。需要重新按一次解除。 +> 3. 合并模型暂时不支持回退。 + +## 选择多个连续的节点,并执行复制或删除操作 + +按住`Alt`键,点击第一个节点,`Alt`不松手点击第二个节点,完成选择。按下`j`键并松开,实现复制。按下`l`键并松开,实现删除。松开`Alt`取消选中状态。 + `onnx-modifer`正在活跃地更新中:hammer_and_wrench:。 欢迎使用,提issue,如果有帮助的话,感谢给个:star:~ diff --git a/onnx_modifier/flask_server.py b/onnx_modifier/flask_server.py index dc54fb1..7d27fab 100644 --- a/onnx_modifier/flask_server.py +++ b/onnx_modifier/flask_server.py @@ -1,6 +1,7 @@ import argparse import logging -from flask import Flask, render_template, request +import time +from flask import Flask, render_template, request, send_file from .onnx_modifier import onnxModifier logging.basicConfig(level=logging.INFO) @@ -22,6 +23,40 @@ def open_model(): return 'OK', 200 + +@app.route('/merge_model', methods=['POST']) +def merge_model(): + + onnx_file1 = request.files['file0'] + onnx_file2 = request.files['file1'] + timestamp = time.time() + global onnx_modifier + onnx_modifier, stream ,merged_name = onnxModifier.merge( + onnx_file1.filename, onnx_file1.stream, + onnx_file2.filename, onnx_file2.stream, + "", str(int(timestamp)) + "_") + + return send_file(stream, + mimetype='application/octet-stream', + as_attachment=True, + download_name=merged_name) + +@app.route('/append_model', methods=['POST']) +def append_model(): + method = request.form.get('method') + onnx_file1 = request.files['file'] + timestamp = time.time() + global onnx_modifier + if onnx_modifier: + onnx_modifier, stream ,merged_name = onnx_modifier.append( + onnx_file1.filename, onnx_file1.stream, + str(int(timestamp)) + "_", int(method)) + + return send_file(stream, + mimetype='application/octet-stream', + as_attachment=True, + download_name=merged_name) + @app.route('/download', methods=['POST']) def modify_and_download_model(): modify_info = request.get_json() diff --git a/onnx_modifier/onnx_modifier.py b/onnx_modifier/onnx_modifier.py index bcd02b9..ac70523 100644 --- a/onnx_modifier/onnx_modifier.py +++ b/onnx_modifier/onnx_modifier.py @@ -3,6 +3,7 @@ # https://github.com/saurabh-shandilya/onnx-utils # https://stackoverflow.com/questions/52402448/how-to-read-individual-layers-weight-bias-values-from-onnx-model +from io import BytesIO import os import copy import struct @@ -39,13 +40,87 @@ def from_name_json_stream(cls, name, stream): return cls(name, model_proto) @classmethod - def from_name_protobuf_stream(cls, name, stream): + def from_name_protobuf_stream(cls, name, stream, prefix = ""): # https://leimao.github.io/blog/ONNX-IO-Stream/ logging.info("loading model...") stream.seek(0) model_proto = onnx.load_model(stream, "protobuf", load_external_data=False) + model_proto = onnx.compose.add_prefix(model_proto, prefix=prefix) logging.info("load done!") return cls(name, model_proto) + + @classmethod + def merge(cls, name1, stream1, name2, stream2, prefix1, prefix2): + stream1.seek(0) + model_proto1 = onnx.load_model(stream1, "protobuf", load_external_data=False) + model_proto1 = onnx.compose.add_prefix(model_proto1, prefix=prefix1) + + stream2.seek(0) + model_proto2 = onnx.load_model(stream2, "protobuf", load_external_data=False) + model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix2) + + model_proto1.graph.input.extend(model_proto2.graph.input) + model_proto1.graph.node.extend(model_proto2.graph.node) + model_proto1.graph.initializer.extend(model_proto2.graph.initializer) + model_proto1.graph.output.extend(model_proto2.graph.output) + + merged_name = name1.split('.')[0] + "_" + name2.split('.')[0] + ".onnx" + byte_stream = BytesIO() + onnx.save_model(model_proto1, byte_stream) + byte_stream.seek(0) + return cls(merged_name, model_proto1), byte_stream, merged_name + + + def find_next_node_by_input(self, model, input_name): + first_node = None + for node in model.graph.node: + if input_name in node.input: + first_node = node + break + return first_node + + def find_previous_node_by_output(self, model, output_name): + last_node = None + for node in model.graph.node: + if output_name in node.output: + last_node = node + return last_node + + def append(self, name, stream, prefix, method = 1): + + stream.seek(0) + model_proto2 = onnx.load_model(stream, "protobuf", load_external_data=False) + model_proto2 = onnx.compose.add_prefix(model_proto2, prefix=prefix) + + model_proto1 = self.model_proto + if method == 1: + model1_last_node = self.find_previous_node_by_output(model_proto1, model_proto1.graph.output[-1].name) + model2_first_node = self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name) + + model1_last_node_output_name = model1_last_node.output[0] + model2_first_node_input_name = model2_first_node.input[0] + + model2_first_node.input.remove(model2_first_node_input_name) + model2_first_node.input.insert(0, model1_last_node_output_name) + + del model_proto1.graph.output[-1] + + if not self.find_next_node_by_input(model_proto2, model_proto2.graph.input[0].name): + del model_proto2.graph.input[0] + + model_proto1.graph.input.extend(model_proto2.graph.input) + model_proto1.graph.node.extend(model_proto2.graph.node) + model_proto1.graph.initializer.extend(model_proto2.graph.initializer) + model_proto1.graph.output.extend(model_proto2.graph.output) + + merged_name = self.model_name.split('.')[0] + "_" + name.split('.')[0] + ".onnx" + + onnx_mdf = onnxModifier(merged_name, model_proto1) + byte_stream = BytesIO() + onnx.save_model(onnx_mdf.model_proto, byte_stream) + byte_stream.seek(0) + + return onnx_mdf, byte_stream, merged_name def reload(self): self.model_proto = copy.deepcopy(self.model_proto_backup) @@ -460,6 +535,7 @@ def modify(self, modify_info): logging.debug("=== modify_info ===\n", modify_info) self.add_nodes(modify_info['added_node_info'], modify_info['node_states']) + self.change_initializer(modify_info['added_tensor']) self.change_initializer(modify_info['changed_initializer']) self.change_node_io_name(modify_info['node_renamed_io']) self.edit_inputs(modify_info['added_inputs'], modify_info['rebatch_info']) diff --git a/onnx_modifier/static/index.js b/onnx_modifier/static/index.js index 6b474b9..f048c17 100644 --- a/onnx_modifier/static/index.js +++ b/onnx_modifier/static/index.js @@ -116,7 +116,89 @@ host.BrowserHost = class { }); } + // open the the model at flask server and reload + // open mode 0: open new model, 1: add to tail, 2: parallel + remote_open_model(obj_files, append_model = 0) { + var files = Array.from(obj_files); + const acceptd_file = files.filter(file => this._view.accept(file.name)).slice(0, 2); + // console.log(file) + if(acceptd_file.length == 1 && append_model == 0) + { + var file = acceptd_file[0]; + this.upload_filename = file.name; + var form = new FormData(); + form.append('file', file); + + // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful + fetch('/open_model', { + method: 'POST', + body: form + }).then(function (response) { + return response.text(); + }).then(function (text) { + console.log('POST response: '); + // Should be 'OK' if everything was successful + console.log(text); + }); + + if (file) { + this._open(file, files); + this._view.modifier.clearGraph(); + } + } else if (acceptd_file.length == 2 || append_model != 0) { + var form = new FormData(); + var url; + if (append_model == 0) { + url = '/merge_model'; + for(var i = 0; i < acceptd_file.length; i++) + { + form.append('file' + i, acceptd_file[i]); + } + } + else if (append_model != 0) { + url = '/append_model'; + form.append('file', acceptd_file[0]); + } + form.append('method', Number(append_model)); + // console.log(file) + // this.upload_filename = file.name; + let filename = 'unknown'; + // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful + this._view.showLoading(); + fetch(url, { + method: 'POST', + body: form + }).then((response) =>{ + + const contentDisposition = response.headers.get('content-disposition'); + if (contentDisposition) { + const matches = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/i); + if (matches && matches[1]) { + filename = decodeURIComponent(matches[1].replace(/['"]/g, '')); + } + } + if (!response.ok) { + this._view.hideLoading(); + throw new Error(`HTTP error! status: ${response.status}`); + } + var blob = response.blob(); + return blob; + }).then(blob => { + var file = new File([blob], filename); + // console.log('POST response: '); + // // Should be 'OK' if everything was successful + // console.log(text); + if (file) { + this.upload_filename = file.name; + files = []; + files.push(file); + this._open(file, files); + this._view.modifier.clearGraph(); + } + }); + } + } start() { this.window.addEventListener('error', (e) => { @@ -328,40 +410,26 @@ host.BrowserHost = class { }); openFileDialog.addEventListener('change', (e) => { if (e.target && e.target.files && e.target.files.length > 0) { - const files = Array.from(e.target.files); - const file = files.find((file) => this._view.accept(file.name)); - // console.log(file) - this.upload_filename = file.name; - var form = new FormData(); - form.append('file', file); - - // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful - fetch('/open_model', { - method: 'POST', - body: form - }).then(function (response) { - return response.text(); - }).then(function (text) { - console.log('POST response: '); - // Should be 'OK' if everything was successful - console.log(text); - }); - - - if (file) { - this._open(file, files); - this._view.modifier.clearGraph(); - } + this.remote_open_model(e.target.files, this.loadModelMode); } }); + } - const openModelButton = this.document.getElementById('load-model'); - if (openModelButton && openFileDialog) { - openModelButton.addEventListener('click', () => { - openFileDialog.value = ''; - openFileDialog.click(); + + var loadModelDropDown = this.document.getElementById('load-model-dropdown'); + if (loadModelDropDown) { + loadModelDropDown.addEventListener('change', (e) => { + this.loadModelMode = loadModelDropDown.selectedIndex; }); } + + // const openModelButton = this.document.getElementById('load-model'); + // if (openModelButton && openFileDialog) { + // openModelButton.addEventListener('click', () => { + // openFileDialog.value = ''; + // openFileDialog.click(); + // }); + // } const githubButton = this.document.getElementById('github-button'); const githubLink = this.document.getElementById('logo-github'); if (githubButton && githubLink) { @@ -379,27 +447,8 @@ host.BrowserHost = class { this.document.body.addEventListener('drop', (e) => { e.preventDefault(); if (e.dataTransfer && e.dataTransfer.files && e.dataTransfer.files.length > 0) { - const files = Array.from(e.dataTransfer.files); - const file = files.find((file) => this._view.accept(file.name)); - this.upload_filename = file.name; - var form = new FormData(); - form.append('file', file); - - // https://stackoverflow.com/questions/66039996/javascript-fetch-upload-files-to-python-flask-restful - fetch('/open_model', { - method: 'POST', - body: form - }).then(function (response) { - return response.text(); - }).then(function (text) { - console.log('POST response: '); - // Should be 'OK' if everything was successful - console.log(text); - }); - if (file) { - this._open(file, files); - this._view.modifier.clearGraph(); - } + this.remote_open_model(e.dataTransfer.files, this.loadModelMode); + } }); @@ -588,6 +637,7 @@ host.BrowserHost = class { // 'modified_inputs_info' : this.arrayToObject(this.process_modified_inputs(this._view.modifier.inputModificationForSave, // this._view.modifier.renameMap, this._view.modifier.name2NodeStates)), 'rebatch_info' : this.mapToObjectRec(this._view.modifier.reBatchInfo), + 'added_tensor' : this.mapToObjectRec(this._view.modifier.addedTensor), 'changed_initializer' : this.mapToObjectRec(this._view.modifier.initializerEditInfo), 'postprocess_args' : {'shapeInf' : this._view.modifier.downloadWithShapeInf, 'cleanUp' : this._view.modifier.downloadWithCleanUp} }) diff --git a/onnx_modifier/static/modifier.js b/onnx_modifier/static/modifier.js index 925c8ec..802f1b2 100644 --- a/onnx_modifier/static/modifier.js +++ b/onnx_modifier/static/modifier.js @@ -23,6 +23,8 @@ modifier.Modifier = class { this.downloadWithShapeInf = false; this.downloadWithCleanUp = false; + this.addedTensor = new Map(); + } loadModelGraph(model, graphs) { @@ -42,6 +44,17 @@ modifier.Modifier = class { this.name2NodeStatesOrig.set(name, 'Exist'); } this.updateAddNodeDropDown(); + this.updateLoadModelDropDown(); + } + + isOptionExists(selectElement, optionText) { + + for (let i = 0; i < selectElement.options.length; i++) { + if (selectElement.options[i].text === optionText) { + return true; + } + } + return false; } // TODO: add filter feature like here: https://www.w3schools.com/howto/howto_js_dropdown.asp @@ -50,12 +63,29 @@ modifier.Modifier = class { var addNodeDropdown = this.view._host.document.getElementById('add-node-dropdown'); for (const node of this.model.supported_nodes) { // node: [domain, op] - var option = new Option(node[1], node[0] + ':' + node[1]); - // console.log(option) - addNodeDropdown.appendChild(option); + if (!this.isOptionExists(addNodeDropdown, node[1])) { + var option = new Option(node[1], node[0] + ':' + node[1]); + // console.log(option) + addNodeDropdown.appendChild(option); + } } } + updateLoadModelDropDown() { + // update dropdown supported node lost + var loadModelDropdown = this.view._host.document.getElementById('load-model-dropdown'); + this.supported_load_mode = ["new", "tail", "parallel"] + + if (loadModelDropdown.options.length == 0) { + for (let [index, mode] of this.supported_load_mode.entries()) { + var option = new Option(mode, index); + loadModelDropdown.appendChild(option); + } + + } + + } + getShapeTypeInfo(name) { for (var value_info of this.graph._value_info) { if (value_info.name == name && value_info.type && value_info.type.tensor_type) { @@ -79,13 +109,13 @@ modifier.Modifier = class { } - try_get_node_name(op_type) + try_get_node_name(op_type, input_node_id) { - var node_id = (this.addNodeKey++).toString(); // in case input (onnx) node has no name + var node_id = (input_node_id || this.addNodeKey++).toString(); // in case input (onnx) node has no name var modelNodeName = 'custom_added_' + op_type + node_id; if (this.addedNode.has(modelNodeName) || this.name2NodeStates.get(modelNodeName) ){ - modelNodeName = this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'); + modelNodeName = try_get_node_name(op_type, Date.parse(new Date()));//this.randomString(16, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'); } return modelNodeName; } @@ -103,6 +133,69 @@ modifier.Modifier = class { this.applyAndUpdateView(); } + // duplicate a node with ( _cp + unique_id ) as param suffix + duplicateNode(node_name, unique_id = "") { + //avoid to add a existed name node + var srcModelNode = this.name2ModelNode.get(node_name); + if (!srcModelNode.type){ + return; + } + var dstModelNodeName = this.try_get_node_name(srcModelNode.type.name); + + var properties = new Map(); + properties.set('domain', "ai.onnx"); + properties.set('op_type', srcModelNode.type.name); + properties.set('name', dstModelNodeName); + + + var attributes = new Map(); + for (const attribute of srcModelNode.attributes) { + attributes.set(attribute.name, [ + attribute.value?attribute.value.toString():"undefined", + attribute.type||"undefined"]) + // attributes.set(key, modelNode.attributes.get(key)); + } + + var outputs = new Map(); + for (const output of srcModelNode.outputs) { + var dstNameList = []; + for (const srcArg of output.arguments) { + var dstName = srcArg.name + "_cp" + unique_id; + dstNameList.push([dstName, false]); + this.graph.copy_tensor(dstName, srcArg.name); + + } + outputs.set(output.name, dstNameList); + + } + + var inputs = new Map(); + for (const input of srcModelNode.inputs) { + var dstNameList = []; + for (const srcArg of input.arguments) { + var dstName = srcArg.name + "_cp" + unique_id; + + if (this.graph._context._tensors && + this.graph._context._tensors.has(srcArg.name)) { + this.graph.copy_tensor(dstName, srcArg.name); + var initializer_info = this.graph.get_initializer_info(dstName); + if(initializer_info) { + this.addedTensor.set(dstName, initializer_info); + } + + } + dstNameList.push([dstName, false]) + } + inputs.set(input.name, dstNameList) + + } + + this.addedNode.set(dstModelNodeName, + new view.LightNodeInfo(properties, attributes, inputs, outputs)); + this.applyAndUpdateView(); + } + + addModelOutput(node_name) { var modelNode = this.name2ModelNode.get(node_name); // use a output argument as a proxy @@ -402,6 +495,10 @@ modifier.Modifier = class { for (const [modelNodeName, node_info] of this.addedNode) { // console.log(node_info) var node = this.graph.make_custom_added_node(node_info); + if (!node) { + console.log("node not supported yet"); + continue; + } // console.log(node) for (const input of node.inputs) { @@ -489,6 +586,7 @@ modifier.Modifier = class { this.graph.reset_custom_added_node(); this.graph.reset_custom_modified_outputs(); this.graph.reset_custom_modified_inputs(); + this.graph.reset_custom_added_tensors(); } // reset load location var container = this.view._getElementById('graph'); diff --git a/onnx_modifier/static/onnx.js b/onnx_modifier/static/onnx.js index 06c98c9..8031214 100644 --- a/onnx_modifier/static/onnx.js +++ b/onnx_modifier/static/onnx.js @@ -465,6 +465,10 @@ onnx.Graph = class { this._custom_added_inputs = [] this._custom_deleted_inputs = [] + this._custom_added_tensors = new Set() + this._graph_config_max_input_count = 12; + this._graph_config_max_output_count = 12; + // model parameter assignment here! // console.log(graph) for (const initializer of graph.initializer) { @@ -564,6 +568,34 @@ onnx.Graph = class { return this._nodes.concat(this._custom_added_node); } + // for duplicated nodes we should manage the tensors in initializers + // its behavior is a little different with editedInitializers + + reset_custom_added_tensors() { + for (const tensor_name of this._custom_added_tensors) { + this._context._tensors.delete(tensor_name); + } + this._custom_added_tensors = new Set() + + } + + copy_tensor(dst_name, source_name) { + var tensor= this._context.tensor(source_name); + tensor.name = dst_name; + tensor.is_custom_added = 1; + this._context._tensors.set(dst_name, tensor); + this._custom_added_tensors.add(dst_name); + + } + + get_initializer_info(name) { + var initializer = this._context._tensors.get(name).initializer; + if (!initializer) return null; + return initializer.type?[initializer.type.toString().replace(/\s+/g, ''), initializer.value?initializer.toString(1).replace(/\s+/g, ''):""]:null + } + + // + reset_custom_added_node() { this._custom_added_node = [] // this._custom_add_node_io_idx = 0 @@ -576,6 +608,10 @@ onnx.Graph = class { make_custom_added_node(node_info) { // type of node_info == LightNodeInfo const schema = this._context.metadata.type(node_info.properties.get('op_type'), node_info.properties.get('domain')); + if(!schema) + { + return null; + } // console.log(schema) // console.log(node_info.attributes) @@ -583,8 +619,8 @@ onnx.Graph = class { // console.log(node_info.outputs) // var max_input = schema.max_input // var min_input = schema.max_input - var max_custom_add_input_num = Math.min(schema.max_input, 8) // set at most 8 custom_add inputs - var max_custom_add_output_num = Math.min(schema.max_output, 8) // set at most 8 custom_add outputs + var max_custom_add_input_num = Math.min(schema.max_input, this._graph_config_max_input_count) // set at most 12 custom_add inputs + var max_custom_add_output_num = Math.min(schema.max_output, this._graph_config_max_output_count) // set at most 12 custom_add outputs // console.log(node_info) var inputs = [] @@ -605,7 +641,14 @@ onnx.Graph = class { else { var arg_name = 'list_custom_input_' + (this._custom_add_node_io_idx++).toString() } - arg_list.push(this._context.argument(arg_name)) + var arg = this._context.argument(arg_name); + if (node_info_input && node_info_input[j] && node_info_input[j].length == 2) { + arg.is_optional = node_info_input[j][1]; + + } else if (input.option && input.option == 'optional') { + arg.is_optional = true; + } + arg_list.push(arg); } } else { @@ -615,14 +658,21 @@ onnx.Graph = class { else { var arg_name = 'custom_input_' + (this._custom_add_node_io_idx++).toString() } - arg_list = [this._context.argument(arg_name)] + var arg = this._context.argument(arg_name); + if (node_info_input && node_info_input[0] && node_info_input[0].length == 2) { + arg.is_optional = node_info_input[0][1]; + + } else if(input.option && input.option == 'optional') { + arg.is_optional = true; + } + arg_list = [arg] } for (var arg of arg_list) { arg.is_custom_added = true; - if (input.option && input.option == 'optional') { - arg.is_optional = true; - } + // if (input.option && input.option == 'optional') { + // arg.is_optional = true; + // } } inputs.push(new onnx.Parameter(input.name, arg_list)); } @@ -643,6 +693,12 @@ onnx.Graph = class { else { var arg_name = 'list_custom_output_' + (this._custom_add_node_io_idx++).toString() } + if (node_info_output && node_info_output[i] && node_info_output[i].length == 2) { + arg.is_optional = node_info_output[i][1]; + + } else if (output.option && output.option == 'optional') { + arg.is_optional = true; + } arg_list.push(this._context.argument(arg_name)) } } @@ -653,15 +709,20 @@ onnx.Graph = class { else { var arg_name = 'custom_output_' + (this._custom_add_node_io_idx++).toString() } - + if (node_info_output && node_info_output[0] && node_info_output[0].length == 2) { + arg.is_optional = node_info_output[0][1]; + + } else if (output.option && output.option == 'optional') { + arg.is_optional = true; + } arg_list = [this._context.argument(arg_name)] } for (var arg of arg_list) { arg.is_custom_added = true; - if (output.option && output.option == 'optional') { - arg.is_optional = true; - } + // if (output.option && output.option == 'optional') { + // arg.is_optional = true; + // } } outputs.push(new onnx.Parameter(output.name, arg_list)); } @@ -1216,13 +1277,13 @@ onnx.Tensor = class { return this._decode(context, 0); } - toString() { + toString(unlimit=0) { const context = this._context(); // console.log(context) if (context.state) { return ''; } - context.limit = 10000; + context.limit = (unlimit==0)?10000:100000000; const value = this._decode(context, 0); // console.log(value) // console.log(onnx.Tensor._stringify(value, '', ' ')) @@ -1666,19 +1727,34 @@ onnx.Metadata = class { } constructor(data) { + this._order_list = ["Conv", "BatchNormalization", "LeakyRelu", "Concat", "Add", "UserDefined"] this._map = new Map(); + let disorder_maps = this._map; if (data) { const metadata = JSON.parse(data); for (const item of metadata) { - if (!this._map.has(item.module)) { - this._map.set(item.module, new Map()); + if (!disorder_maps.has(item.module)) { + disorder_maps.set(item.module, new Map()); } - const map = this._map.get(item.module); + const map = disorder_maps.get(item.module); if (!map.has(item.name)) { map.set(item.name, []); } map.get(item.name).push(item); } + + // reorder the map, facilitate the addition of nodes + for(let [name, disorder_map] of disorder_maps) { + let ordered_map = new Map(); + for (let order of this._order_list) { + if(disorder_map.has(order)) + { + ordered_map.set(order, disorder_map.get(order)); + disorder_map.delete(order); + } + } + this._map.set(name, new Map([...ordered_map, ...disorder_map])) + } } } @@ -1895,6 +1971,7 @@ onnx.GraphContext = class { } argument(name, original_name) { + if(!original_name) original_name = name; const tensor = this.tensor(name); // console.log(tensor) const type = tensor.initializer ? tensor.initializer.type : tensor.type || null; diff --git a/onnx_modifier/static/view-grapher.css b/onnx_modifier/static/view-grapher.css index 065bca0..7e6a75d 100644 --- a/onnx_modifier/static/view-grapher.css +++ b/onnx_modifier/static/view-grapher.css @@ -136,4 +136,36 @@ } /* render selected nodes */ - .highlight path { stroke: #FF6347; stroke-width: 2px;} \ No newline at end of file + .highlight path { stroke: #FF6347; stroke-width: 2px;} +/* wait for merging the model, it may be extremely slow */ + +.loading-overlay { + display: none; /* 默认隐藏 */ + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgba(0, 0, 0, 0.5); + justify-content: center; + align-items: center; + z-index: 9999; +} + + +.loading-spinner { + border: 8px solid #f3f3f3; + border-top: 8px solid #3498db; + border-radius: 50%; + width: 60px; + height: 60px; + animation: loading-spin 1s linear infinite; +} + +@keyframes loading-spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } + +} + +.selected path { stroke: #FF6347; fill:#3498db; stroke-width: 2px;} diff --git a/onnx_modifier/static/view-grapher.js b/onnx_modifier/static/view-grapher.js index ebbe6f1..842ef05 100644 --- a/onnx_modifier/static/view-grapher.js +++ b/onnx_modifier/static/view-grapher.js @@ -131,6 +131,97 @@ grapher.Graph = class { } } + // highlight the selected nodes + removeHighlight() { + for (const [name,node] of this.modifier.name2ViewNode) { + if(node.element) { + node.element.classList.remove("selected"); + } + } + } + + setHighlight(sets) { + this.removeHighlight(); + for(const set of sets) { + var node = this.modifier.name2ViewNode.get(set); + node.element.classList.add('selected'); + } + + } + + // find nodes between the two selected nodes + + getAllPathNodesRecursive(startNodeName, endNodeName, allPathNodeNames, currentPathNodeNames = [], deadPathNodeNames = [], depth = 0) { + var len = 0; + if (depth > 400) return len; + var direction_switcher = false; // startNode and endNode may be swapped + if(depth === 0 ) { + if(!this.modifier.namedEdges.get(startNodeName) || + !this.modifier.name2ModelNode.get(endNodeName)) { + return len; + } else { + direction_switcher = true; + } + + } + + currentPathNodeNames.push(startNodeName); + + if (startNodeName === endNodeName) { + len = currentPathNodeNames.length; + currentPathNodeNames.forEach(item => allPathNodeNames.add(item)); + // allPathNodeNames = new Set([...allPathNodeNames, ...currentPathNodeNames]); + } else { + + const children = new Set(this.modifier.namedEdges.get(startNodeName) || []); + var currentPathLen = currentPathNodeNames.length; + for (let child of children) { + if(allPathNodeNames.has(child)) + { + currentPathNodeNames.forEach(item => allPathNodeNames.add(item)); + continue; + }else if(deadPathNodeNames.includes(child)) + { + continue; + } + while(currentPathNodeNames.length - currentPathLen > 0) + { + currentPathNodeNames.pop(); + } + len += this.getAllPathNodesRecursive(child, endNodeName, allPathNodeNames, + currentPathNodeNames, deadPathNodeNames, depth + 1); + + } + while(currentPathNodeNames.length - currentPathLen > 0) + { + currentPathNodeNames.pop(); + } + + if (len == 0) { + deadPathNodeNames.push(startNodeName); + } + } + + if (direction_switcher && len == 0) { + deadPathNodeNames = []; + currentPathNodeNames = []; + len = this.getAllPathNodesRecursive(endNodeName, startNodeName, allPathNodeNames, + currentPathNodeNames, deadPathNodeNames, depth + 1); + } + + return len; + } + + getAllPathNodeNames(startNodeName, endNodeName) + { + var allPathNodeNames = new Set([startNodeName]); + this.getAllPathNodesRecursive(startNodeName, endNodeName, allPathNodeNames); + return allPathNodeNames; + + } + + // + build(document, origin) { const createGroup = (name) => { const element = document.createElementNS('http://www.w3.org/2000/svg', 'g'); diff --git a/onnx_modifier/static/view-sidebar.js b/onnx_modifier/static/view-sidebar.js index 3575546..f2b5bdc 100644 --- a/onnx_modifier/static/view-sidebar.js +++ b/onnx_modifier/static/view-sidebar.js @@ -224,6 +224,8 @@ sidebar.NodeSidebar = class { this._addHeader('Model Input Output editing helper'); + this._addButton('Duplicate Node'); + this.add_span(); this._addButton('Add Output'); this.add_span(); this._addButton('Add Input'); @@ -328,6 +330,12 @@ sidebar.NodeSidebar = class { this._host._view.modifier.addModelOutput(this._modelNodeName); }); } + if (title === 'Duplicate Node') { + buttonElement.addEventListener('click', () => { + var time_now = Date.parse(new Date())/1000; + this._host._view.modifier.duplicateNode(this._modelNodeName, time_now); + }); + } if (title === 'Add Input') { buttonElement.addEventListener('click', () => { // show dialog @@ -965,6 +973,61 @@ sidebar.ArgumentView = class { return this._renameAuxelements; } + // just move numpy dataloader commands to a single funtion + add_np_dataloader(inputInitializerVal, inputInitializerType) + { + const editInitializerNumpyVal = this._host.document.createElement('div'); + editInitializerNumpyVal.className = 'sidebar-view-item-value-line-border'; + editInitializerNumpyVal.innerHTML = 'Or import from a *.npy file:'; + this._element.appendChild(editInitializerNumpyVal); + + const openFileButton_ = this._host.document.createElement('button'); + openFileButton_.setAttribute("display", "none"); + openFileButton_.innerHTML = "Open *.npy" + const openFileDialog_ = this._host.document.createElement('input'); + openFileDialog_.setAttribute("type", "file"); + + openFileButton_.addEventListener('click', () => { + openFileDialog_.value = ''; + openFileDialog_.click(); + }); + var orig_arg_name = this._host._view.modifier.getOriginalName(this._param_type, this._modelNodeName, this._param_index, this._arg_index); + openFileDialog_.addEventListener('change', (e) => { + if (e.target && e.target.files && e.target.files.length > 0) { + var reader = new FileReader(); + var context = this; + reader.onload = function() { + var npLoader = new npyjs.Npyjs(); + npLoader.load(reader.result, (out) => { + // `array` is a one-dimensional array of the raw data + // `shape` is a one-dimensional array that holds a numpy-style shape. + // console.log( + // `You loaded an array with ${out.shape} \nelements: ${out.data}.` + // ); + var fmt_tensor = npLoader.format_np(out.data, out.shape); + var dataType = context._argument.type?context._argument.type._dataType:`${out.dtype}[${out.shape.toString()}]` + context._host._view.modifier.changeInitializer(context._modelNodeName, context._parameterName, context._param_type, context._param_index, + context._arg_index, dataType, fmt_tensor); + // [type, value] + + var initializerEditInfo = context._host._view.modifier.initializerEditInfo.get(orig_arg_name) + if (initializerEditInfo) { + // [type, value] + inputInitializerVal.innerHTML = initializerEditInfo[1]; + if(inputInitializerType) { + inputInitializerType.innerHTML = initializerEditInfo[0]; + } + inputInitializerVal.setAttribute("tab-size", '10px'); + } + + }); + }; + reader.readAsArrayBuffer(e.target.files[0]); + } + }); + this._element.appendChild(openFileButton_); + } + toggle() { if (this._expander) { if (this._expander.innerText == '+') { @@ -1022,7 +1085,7 @@ sidebar.ArgumentView = class { this._element.appendChild(location); } - if (initializer) { + if (initializer && !this._argument.is_custom_added) { const editInitializerVal = this._host.document.createElement('div'); editInitializerVal.className = 'sidebar-view-item-value-line-border'; editInitializerVal.innerHTML = 'This is an initializer, you can input a new value for it here:'; @@ -1046,46 +1109,7 @@ sidebar.ArgumentView = class { }); this._element.appendChild(inputInitializerVal); - const editInitializerNumpyVal = this._host.document.createElement('div'); - editInitializerNumpyVal.className = 'sidebar-view-item-value-line-border'; - editInitializerNumpyVal.innerHTML = 'Or import from a *.npy file:'; - this._element.appendChild(editInitializerNumpyVal); - - const openFileButton_ = this._host.document.createElement('button'); - openFileButton_.setAttribute("display", "none"); - openFileButton_.innerHTML = "Open *.npy" - const openFileDialog_ = this._host.document.createElement('input'); - openFileDialog_.setAttribute("type", "file"); - - openFileButton_.addEventListener('click', () => { - openFileDialog_.value = ''; - openFileDialog_.click(); - }); - - openFileDialog_.addEventListener('change', (e) => { - if (e.target && e.target.files && e.target.files.length > 0) { - var reader = new FileReader(); - var context = this; - reader.onload = function() { - var npLoader = new npyjs.Npyjs(); - npLoader.load(reader.result, (out) => { - // `array` is a one-dimensional array of the raw data - // `shape` is a one-dimensional array that holds a numpy-style shape. - // console.log( - // `You loaded an array with ${out.shape} \nelements: ${out.data}.` - // ); - var fmt_tensor = npLoader.format_np(out.data, out.shape); - context._host._view.modifier.changeInitializer(context._modelNodeName, context._parameterName, context._param_type, context._param_index, - context._arg_index, context._argument.type._dataType, fmt_tensor); - // [type, value] - inputInitializerVal.innerHTML = context._host._view.modifier.initializerEditInfo.get(orig_arg_name)[1]; - inputInitializerVal.setAttribute("tab-size", '10px'); - }); - }; - reader.readAsArrayBuffer(e.target.files[0]); - } - }); - this._element.appendChild(openFileButton_); + this.add_np_dataloader(inputInitializerVal, orig_arg_name) // this._element.appendChild(openFileDialog_); } @@ -1142,6 +1166,8 @@ sidebar.ArgumentView = class { this._modelNodeName, this._parameterName, this._param_type, this._param_index, this._arg_index, init_type, init_val); }); this._element.appendChild(inputInitializerType); + + this.add_np_dataloader(inputInitializerVal, inputInitializerType); // <====== input type ====== } diff --git a/onnx_modifier/static/view.js b/onnx_modifier/static/view.js index 8b38c08..3e400a2 100644 --- a/onnx_modifier/static/view.js +++ b/onnx_modifier/static/view.js @@ -26,6 +26,11 @@ view.View = class { direction: 'vertical', mousewheel: 'scroll' }; + + this.isAltKeyDown = false; + this.selectedFisrtNodeName = null; + this.selectedSecondNodeName = null; + this.lastScrollLeft = 0; this.lastScrollTop = 0; this._zoom = 1; @@ -59,6 +64,42 @@ view.View = class { container.addEventListener('scroll', (e) => this._scrollHandler(e)); container.addEventListener('wheel', (e) => this._wheelHandler(e), { passive: false }); container.addEventListener('mousedown', (e) => this._mouseDownHandler(e)); + + // for select nodes + + container.addEventListener('keydown', (e) => { + + if (e.key === 'Alt' && !this.isAltKeyDown) { + this.isAltKeyDown = true; + } + }); + container.addEventListener('keyup', (e) => { + + if (e.key === 'Alt' && this.isAltKeyDown) { + this.isAltKeyDown = false; + this.selectedFisrtNodeName = null; + this.selectedSecondNodeName = null; + this._graph.removeHighlight(); + } + if (this.isAltKeyDown && e.key.toLowerCase() === 'j') { + if(this.isAltKeyDown && this.selectedFisrtNodeName) { + var time_now = Date.parse(new Date()) / 1000; + var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName); + for (const name of sets) { + this.modifier.duplicateNode(name, time_now); + } + } + } else if (this.isAltKeyDown && e.key.toLowerCase() === 'l') { + + if(this.isAltKeyDown && this.selectedFisrtNodeName) { + var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName); + for (const name of sets) { + this._host._view.modifier.deleteSingleNode(name); + } + } + } + }); + switch (this._host.agent) { case 'safari': container.addEventListener('gesturestart', (e) => this._gestureStartHandler(e), false); @@ -74,7 +115,36 @@ view.View = class { }); } + highlight(node, input, modelNodeName) { + if (node) { + if(this.isAltKeyDown) { + if(this.selectedFisrtNodeName) + { + this.selectedSecondNodeName = modelNodeName; + var sets = this._graph.getAllPathNodeNames(this.selectedFisrtNodeName, this.selectedSecondNodeName); + this._graph.setHighlight(sets.add(this.selectedFisrtNodeName)) + } else { + this.selectedFisrtNodeName = modelNodeName; + this.selectedSecondNodeName = modelNodeName; + var sets = new Set([this.selectedFisrtNodeName]); + this._graph.setHighlight(sets); + } + } + + } + + } + + showLoading() { + this._host.document.getElementById('loading').style.display = 'flex'; + } + + hideLoading() { + this._host.document.getElementById('loading').style.display = 'none'; + } + show(page) { + this.hideLoading(); if (!page) { page = (!this._model && !this.activeGraph) ? 'welcome' : 'default'; } @@ -779,7 +849,7 @@ view.View = class { } showNodeProperties(node, input, modelNodeName) { - if (node) { + if (node && !this.isAltKeyDown) { try { // console.log(node) // 注意是onnx.Node, 不是grapher.Node,所以没有update(), 没有element元素 const nodeSidebar = new sidebar.NodeSidebar(this._host, node, modelNodeName); @@ -1077,6 +1147,7 @@ view.Node = class extends grapher.Node { const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location); const title = header.add(null, styles, content, tooltip); title.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName)); + title.on('click', () => this.context.view.highlight(node, null, this.modelNodeName)); if (node.type.nodes && node.type.nodes.length > 0) { const definition = header.add(null, styles, '\u0192', 'Show Function Definition'); definition.on('click', () => this.context.view.pushGraph(node.type)); @@ -1112,6 +1183,7 @@ view.Node = class extends grapher.Node { if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) { const list = this.list(); list.on('click', () => this.context.view.showNodeProperties(node, null, this.modelNodeName)); + list.on('click', () => this.context.view.highlight(node, null, this.modelNodeName)); for (const initializer of initializers) { const argument = initializer.arguments[0]; const type = argument.type; diff --git a/onnx_modifier/templates/index.html b/onnx_modifier/templates/index.html index 5eca38f..abd36b1 100644 --- a/onnx_modifier/templates/index.html +++ b/onnx_modifier/templates/index.html @@ -124,6 +124,20 @@ left: 2px; top: 135px; } +.load-model-text { + font-size: 15px; + position: absolute; + border: none; + left: 2px; + top: 180px; + font-family:"Oliviar Sans Light"; +} +.op-load-model-dropdown { + font-size: 15px; + position:absolute; + left: 2px; + top: 210px; +} .graph-add-input-dropdown { font-size: 15px; } @@ -391,6 +405,9 @@ +
+
+
@@ -440,6 +457,9 @@ + +