diff --git a/src/main/scala/nak/cluster/GDBSCAN.scala b/src/main/scala/nak/cluster/GDBSCAN.scala index be602ec..7332163 100644 --- a/src/main/scala/nak/cluster/GDBSCAN.scala +++ b/src/main/scala/nak/cluster/GDBSCAN.scala @@ -1,5 +1,7 @@ package nak.cluster +import scala.language.{ implicitConversions, postfixOps } + import breeze.numerics._ import breeze.linalg._ import breeze.util._ @@ -26,13 +28,14 @@ class GDBSCAN[T]( * @param data - each row is treated as a feature vector * @return clusters - a list of clusters with */ - def cluster(data: DenseMatrix[T]): Seq[Cluster[T]] = { + def cluster(data: DenseMatrix[T]): Seq[AbstractCluster[T]] = { // Visited - using row indices val visited = MutableSet[Point[T]]() val clustered = MutableSet[Point[T]]() // Init points val points = for (row <- 0 until data.rows) yield Point(row)(data(row, ::).inner) + val noise = Noise[T]() // Start clustering points.collect { @@ -44,12 +47,13 @@ class GDBSCAN[T]( expand(point, neighbours, cluster)(points, visited, clustered) Some(cluster) } else { + noise add point None // noise } - }.flatten // remove noise + }.flatten :+ noise } - private def expand(point: Point[T], neighbours: Seq[Point[T]], cluster: Cluster[T])(implicit points: Seq[Point[T]], visited: MutableSet[Point[T]], clustered: MutableSet[Point[T]]) { + private def expand(point: Point[T], neighbours: Seq[Point[T]], cluster: AbstractCluster[T])(points: Seq[Point[T]], visited: MutableSet[Point[T]], clustered: MutableSet[Point[T]]) { cluster add point clustered add point neighbours.foldLeft(neighbours) { @@ -80,19 +84,25 @@ object GDBSCAN { override def toString() = s"[$row]: $value" } - /** Cluster description */ - case class Cluster[T](id: Long) { - private var _points = ListBuffer[Point[T]]() - + abstract class AbstractCluster[T] { + val id: Long + + protected var _points = ListBuffer[Point[T]]() + def add(p: Point[T]) { _points += p } - + def points: Seq[Point[T]] = Seq(_points: _*) + } + /** Cluster description */ + case class Cluster[T](id: Long) extends AbstractCluster[T] { override def toString() = s"Cluster [$id]\t:\t${_points.size} points\t${_points mkString "|"}" } + case class Noise[T](id: Long = -1) extends AbstractCluster[T] + } /** @@ -122,4 +132,4 @@ object DBSCAN { def isCorePoint(minPoints: Double)(point: Point[Double], neighbours: Seq[Point[Double]]): Boolean = { neighbours.size >= minPoints } -} \ No newline at end of file +} diff --git a/src/test/scala/nak/cluster/GDBSCANTest.scala b/src/test/scala/nak/cluster/GDBSCANTest.scala index fa200ef..c7715c5 100644 --- a/src/test/scala/nak/cluster/GDBSCANTest.scala +++ b/src/test/scala/nak/cluster/GDBSCANTest.scala @@ -28,9 +28,10 @@ class GDBSCANTest extends FlatSpec with Matchers { val cluster = gdbscan cluster input val clusterPoints = cluster.map(_.points.map(_.value.toArray)) - cluster.size shouldBe 2 + cluster.size shouldBe 3 clusterPoints(0) should contain only (Array(0.9, 1.0), Array(1.0, 1.0), Array(1.0, 1.1)) clusterPoints(1) should contain only (Array(15.0, 15.0), Array(15.0, 14.1), Array(15.3, 15.0)) + clusterPoints(2) should contain only (Array(5.0,5.0)) } it should "work with custom predicates" in { @@ -55,8 +56,9 @@ class GDBSCANTest extends FlatSpec with Matchers { val cluster = gdbscan cluster input val clusterPoints = cluster.map(_.points.map(_.value.toArray)) - cluster.size shouldBe 2 + cluster.size shouldBe 3 clusterPoints(0) should contain only (Array(1.0), Array(3.0)) clusterPoints(1) should contain only (Array(2.0), Array(4.0)) + clusterPoints(2) shouldBe empty } -} \ No newline at end of file +}