diff --git a/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaDisplayers.scala b/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaDisplayers.scala index 08a8f07b7..b144e2ccd 100644 --- a/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaDisplayers.scala +++ b/scala-interpreter/src/main/scala/org/apache/toree/kernel/interpreter/scala/ScalaDisplayers.scala @@ -23,8 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try import org.apache.spark.SparkContext -import org.apache.spark.sql.Row -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import jupyter.Displayer import jupyter.Displayers import jupyter.MIMETypes @@ -64,6 +63,13 @@ object ScalaDisplayers { } }) + Displayers.register(classOf[DataFrame], new Displayer[DataFrame] { + override def display(df: DataFrame): util.Map[String, String] = toJava { + val (text, html) = displayRows(df.head(20), Some(df.schema.fieldNames)) + Map(MIMEType.PlainText -> text, MIMEType.TextHtml -> html) + } + }) + Displayers.register(classOf[Array[Row]], new Displayer[Array[Row]] { override def display(arr: Array[Row]): util.Map[String, String] = toJava { val (text, html) = displayRows(arr)