diff --git a/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala b/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala index 48a812400..ef5e9276e 100644 --- a/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala +++ b/kernel/src/main/scala/org/apache/toree/magic/builtin/AddJar.scala @@ -18,17 +18,18 @@ package org.apache.toree.magic.builtin import java.io.{File, PrintStream} -import java.net.URL +import java.net.{URL, URI} import java.nio.file.{Files, Paths} - import org.apache.toree.magic._ import org.apache.toree.magic.builtin.AddJar._ import org.apache.toree.magic.dependencies._ import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils} import com.typesafe.config.Config +import org.apache.hadoop.fs.Path import org.apache.toree.plugins.annotations.Event object AddJar { + val HADOOP_FS_SCHEMES = Set("hdfs", "s3", "s3n", "file") private var jarDir:Option[String] = None @@ -63,18 +64,18 @@ class AddJar private def printStream = new PrintStream(outputStream) /** - * Retrieves file name from URL. + * Retrieves file name from a URI. * - * @param location The remote location (URL) - * @return The name of the remote URL, or an empty string if one does not exist + * @param location a URI + * @return The file name of the remote URI, or an empty string if one does not exist */ def getFileFromLocation(location: String): String = { - val url = new URL(location) - val file = url.getFile.split("/") - if (file.length > 0) { - file.last + val uri = new URI(location) + val pathParts = uri.getPath.split("/") + if (pathParts.nonEmpty) { + pathParts.last } else { - "" + "" } } @@ -122,10 +123,27 @@ class AddJar // Report beginning of download printStream.println(s"Starting download from $jarRemoteLocation") - downloadFile( - new URL(jarRemoteLocation), - new File(downloadLocation).toURI.toURL - ) + val jar = URI.create(jarRemoteLocation) + if (HADOOP_FS_SCHEMES.contains(jar.getScheme)) { + val conf = kernel.sparkContext.hadoopConfiguration + val jarPath = new Path(jarRemoteLocation) + val fs = jarPath.getFileSystem(conf) + val destPath = if (downloadLocation.startsWith("file:")) { + new Path(downloadLocation) + } else { + new Path("file:" + downloadLocation) + } + + fs.copyToLocalFile( + false /* keep original file */, + jarPath, destPath, + true /* don't create checksum files */) + } else { + downloadFile( + new URL(jarRemoteLocation), + new File(downloadLocation).toURI.toURL + ) + } // Report download finished printStream.println(s"Finished download of $jarName") diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala index 1c7b3fc64..8d1f44ba2 100644 --- a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala @@ -91,7 +91,8 @@ class AddJarSpec extends FunSpec with Matchers with MockitoSugar { url = """http://www.example.com/remotecontent?filepath=/path/to/someJar.jar""" jarName = addJarMagic.getFileFromLocation(url) - assert(jarName == "someJar.jar") + // File names come from the path, not from the query fragment + assert(jarName == "remotecontent") url = """http://www.example.com/""" jarName = addJarMagic.getFileFromLocation(url)