@@ -228,6 +228,7 @@ public void TestDiscriminators()
228228 customDomain . BsonClassMap . RegisterClassMap < BasePerson > ( cm =>
229229 {
230230 cm . AutoMap ( ) ;
231+ cm . SetDiscriminator ( "bp" ) ;
231232 cm . SetIsRootClass ( true ) ;
232233 } ) ;
233234
@@ -247,10 +248,21 @@ public void TestDiscriminators()
247248
248249 var client = CreateClientWithDomain ( customDomain ) ;
249250 var collection = GetTypedCollection < BasePerson > ( client ) ;
251+ var untypedCollection = GetUntypedCollection ( client ) ;
250252
251253 var bp1 = new DerivedPerson1 { Name = "Alice" , Age = 30 , ExtraField1 = "Field1" } ;
252254 var bp2 = new DerivedPerson2 { Name = "Bob" , Age = 40 , ExtraField2 = "Field2" } ;
253- collection . InsertMany ( new BasePerson [ ] { bp1 , bp2 } ) ;
255+ collection . InsertMany ( [ bp1 , bp2 ] ) ;
256+
257+ var retrieved1 = untypedCollection . FindSync ( Builders < BsonDocument > . Filter . Eq ( "_id" , bp1 . Id ) ) . ToList ( ) . Single ( ) ;
258+ var retrieved2 = untypedCollection . FindSync ( Builders < BsonDocument > . Filter . Eq ( "_id" , bp2 . Id ) ) . ToList ( ) . Single ( ) ;
259+
260+ var expectedDiscriminator1 =
261+ $ """ _t" : ["bp", "dp1"]""" ;
262+ var expectedDiscriminator2 =
263+ $ """ _t" : ["bp", "dp2"]""" ;
264+ Assert . Contains ( expectedDiscriminator1 , retrieved1 . ToString ( ) ) ;
265+ Assert . Contains ( expectedDiscriminator2 , retrieved2 . ToString ( ) ) ;
254266
255267 //Aggregate with OfType
256268 var retrievedDerivedPerson1 = collection . Aggregate ( ) . OfType < DerivedPerson1 > ( ) . Single ( ) ;
@@ -317,6 +329,74 @@ void AssertBasePerson(BasePerson expected, BasePerson retrieved)
317329 }
318330 }
319331
332+ [ Fact ]
333+ public void TestDiscriminatorsWithAttributes ( )
334+ {
335+ RequireServer . Check ( ) ;
336+
337+ var customDomain = BsonSerializer . CreateSerializationDomain ( ) ;
338+ customDomain . RegisterSerializer ( new CustomStringSerializer ( ) ) ;
339+
340+ var client = CreateClientWithDomain ( customDomain ) ;
341+ var collection = GetTypedCollection < BasePersonAttribute > ( client ) ;
342+ var untypedCollection = GetUntypedCollection ( client ) ;
343+
344+ var bp1 = new DerivedPersonAttribute1 { Name = "Alice" , Age = 30 , ExtraField1 = "Field1" } ;
345+ var bp2 = new DerivedPersonAttribute2 { Name = "Bob" , Age = 40 , ExtraField2 = "Field2" } ;
346+ collection . InsertMany ( [ bp1 , bp2 ] ) ;
347+
348+ var retrieved1 = untypedCollection . FindSync ( Builders < BsonDocument > . Filter . Eq ( "_id" , bp1 . Id ) ) . ToList ( ) . Single ( ) ;
349+ var retrieved2 = untypedCollection . FindSync ( Builders < BsonDocument > . Filter . Eq ( "_id" , bp2 . Id ) ) . ToList ( ) . Single ( ) ;
350+
351+ var expectedDiscriminator1 =
352+ $ """ _t" : ["bp", "dp1"]""" ;
353+ var expectedDiscriminator2 =
354+ $ """ _t" : ["bp", "dp2"]""" ;
355+ Assert . Contains ( expectedDiscriminator1 , retrieved1 . ToString ( ) ) ;
356+ Assert . Contains ( expectedDiscriminator2 , retrieved2 . ToString ( ) ) ;
357+
358+ //Aggregate with OfType
359+ var retrievedDerivedPerson1 = collection . Aggregate ( ) . OfType < DerivedPersonAttribute1 > ( ) . Single ( ) ;
360+ var retrievedDerivedPerson2 = collection . Aggregate ( ) . OfType < DerivedPersonAttribute2 > ( ) . Single ( ) ;
361+
362+ AssertDerivedPerson1 ( bp1 , retrievedDerivedPerson1 ) ;
363+ AssertDerivedPerson2 ( bp2 , retrievedDerivedPerson2 ) ;
364+
365+ //AppendStage with OfType
366+ retrievedDerivedPerson1 = collection . AsQueryable ( ) . AppendStage ( PipelineStageDefinitionBuilder . OfType < BasePersonAttribute , DerivedPersonAttribute1 > ( ) ) . Single ( ) ;
367+ retrievedDerivedPerson2 = collection . AsQueryable ( ) . AppendStage ( PipelineStageDefinitionBuilder . OfType < BasePersonAttribute , DerivedPersonAttribute2 > ( ) ) . Single ( ) ;
368+
369+ AssertDerivedPerson1 ( bp1 , retrievedDerivedPerson1 ) ;
370+ AssertDerivedPerson2 ( bp2 , retrievedDerivedPerson2 ) ;
371+
372+ //LINQ with OfType
373+ retrievedDerivedPerson1 = collection . AsQueryable ( ) . OfType < DerivedPersonAttribute1 > ( ) . Single ( ) ;
374+ retrievedDerivedPerson2 = collection . AsQueryable ( ) . OfType < DerivedPersonAttribute2 > ( ) . Single ( ) ;
375+
376+ AssertDerivedPerson1 ( bp1 , retrievedDerivedPerson1 ) ;
377+ AssertDerivedPerson2 ( bp2 , retrievedDerivedPerson2 ) ;
378+
379+
380+ void AssertDerivedPerson1 ( DerivedPersonAttribute1 expected , DerivedPersonAttribute1 retrieved )
381+ {
382+ AssertBasePerson ( expected , retrieved ) ;
383+ Assert . Equal ( expected . ExtraField1 , retrieved . ExtraField1 ) ;
384+ }
385+
386+ void AssertDerivedPerson2 ( DerivedPersonAttribute2 expected , DerivedPersonAttribute2 retrieved )
387+ {
388+ AssertBasePerson ( expected , retrieved ) ;
389+ Assert . Equal ( expected . ExtraField2 , retrieved . ExtraField2 ) ;
390+ }
391+
392+ void AssertBasePerson ( BasePersonAttribute expected , BasePersonAttribute retrieved )
393+ {
394+ Assert . Equal ( expected . Id , retrieved . Id ) ;
395+ Assert . Equal ( expected . Name , retrieved . Name ) ;
396+ Assert . Equal ( expected . Age , retrieved . Age ) ;
397+ }
398+ }
399+
320400 private static IMongoCollection < T > GetTypedCollection < T > ( IMongoClient client ) =>
321401 client . GetDatabase ( DriverTestConfiguration . DatabaseNamespace . DatabaseName )
322402 . GetCollection < T > ( DriverTestConfiguration . CollectionNamespace . CollectionName ) ;
@@ -375,6 +455,25 @@ public class DerivedPerson2 : BasePerson
375455 public string ExtraField2 { get ; set ; }
376456 }
377457
458+ [ BsonDiscriminator ( "bp" , RootClass = true ) ]
459+ public class BasePersonAttribute
460+ {
461+ [ BsonId ] public ObjectId Id { get ; set ; } = ObjectId . GenerateNewId ( ) ;
462+ public string Name { get ; set ; }
463+ public int Age { get ; set ; }
464+ }
465+
466+ [ BsonDiscriminator ( "dp1" ) ]
467+ public class DerivedPersonAttribute1 : BasePersonAttribute
468+ {
469+ public string ExtraField1 { get ; set ; }
470+ }
471+
472+ [ BsonDiscriminator ( "dp2" ) ]
473+ public class DerivedPersonAttribute2 : BasePersonAttribute
474+ {
475+ public string ExtraField2 { get ; set ; }
476+ }
378477
379478 // This serializer adds the _appended variable to any serialised string
380479 public class CustomStringSerializer ( string appended = "test" )
0 commit comments