-
Notifications
You must be signed in to change notification settings - Fork 3
Plugin import keras #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Plugin import keras #120
Changes from 14 commits
9708c77
5bfd953
01c312a
90a403f
8979f69
da00cfb
c27bcef
6fcd15a
112abbe
2c697a7
2c93550
acfc268
0caeac1
8684e64
1b8143c
c56b62c
0b1921c
a4d8a71
d36c3b0
0f5a34b
5c35223
6c257ce
1cd436a
aba4e1e
57e71dc
aeb519a
6b1a1d5
8a66175
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
* text eol=lf | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,9 @@ node_modules | |
tmp/ | ||
blob-local-storage/ | ||
notes/ | ||
|
||
#IntelliJ | ||
.idea | ||
|
||
#vscode | ||
.vscode |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
/*globals define*/ | ||
/*eslint-env node, browser*/ | ||
|
||
/** | ||
* Generated by PluginGenerator 2.20.5 from webgme on Tue Sep 10 2019 15:28:36 GMT-0500 (Central Daylight Time). | ||
umesh-timalsina marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
* A plugin that inherits from the PluginBase. To see source code documentation about available | ||
* properties and methods visit %host%/docs/source/PluginBase.html. | ||
*/ | ||
|
||
define([ | ||
'plugin/PluginConfig', | ||
'text!./metadata.json', | ||
'plugin/PluginBase', | ||
'./utils/JSONModelMaps', | ||
'./utils/json-model-parser', | ||
], function ( | ||
PluginConfig, | ||
pluginMetadata, | ||
PluginBase, | ||
ModelMaps, | ||
JSONLayerParser, | ||
) { | ||
'use strict'; | ||
|
||
pluginMetadata = JSON.parse(pluginMetadata); | ||
|
||
|
||
/** | ||
* Initializes a new instance of ImportKeras. | ||
* @class | ||
* @augments {PluginBase} | ||
* @classdesc This class represents the plugin ImportKeras. | ||
* @constructor | ||
*/ | ||
var ImportKeras = function () { | ||
// Call base class' constructor. | ||
PluginBase.call(this); | ||
this.pluginMetadata = pluginMetadata; | ||
}; | ||
|
||
/** | ||
* Metadata associated with the plugin. Contains id, name, version, description, icon, configStructure etc. | ||
* This is also available at the instance at this.pluginMetadata. | ||
* @type {object} | ||
*/ | ||
ImportKeras.metadata = pluginMetadata; | ||
|
||
// Prototypical inheritance from PluginBase. | ||
ImportKeras.prototype = Object.create(PluginBase.prototype); | ||
ImportKeras.prototype.constructor = ImportKeras; | ||
|
||
/** | ||
* Main function for the plugin to execute. This will perform the execution. | ||
* Notes: | ||
* - Always log with the provided logger.[error,warning,info,debug]. | ||
* - Do NOT put any user interaction logic UI, etc. inside this method. | ||
* - callback always has to be called even if error happened. | ||
* | ||
* @param {function(Error|null, plugin.PluginResult)} callback - the result callback | ||
*/ | ||
ImportKeras.prototype.main = async function(callback){ | ||
let srcJsonHash = this.getCurrentConfig().srcModel; | ||
if (!srcJsonHash) { | ||
callback(new Error('Keras Json Not Provided'), this.result); | ||
return; | ||
} | ||
try { | ||
this.archName = this.getCurrentConfig().archName; | ||
umesh-timalsina marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
let metadata = await this.blobClient.getMetadata(srcJsonHash); | ||
this.archName = this.archName ? this.archName : metadata.name.replace('.json', ''); | ||
let modelJson = await this.blobClient.getObjectAsJSON(srcJsonHash); | ||
this.modelInfo = JSONLayerParser.flatten(modelJson).config; | ||
this.addNewArchitecture(); | ||
this.addLayers(); | ||
await this.addConnections(); | ||
await this.save('Completed Import Model'); | ||
this.result.setSuccess(true); | ||
callback(null, this.result); | ||
} catch (err) { | ||
this.logger.debug(`Something Went Wrong, Error Message: ${err}`); | ||
callback(err, this.result); | ||
} | ||
}; | ||
|
||
ImportKeras.prototype.addNewArchitecture = function () { | ||
// Add Architecture | ||
this.importedArch = this.core.createNode({ | ||
umesh-timalsina marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
parent: this.activeNode, | ||
base: this.META.Architecture | ||
}); | ||
const uniqueName = this.archName; | ||
this.core.setAttribute(this.importedArch, 'name', uniqueName); | ||
|
||
this.logger.debug(`Added ${uniqueName} as a new architecture.`); | ||
}; | ||
|
||
// This should add layers. constraints, initializers, regularizers as well as activations to the layer. | ||
ImportKeras.prototype.addLayers = function () { | ||
let layers = this.modelInfo.layers; | ||
let layerToCreate = null; | ||
this.layerInfo = {}; | ||
layers.forEach((layer) => { | ||
layerToCreate = this._getMetaTypeForClass(layer.class_name); | ||
let layerNode = this.core.createNode({ | ||
parent: this.importedArch, | ||
base: this.META[layerToCreate] | ||
}); | ||
this.logger.debug(`Added ${layerToCreate}\ | ||
to ${this.core.getAttribute(this.importedArch, 'name')}`); | ||
|
||
// Add all attributes, from the model JSON, as well as from the layers schema | ||
this._addLayerAttributes(layerNode, layer); | ||
this._addConfigurableNodes(layerNode, layer); | ||
}); | ||
}; | ||
|
||
// All the attributes, which do not require a separate node to be created | ||
// 1. First find the validAttributeNames for the layer | ||
// 2. If the name is in layer, set it. | ||
// 3. If the name is in layer.config, set it. | ||
// 4. Finally, check the layers schema for remaining attributes. | ||
ImportKeras.prototype._addLayerAttributes = function (layerNode, attrObj) { | ||
let config = attrObj.config; | ||
let validAttributeNamesForThisLayer = this.core.getValidAttributeNames(layerNode); | ||
let configKeys = Object.keys(config); | ||
let remainingKeys = Object.keys(attrObj) | ||
.filter(value => value !== 'config'); | ||
|
||
validAttributeNamesForThisLayer.forEach((attribute) => { | ||
if (remainingKeys.indexOf(this._jsonConfigToNodeAttr(attribute)) > -1) { | ||
this.core.setAttribute(layerNode, attribute, this._toPythonIterable(attrObj[this._jsonConfigToNodeAttr(attribute)])); | ||
this.logger.debug(`Set ${attribute} for ${this.core.getGuid(layerNode)}` + | ||
` to ${this.core.getAttribute(layerNode, attribute)}`); | ||
} else if (configKeys.indexOf(this._jsonConfigToNodeAttr(attribute)) > -1) { | ||
this.core.setAttribute(layerNode, attribute, this._toPythonIterable(config[this._jsonConfigToNodeAttr(attribute)])); | ||
this.logger.debug(`Set ${attribute} for ${this.core.getGuid(layerNode)}` + | ||
` to ${this.core.getAttribute(layerNode, attribute)}`); | ||
|
||
} | ||
}); | ||
let layerName = this.core.getAttribute(layerNode, 'name'); | ||
this.layerInfo[layerName] = layerNode; | ||
|
||
}; | ||
|
||
ImportKeras.prototype._addConfigurableNodes = function (layerNode, layerConfig) { | ||
let allPointerNames = this.core.getValidPointerNames(layerNode); | ||
let config = layerConfig.config; | ||
this.logger.debug(`Layer ${this.core.getAttribute(layerNode, 'name')}` + | ||
` has following configurable attributes ${allPointerNames.join(', ')}`); | ||
allPointerNames.filter(pointer => !!config[pointer]) | ||
.forEach((pointer) => { | ||
if (typeof config[pointer] == 'string') { | ||
this._addStringPropertiesNode(layerNode, config, pointer); | ||
} else { | ||
this._addPluggableNodes(layerNode, config, pointer); | ||
} | ||
}); | ||
}; | ||
|
||
ImportKeras.prototype._addStringPropertiesNode = function(layerNode, config, pointer) { | ||
let configurableNode = this.core.createNode({ | ||
parent: layerNode, | ||
base: this.META[config[pointer]] | ||
}); | ||
// This will set the necessary pointers. | ||
// Of things like activations and so on... | ||
this.core.setPointer(layerNode, pointer, configurableNode); | ||
this.logger.debug(`Added ${this.core.getAttribute(configurableNode, 'name')}` | ||
+ ` as ${pointer} to the layer ` | ||
+ `${this.core.getAttribute(layerNode, 'name')}`); | ||
}; | ||
|
||
ImportKeras.prototype._addPluggableNodes = function (layerNode, config, pointer){ | ||
let pluggableNode = this.core.createNode({ | ||
parent: layerNode, | ||
base: this.META[config[pointer].class_name] | ||
}); | ||
this.logger.debug(`Added ${this.core.getAttribute(pluggableNode, 'name')} as` + | ||
` ${pointer} to the layer ${this.core.getAttribute(layerNode, 'name')}`); | ||
let validArgumentsForThisNode = this.core.getValidAttributeNames(pluggableNode); | ||
let configForAddedNode = config[pointer].config; | ||
if (validArgumentsForThisNode && configForAddedNode) { | ||
validArgumentsForThisNode.forEach((arg) => { | ||
if (configForAddedNode[arg]) | ||
this.core.setAttribute(pluggableNode, arg, | ||
this._toPythonIterable(configForAddedNode[arg])); | ||
}); | ||
} | ||
}; | ||
|
||
// This method is used to convert javascript arrays to a | ||
// tuple/ list(Python) in string Representation. Needed for | ||
// Code generation. | ||
ImportKeras.prototype._toPythonIterable = function (obj) { | ||
if (obj == null) { | ||
return 'None'; | ||
} | ||
if (obj instanceof Array) { | ||
return '[' + obj.map((val) => { | ||
return this._toPythonIterable(val); | ||
}).join(', ') + ']'; | ||
} else { | ||
return obj; | ||
} | ||
}; | ||
|
||
|
||
// This method is used to convert various classes from the | ||
// keras JSON to deepforge meta Nodes | ||
ImportKeras.prototype._getMetaTypeForClass = function (kerasClass) { | ||
let classMap = ModelMaps.CLASS_MAP; | ||
if (Object.keys(classMap).indexOf(kerasClass) > -1) { | ||
return classMap[kerasClass]; | ||
} else { | ||
return kerasClass; | ||
} | ||
}; | ||
|
||
// Change the model converts some JSON | ||
// layer attributes names (from keras) to the correct | ||
// attribute name for deepforge-keras nodes | ||
ImportKeras.prototype._jsonConfigToNodeAttr = function (orgName) { | ||
let argMap = ModelMaps.ARGUMENTS_MAP; | ||
if (Object.keys(argMap).indexOf(orgName) > -1) { | ||
return argMap[orgName]; | ||
} else { | ||
return orgName; | ||
} | ||
}; | ||
|
||
/**********************The functions below Add Connections between the Layers**************/ | ||
ImportKeras.prototype.addConnections = async function () { | ||
// this._findNodeByName(); | ||
let layers = this.modelInfo.layers; | ||
let layerInputConnections = {}; | ||
let layerOutputConnections = {}; | ||
let connections = null; | ||
layers.forEach((layer) => { | ||
layerInputConnections[layer.name] = []; | ||
layerOutputConnections[layer.name] = []; | ||
}); | ||
|
||
layers.forEach((layer) => { | ||
if (layer.inbound_nodes.length > 0) { | ||
connections = layer.inbound_nodes; | ||
connections.forEach((connection) => { | ||
|
||
if (this._layerNameExists(connection)) { | ||
layerInputConnections[layer.name].push(connection); | ||
layerOutputConnections[connection].push(layer.name); | ||
} | ||
}); | ||
|
||
} | ||
}); | ||
|
||
await this._updateConnections(layerInputConnections); | ||
}; | ||
|
||
|
||
ImportKeras.prototype._layerNameExists = function (layerName) { | ||
let allLayerNames = this.modelInfo.layers.map((layer) => { | ||
return layer.name; | ||
}); | ||
|
||
return allLayerNames.indexOf(layerName) > -1; | ||
}; | ||
|
||
ImportKeras.prototype._updateConnections = function (inputs) { | ||
let allLayerNames = Object.keys(inputs); | ||
return Promise.all(allLayerNames.map((layerName) => { | ||
let dstLayer = this.layerInfo[layerName]; | ||
let srcs = inputs[layerName]; | ||
return Promise.all(srcs.map((src, index) => { | ||
return this._connectLayers(this.layerInfo[src], dstLayer, index); | ||
})); | ||
})); | ||
}; | ||
|
||
ImportKeras.prototype._connectLayers = async function (srcLayer, dstLayer, index) { | ||
|
||
let srcPort = await this.core.loadMembers(srcLayer, 'outputs'); | ||
let dstPort = await this.core.loadMembers(dstLayer, 'inputs'); | ||
if (dstPort && srcPort) { | ||
this.core.addMember(dstPort[0], 'source', srcPort[0]); | ||
this.core.setMemberRegistry(dstPort[0], | ||
'source', | ||
this.core.getPath(srcPort[0]), | ||
'position', {x: 100, y: 100}); | ||
this.core.setMemberAttribute(dstPort[0], 'source', | ||
this.core.getPath(srcPort[0]), | ||
'index', index); | ||
this.logger.debug(`Connected ${this.core.getAttribute(srcLayer, 'name')} ` + | ||
`with ${this.core.getAttribute(dstLayer, 'name')} as input ${index}`); | ||
} | ||
}; | ||
|
||
return ImportKeras; | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"id": "ImportKeras", | ||
"name": "ImportKeras", | ||
"version": "0.1.0", | ||
"description": "", | ||
"icon": { | ||
"class": "glyphicon glyphicon-download-alt", | ||
"src": "" | ||
}, | ||
"disableServerSideExecution": false, | ||
"disableBrowserSideExecution": false, | ||
"dependencies": [], | ||
"writeAccessRequired": false, | ||
"configStructure": [{ | ||
"name": "srcModel", | ||
"displayName": "Keras Model JSON File", | ||
"description": "The Keras model JSON to import.", | ||
"valueType": "asset", | ||
"readOnly": false | ||
}, | ||
{ | ||
"name": "archName", | ||
"displayName": "Model Name", | ||
"description": "Name of the imported model", | ||
"valueType": "string", | ||
"readOnly": false | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/*globals define */ | ||
define([], function() { | ||
const ModelMaps = {}; | ||
|
||
ModelMaps.CLASS_MAP = { | ||
InputLayer: 'Input' | ||
}; | ||
|
||
ModelMaps.ARGUMENTS_MAP = { | ||
batch_shape: 'batch_input_shape' | ||
}; | ||
|
||
ModelMaps.ModelTypes = { | ||
sequential : 'Sequential', | ||
functional : 'Model' | ||
}; | ||
|
||
ModelMaps.AbstractLayerTypeMapping = { | ||
Activation: 'activation', | ||
ActivityRegularization: 'activity_regularizer' | ||
}; | ||
|
||
|
||
return ModelMaps; | ||
|
||
}); |
Uh oh!
There was an error while loading. Please reload this page.