diff --git a/project/Build.scala b/project/Build.scala index d371ec7..f2b953f 100644 --- a/project/Build.scala +++ b/project/Build.scala @@ -41,7 +41,8 @@ object Build extends Build { .settings(libSettings: _*) .settings(libraryDependencies ++= compile( - akkaActor + akkaActor, + scalaReflect ) ++ test( scalatest diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 1e5a245..2c59538 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -7,6 +7,8 @@ object Dependencies { val jsonLenses = "net.virtual-void" %% "json-lenses" % "0.6.1" val scalatest = "org.scalatest" %% "scalatest" % "2.2.6" val base64 = "me.lessis" %% "base64" % "0.2.0" + val scalaReflect = "org.scala-lang" % "scala-reflect" % "2.11.8" + // Only used by the tests val sprayJsonShapeless = "com.github.fommil" %% "spray-json-shapeless" % "1.1.0" diff --git a/stamina-core/src/main/scala/stamina/Persister.scala b/stamina-core/src/main/scala/stamina/Persister.scala index 64fee54..861058b 100644 --- a/stamina-core/src/main/scala/stamina/Persister.scala +++ b/stamina-core/src/main/scala/stamina/Persister.scala @@ -2,13 +2,14 @@ package stamina import scala.reflect._ import scala.util._ +import scala.reflect.runtime.universe.{Try ⇒ _, _} /** * A Persister[T, V] provides a type-safe API for persisting instances of T * at version V and unpersisting persisted instances of T for all versions up * to and including version V. */ -abstract class Persister[T: ClassTag, V <: Version: VersionInfo](val key: String) { +abstract class Persister[T: ClassTag: TypeTag, V <: Version: VersionInfo](val key: String) { lazy val currentVersion = Version.numberFor[V] def persist(t: T): Persisted @@ -18,8 +19,14 @@ abstract class Persister[T: ClassTag, V <: Version: VersionInfo](val key: String def canUnpersist(p: Persisted): Boolean = p.key == key && p.version <= currentVersion private[stamina] def convertToT(any: AnyRef): Option[T] = any match { - case t: T ⇒ Some(t) - case _ ⇒ None + case t: T ⇒ t match { + case tagged: TypeTagged[_] ⇒ + if (typeTag[T].tpe =:= tagged.tag.tpe) Some(t) + else None + case _ ⇒ + Some(t) + } + case _ ⇒ None } private[stamina] def persistAny(any: AnyRef): Persisted = { @@ -37,5 +44,11 @@ abstract class Persister[T: ClassTag, V <: Version: VersionInfo](val key: String } } - private[stamina] val tag = classTag[T] + private[stamina] val tag = + if (typeTag[T].tpe <:< typeOf[TypeTagged[_]]) typeTag[T].tpe + else classTag[T].runtimeClass +} + +object Persister { + implicit def optionTypeTag[E](implicit typeTag: TypeTag[E]) = Some(typeTag) } diff --git a/stamina-core/src/main/scala/stamina/Persisters.scala b/stamina-core/src/main/scala/stamina/Persisters.scala index f724eb9..fbf3abe 100644 --- a/stamina-core/src/main/scala/stamina/Persisters.scala +++ b/stamina-core/src/main/scala/stamina/Persisters.scala @@ -33,7 +33,7 @@ case class Persisters(persisters: List[Persister[_, _]]) { private def requireNoOverlappingTags() = { val overlappingTags = persisters.groupBy(_.tag).filter(_._2.length > 1).mapValues(_.map(_.key)) - val warnings = overlappingTags.map { case (tag, keys) ⇒ s"""Persisters with keys ${keys.mkString("'", "', '", "'")} all persist ${tag.runtimeClass}.""" } + val warnings = overlappingTags.map { case (tag, keys) ⇒ s"""Persisters with keys ${keys.mkString("'", "', '", "'")} all persist ${tag}.""" } require(overlappingTags.isEmpty, s"""Overlapping persisters: ${warnings.mkString(" ")}""") } diff --git a/stamina-core/src/main/scala/stamina/TypeTagged.scala b/stamina-core/src/main/scala/stamina/TypeTagged.scala new file mode 100644 index 0000000..b2cdadb --- /dev/null +++ b/stamina-core/src/main/scala/stamina/TypeTagged.scala @@ -0,0 +1,22 @@ +package stamina + +import scala.reflect.runtime.universe._ + +/** + * This marker interface can be used to solve the problem of nested json formats of the same + * root format. + * By example: + * trait Event[E] { + * } + * + * case class Payload1() + * case class Payload2() + * + * The Persister cannot distinguish Event[Payload1] from Event[Payload2] due to type erasure within + * Akka serialization to AnyRef. Therefore you can mark your Event envelop using a TypeTagged marker + * interface which whould allow stamina to choose the correct persister for the kind of event payload + * which should get serialized. + */ +class TypeTagged[X: TypeTag] extends AnyRef { + @transient val tag = typeTag[X] +} diff --git a/stamina-core/src/test/scala/stamina/PersistersSpec.scala b/stamina-core/src/test/scala/stamina/PersistersSpec.scala index 97f6935..710b4ac 100644 --- a/stamina-core/src/test/scala/stamina/PersistersSpec.scala +++ b/stamina-core/src/test/scala/stamina/PersistersSpec.scala @@ -1,4 +1,11 @@ package stamina +import scala.reflect.runtime.universe._ + +case class PayloadEvent[E: TypeTag](payload: E) extends TypeTagged[PayloadEvent[E]] +case class Payload1(txt: String) +case class Payload2(value: Int) +case object Payload3 +case class Payload4[T](list: List[T]) class PersistersSpec extends StaminaSpec { import TestDomain._ @@ -57,4 +64,34 @@ class PersistersSpec extends StaminaSpec { be thrownBy unpersist(Persisted("item", 1, ByteString("not an item"))) } } + + "Persist overlapping events using the TypeTagged marker interface" should { + val persister1 = persister[PayloadEvent[Payload1]]("payload1") + val persister2 = persister[PayloadEvent[Payload2]]("payload2") + val persister4 = persister[PayloadEvent[Payload4[_]]]("payload4") + + val event1 = PayloadEvent(Payload1("test")) + val event2 = PayloadEvent(Payload2(123)) + val event3 = PayloadEvent(Payload3) + val event4a = PayloadEvent(Payload4(List(1, 2, 3))) + val event4b = PayloadEvent(Payload4(List("a", "b", "c"))) + + val nestedPersisters = Persisters(persister1, persister2, persister4) + import nestedPersisters._ + + "Persist nested events correctly" in { + canPersist(event1) should be(true) + canPersist(event2) should be(true) + canPersist(event3) should be(false) + + // We currently don't support tagged types with abstract parameters: + canPersist(event4a) should be(false) + canPersist(event4b) should be(false) + } + + "correctly implement canUnpersist()" in { + canUnpersist(persister1.persist(event1)) should be(true) + canUnpersist(persister2.persist(event2)) should be(true) + } + } } diff --git a/stamina-core/src/test/scala/stamina/TestOnlyPersister.scala b/stamina-core/src/test/scala/stamina/TestOnlyPersister.scala index bc6c737..3b180c6 100644 --- a/stamina-core/src/test/scala/stamina/TestOnlyPersister.scala +++ b/stamina-core/src/test/scala/stamina/TestOnlyPersister.scala @@ -1,17 +1,18 @@ package stamina -import scala.reflect._ import akka.actor._ import akka.serialization._ +import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ object TestOnlyPersister { private val system = ActorSystem("TestOnlyPersister") private val javaSerializer = new JavaSerializer(system.asInstanceOf[ExtendedActorSystem]) import javaSerializer._ - def persister[T <: AnyRef: ClassTag](key: String): Persister[T, V1] = new JavaPersister[T](key) + def persister[T <: AnyRef: ClassTag: TypeTag](key: String): Persister[T, V1] = new JavaPersister[T](key) - private class JavaPersister[T <: AnyRef: ClassTag](key: String) extends Persister[T, V1](key) { + private class JavaPersister[T <: AnyRef: ClassTag: TypeTag](key: String) extends Persister[T, V1](key) { def persist(t: T): Persisted = Persisted(key, currentVersion, toBinary(t)) def unpersist(p: Persisted): T = { if (canUnpersist(p)) fromBinary(p.bytes.toArray).asInstanceOf[T] diff --git a/stamina-json/src/main/scala/stamina/json/json.scala b/stamina-json/src/main/scala/stamina/json/json.scala index 6a72d44..22e299f 100644 --- a/stamina-json/src/main/scala/stamina/json/json.scala +++ b/stamina-json/src/main/scala/stamina/json/json.scala @@ -2,6 +2,7 @@ package stamina import scala.reflect.ClassTag import spray.json._ +import scala.reflect.runtime.universe._ import migrations._ @@ -45,7 +46,7 @@ package object json { * and unpersist version 1. Use this function to produce the initial persister * for a new domain class/event/entity. */ - def persister[T: RootJsonFormat: ClassTag](key: String): JsonPersister[T, V1] = new V1JsonPersister[T](key) + def persister[T: RootJsonFormat: ClassTag](key: String)(implicit typeTag: TypeTag[T]): JsonPersister[T, V1] = new V1JsonPersister[T](key) /** * Creates a JsonPersister[T, V] where V is a version greater than V1. @@ -53,7 +54,7 @@ package object json { * JsonMigrator[V] to migrate any values older than version V to version V before * unpersisting them. */ - def persister[T: RootJsonFormat: ClassTag, V <: Version: VersionInfo: MigratableVersion](key: String, migrator: JsonMigrator[V]): JsonPersister[T, V] = new VnJsonPersister[T, V](key, migrator) + def persister[T: RootJsonFormat: ClassTag, V <: Version: VersionInfo: MigratableVersion](key: String, migrator: JsonMigrator[V])(implicit typeTag: TypeTag[T]): JsonPersister[T, V] = new VnJsonPersister[T, V](key, migrator) private[json] def toJsonBytes[T](t: T)(implicit writer: RootJsonWriter[T]): ByteString = ByteString(writer.write(t).compactPrint) private[json] def fromJsonBytes[T](bytes: ByteString)(implicit reader: RootJsonReader[T]): T = reader.read(parseJson(bytes)) @@ -64,12 +65,12 @@ package json { /** * Simple abstract marker superclass to unify (and hide) the two internal Persister implementations. */ - sealed abstract class JsonPersister[T: RootJsonFormat: ClassTag, V <: Version: VersionInfo](key: String) extends Persister[T, V](key) { + sealed abstract class JsonPersister[T: RootJsonFormat: ClassTag: TypeTag, V <: Version: VersionInfo](key: String) extends Persister[T, V](key) { private[json] def cannotUnpersist(p: Persisted) = - s"""JsonPersister[${implicitly[ClassTag[T]].runtimeClass.getSimpleName}, V${currentVersion}](key = "${key}") cannot unpersist data with key "${p.key}" and version ${p.version}.""" + s"""JsonPersister[${implicitly[TypeTag[T]].tpe.baseClasses.head.name}, V${currentVersion}](key = "${key}") cannot unpersist data with key "${p.key}" and version ${p.version}.""" } - private[json] class V1JsonPersister[T: RootJsonFormat: ClassTag](key: String) extends JsonPersister[T, V1](key) { + private[json] class V1JsonPersister[T: RootJsonFormat: ClassTag: TypeTag](key: String) extends JsonPersister[T, V1](key) { def persist(t: T): Persisted = Persisted(key, currentVersion, toJsonBytes(t)) def unpersist(p: Persisted): T = { if (canUnpersist(p)) fromJsonBytes[T](p.bytes) @@ -77,7 +78,7 @@ package json { } } - private[json] class VnJsonPersister[T: RootJsonFormat: ClassTag, V <: Version: VersionInfo: MigratableVersion](key: String, migrator: JsonMigrator[V]) extends JsonPersister[T, V](key) { + private[json] class VnJsonPersister[T: RootJsonFormat: ClassTag: TypeTag, V <: Version: VersionInfo: MigratableVersion](key: String, migrator: JsonMigrator[V]) extends JsonPersister[T, V](key) { override def canUnpersist(p: Persisted): Boolean = p.key == key && migrator.canMigrate(p.version) def persist(t: T): Persisted = Persisted(key, currentVersion, toJsonBytes(t)) diff --git a/stamina-json/src/test/scala/stamina/json/OverlappingPersistersSpec.scala b/stamina-json/src/test/scala/stamina/json/OverlappingPersistersSpec.scala index a544757..8f3e8df 100644 --- a/stamina-json/src/test/scala/stamina/json/OverlappingPersistersSpec.scala +++ b/stamina-json/src/test/scala/stamina/json/OverlappingPersistersSpec.scala @@ -3,6 +3,7 @@ package json import spray.json._ import DefaultJsonProtocol._ +import scala.reflect.runtime.universe._ class OverlappingPersisterSpec extends StaminaJsonSpec { import OverlappingPersisterSpecDomain._ @@ -17,28 +18,21 @@ class OverlappingPersisterSpec extends StaminaJsonSpec { /** #43 In the future we might want to support this situation instead of failing at initialization time */ "correctly handle overlapping persisters" in { - val e = intercept[IllegalArgumentException] { - Persisters( - persister[Event[Payload1]]("payload1"), - persister[Event[Payload2]]("payload2") - ) - } - e.getMessage() should be("requirement failed: Overlapping persisters: Persisters with keys 'payload1', 'payload2' all persist class stamina.json.OverlappingPersisterSpecDomain$Event.") - - /** - * When we actually want to support this situation, then this should work: - * - * val event1 = Event(Payload1("abcd")) - * persisters.unpersist(persisters.persist(event1)) should equal(event1) - * val event2 = Event(Payload2(42)) - * persisters.unpersist(persisters.persist(event2)) should equal(event2) - */ + val persisters = Persisters( + persister[Event[Payload1]]("payload1"), + persister[Event[Payload2]]("payload2") + ) + + val event1 = Event(Payload1("abcd")) + persisters.unpersist(persisters.persist(event1)) should equal(event1) + val event2 = Event(Payload2(42)) + persisters.unpersist(persisters.persist(event2)) should equal(event2) } } } object OverlappingPersisterSpecDomain { - case class Event[P](payload: P) + case class Event[P: TypeTag](payload: P) extends TypeTagged[Event[P]] case class Payload1(msg: String) case class Payload2(value: Int) }