Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class ExpressionInfo {
"collection_funcs", "predicate_funcs", "conditional_funcs", "conversion_funcs",
"csv_funcs", "datetime_funcs", "generator_funcs", "hash_funcs", "json_funcs",
"lambda_funcs", "map_funcs", "math_funcs", "misc_funcs", "string_funcs", "struct_funcs",
"window_funcs", "xml_funcs", "table_funcs"));
"window_funcs", "xml_funcs", "table_funcs", "url_funcs"));

private static final Set<String> validSources =
new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", "java_udf"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,6 @@ object FunctionRegistry {
expressionBuilder("lpad", LPadExpressionBuilder),
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
expression[ParseUrl]("parse_url"),
expression[StringLocate]("position", true),
expression[FormatString]("printf", true),
expression[RegExpExtract]("regexp_extract"),
Expand Down Expand Up @@ -586,6 +585,11 @@ object FunctionRegistry {
expression[XPathString]("xpath_string"),
expression[RegExpCount]("regexp_count"),

// url functions
expression[UrlEncode]("url_encode"),
expression[UrlDecode]("url_decode"),
expression[ParseUrl]("parse_url"),

// datetime functions
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import java.net.{URI, URISyntaxException}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64}
import java.util.{HashMap, Locale, Map => JMap}
import java.util.regex.Pattern

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -1626,181 +1624,6 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera
copy(str = newFirst, len = newSecond, pad = newThird)
}

object ParseUrl {
private val HOST = UTF8String.fromString("HOST")
private val PATH = UTF8String.fromString("PATH")
private val QUERY = UTF8String.fromString("QUERY")
private val REF = UTF8String.fromString("REF")
private val PROTOCOL = UTF8String.fromString("PROTOCOL")
private val FILE = UTF8String.fromString("FILE")
private val AUTHORITY = UTF8String.fromString("AUTHORITY")
private val USERINFO = UTF8String.fromString("USERINFO")
private val REGEXPREFIX = "(&|^)"
private val REGEXSUBFIX = "=([^&]*)"
}

/**
* Extracts a part from a URL
*/
@ExpressionDescription(
usage = "_FUNC_(url, partToExtract[, key]) - Extracts a part from a URL.",
examples = """
Examples:
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST');
spark.apache.org
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY');
query=1
> SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query');
1
""",
since = "2.0.0",
group = "string_funcs")
case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.get.ansiEnabled)
extends Expression with ExpectsInputTypes with CodegenFallback {
def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled)

override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
override def dataType: DataType = StringType
override def prettyName: String = "parse_url"

// If the url is a constant, cache the URL object so that we don't need to convert url
// from UTF8String to String to URL for every row.
@transient private lazy val cachedUrl = children(0) match {
case Literal(url: UTF8String, _) if url ne null => getUrl(url)
case _ => null
}

// If the key is a constant, cache the Pattern object so that we don't need to convert key
// from UTF8String to String to StringBuilder to String to Pattern for every row.
@transient private lazy val cachedPattern = children(2) match {
case Literal(key: UTF8String, _) if key ne null => getPattern(key)
case _ => null
}

// If the partToExtract is a constant, cache the Extract part function so that we don't need
// to check the partToExtract for every row.
@transient private lazy val cachedExtractPartFunc = children(1) match {
case Literal(part: UTF8String, _) => getExtractPartFunc(part)
case _ => null
}

import ParseUrl._

override def checkInputDataTypes(): TypeCheckResult = {
if (children.size > 3 || children.size < 2) {
TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
} else {
super[ExpectsInputTypes].checkInputDataTypes()
}
}

private def getPattern(key: UTF8String): Pattern = {
Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
}

private def getUrl(url: UTF8String): URI = {
try {
new URI(url.toString)
} catch {
case e: URISyntaxException if failOnError =>
throw QueryExecutionErrors.invalidUrlError(url, e)
case _: URISyntaxException => null
}
}

private def getExtractPartFunc(partToExtract: UTF8String): URI => String = {

// partToExtract match {
// case HOST => _.toURL().getHost
// case PATH => _.toURL().getPath
// case QUERY => _.toURL().getQuery
// case REF => _.toURL().getRef
// case PROTOCOL => _.toURL().getProtocol
// case FILE => _.toURL().getFile
// case AUTHORITY => _.toURL().getAuthority
// case USERINFO => _.toURL().getUserInfo
// case _ => (url: URI) => null
// }

partToExtract match {
case HOST => _.getHost
case PATH => _.getRawPath
case QUERY => _.getRawQuery
case REF => _.getRawFragment
case PROTOCOL => _.getScheme
case FILE =>
(url: URI) =>
if (url.getRawQuery ne null) {
url.getRawPath + "?" + url.getRawQuery
} else {
url.getRawPath
}
case AUTHORITY => _.getRawAuthority
case USERINFO => _.getRawUserInfo
case _ => (url: URI) => null
}
}

private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
val m = pattern.matcher(query.toString)
if (m.find()) {
UTF8String.fromString(m.group(2))
} else {
null
}
}

private def extractFromUrl(url: URI, partToExtract: UTF8String): UTF8String = {
if (cachedExtractPartFunc ne null) {
UTF8String.fromString(cachedExtractPartFunc.apply(url))
} else {
UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
}
}

private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
if (cachedUrl ne null) {
extractFromUrl(cachedUrl, partToExtract)
} else {
val currentUrl = getUrl(url)
if (currentUrl ne null) {
extractFromUrl(currentUrl, partToExtract)
} else {
null
}
}
}

override def eval(input: InternalRow): Any = {
val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
if (evaluated.contains(null)) return null
if (evaluated.size == 2) {
parseUrlWithoutKey(evaluated(0), evaluated(1))
} else {
// 3-arg, i.e. QUERY with key
assert(evaluated.size == 3)
if (evaluated(1) != QUERY) {
return null
}

val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
if (query eq null) {
return null
}

if (cachedPattern ne null) {
extractValueFromQuery(query, cachedPattern)
} else {
extractValueFromQuery(query, getPattern(evaluated(2)))
}
}
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ParseUrl =
copy(children = newChildren)
}

/**
* Returns the input formatted according do printf-style format strings
*/
Expand Down
Loading