Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model-scoring-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -589,5 +589,10 @@
<artifactId>maven-artifact</artifactId>
<version>2.2.1</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-aws</artifactId>
<version>${tap.hadoop2.version}</version>
</dependency>
</dependencies>
</project>
7 changes: 7 additions & 0 deletions model-scoring-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ trustedanalytics.scoring-engine {
archive-mar = ""
}

#-Dtrustedanalytics.aws.fs.s3a.access.key="ACCESS_KEY"
trustedanalytics.aws {
fs.s3a.proxy.host = ""
fs.s3a.proxy.port = ""
fs.s3a.access.key = ""
fs.s3a.secret.key = ""
}
#-Dtrustedanalytics.scoring.port="SOME_PORT"
trustedanalytics.scoring {
identifier = "ia"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
package org.trustedanalytics.scoring

import java.io.File
import java.net.URI
import java.util.{ArrayList => JArrayList}

import com.typesafe.config.ConfigFactory
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.slf4j.LoggerFactory
import org.trustedanalytics.model.archive.format.ModelArchiveFormat
import org.trustedanalytics.scoring.interfaces.Model

object ScoringEngineHelper {
private val logger = LoggerFactory.getLogger(this.getClass)
val config = ConfigFactory.load(this.getClass.getClassLoader)

/**
*
Expand All @@ -39,13 +45,50 @@ object ScoringEngineHelper {
model.output().deep == revisedModel.output().deep
}

private def getAWSConfig(): Configuration = {
val proxyHost = config.getString("trustedanalytics.aws.fs.s3a.proxy.host")
val proxyPort = config.getString("trustedanalytics.aws.fs.s3a.proxy.port")
val accessKey = config.getString("trustedanalytics.aws.fs.s3a.access.key")
val secretKey = config.getString("trustedanalytics.aws.fs.s3a.secret.key")

val cfg = new Configuration()
if(proxyHost != "" && proxyPort != "") {
cfg.set("fs.s3a.proxy.host",proxyHost)
cfg.set("fs.s3a.proxy.port",proxyPort)
}
if(accessKey != "") cfg.set("fs.s3a.access.key",accessKey) else throw new Exception("Configuration do not have AWS access key")
if(secretKey != "") cfg.set("fs.s3a.secret.key",secretKey) else throw new Exception("Configuration do not have AWS secret key")
cfg
}
def getModel(marFilePath: String): Model = {
if (marFilePath != "") {
logger.info("calling ModelArchiveFormat to get the model")
ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None)
}
else {
null
try {
logger.info("Calling ModelArchiveFormat.read() to load the model stored locally")
return ModelArchiveFormat.read(new File(marFilePath), this.getClass.getClassLoader, None)
} catch {
case e: Exception =>
logger.info("Unale to load model from local filesystem, trying to load the model from HDFS")
var tempMarFile: File = null
tempMarFile = File.createTempFile("model", ".mar")
try
{
val cfg = getAWSConfig()
val hdfsFileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(marFilePath), cfg)
hdfsFileSystem.copyToLocalFile(false, new Path(marFilePath), new Path(tempMarFile.getAbsolutePath))
val hdfsMarFilePath = tempMarFile.getAbsolutePath
sys.addShutdownHook(FileUtils.deleteQuietly(tempMarFile)) // Delete temporary directory on exit
return ModelArchiveFormat.read(new File(hdfsMarFilePath), this.getClass.getClassLoader, None)
} catch {
case e: Exception =>
logger.info("Unale to load model from HDFS...\n"+e.getMessage)
logger.info("\n"+e.getStackTraceString)
return null
} finally {
FileUtils.deleteQuietly(tempMarFile)
}
}
} else {
return null
}
}
}