diff --git a/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs b/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs index 8767bbe1..f168760a 100644 --- a/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs +++ b/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs @@ -21,6 +21,7 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Query; using MongoDB.Bson; +using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver; using MongoDB.Driver.Linq; @@ -32,9 +33,10 @@ namespace MongoDB.EntityFrameworkCore.Query.Visitors; /// /// Visits the tree resolving any query context parameter bindings and EF references so the query can be used with the MongoDB V3 LINQ provider. /// -internal sealed class MongoEFToLinqTranslatingExpressionVisitor : ExpressionVisitor +internal sealed class MongoEFToLinqTranslatingExpressionVisitor : System.Linq.Expressions.ExpressionVisitor { - private static readonly MethodInfo MqlFieldMethodInfo = typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!; + private static readonly MethodInfo MqlFieldMethodInfo = + typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!; private readonly QueryContext _queryContext; private readonly Expression _source; @@ -56,17 +58,17 @@ public MethodCallExpression Translate( { if (efQueryExpression == null) // No LINQ methods, e.g. Direct ToList() against DbSet { - return InjectAsBsonDocumentMethod(_source, BsonDocumentSerializer.Instance); + return ApplyAsSerializer(_source, BsonDocumentSerializer.Instance, typeof(BsonDocument)); } var query = (MethodCallExpression)Visit(efQueryExpression)!; if (resultCardinality == ResultCardinality.Enumerable) { - return InjectAsBsonDocumentMethod(query, BsonDocumentSerializer.Instance); + return ApplyAsSerializer(query, BsonDocumentSerializer.Instance, typeof(BsonDocument)); } - var documentQueryableSource = InjectAsBsonDocumentMethod(query.Arguments[0], BsonDocumentSerializer.Instance); + var documentQueryableSource = ApplyAsSerializer(query.Arguments[0], BsonDocumentSerializer.Instance, typeof(BsonDocument)); return Expression.Call( null, @@ -74,11 +76,12 @@ public MethodCallExpression Translate( documentQueryableSource); } - private static MethodCallExpression InjectAsBsonDocumentMethod( + private static MethodCallExpression ApplyAsSerializer( Expression query, - BsonDocumentSerializer resultSerializer) + IBsonSerializer resultSerializer, + Type resultType) { - var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], typeof(BsonDocument)); + var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], resultType); var serializerExpression = Expression.Constant(resultSerializer, resultSerializer.GetType()); return Expression.Call( @@ -106,6 +109,20 @@ private static MethodCallExpression InjectAsBsonDocumentMethod( break; + // Wrap OfType with As(serializer) to re-attach the custom serializer in LINQ3 + case MethodCallExpression + { + Method.Name: nameof(Queryable.OfType), Method.IsGenericMethod: true, Arguments.Count: 1 + } ofTypeCall + when ofTypeCall.Method.DeclaringType == typeof(Queryable): + var resultType = ofTypeCall.Method.GetGenericArguments()[0]; + var resultEntityType = _queryContext.Context.Model.FindEntityType(resultType) + ?? throw new NotSupportedException($"OfType type '{resultType.ShortDisplayName() + }' does not map to an entity type."); + var resultSerializer = _bsonSerializerFactory.GetEntitySerializer(resultEntityType); + var translatedOfTypeCall = Expression.Call(null, ofTypeCall.Method, Visit(ofTypeCall.Arguments[0])!); + return ApplyAsSerializer(translatedOfTypeCall, resultSerializer, resultType); + // Replace object.Equals(Property(p, "propName"), ConstantExpression) elements generated by EF's Find. case MethodCallExpression {Method.Name: nameof(object.Equals), Object: null, Arguments.Count: 2} methodCallExpression: var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]))!; diff --git a/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs b/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs index 461df1fd..2645e7de 100644 --- a/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs +++ b/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs @@ -32,7 +32,22 @@ public void Uses_real_property_type_discriminator_property_for_read_and_write() using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel); var entities = db.Entities.ToList(); - Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer)); + Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active + && e is Customer + { + Name: "Customer 1" + }); + Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active + && e is Customer + { + Name: "Customer 2" + }); + Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Inactive + && e is Customer + { + Name: "Customer 1" + }); + Assert.Single(entities, e => e.EntityType == "SubClient" && e.GetType() == typeof(SubCustomer)); Assert.Single(entities, e => e.EntityType == "Order" && e.GetType() == typeof(Order)); Assert.Single(entities, e => e.EntityType == "Supplier" && e.GetType() == typeof(Supplier)); @@ -48,7 +63,9 @@ public void Uses_shadow_property_type_discriminator_for_read_and_write() using var db = SingleEntityDbContext.Create(collection, ShadowPropertyConfiguredModel); var entities = db.Entities.ToList(); - Assert.Single(entities, e => e is Customer {Name: "Customer 1"}); + Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 1"}); + Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 2"}); + Assert.Single(entities, e => e.Status == Status.Inactive && e is Customer {Name: "Customer 1"}); Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"}); Assert.Single(entities, e => e is Order {OrderReference: "Order 1"}); Assert.Single(entities, e => e is Supplier {Name: "Supplier 1"}); @@ -225,9 +242,11 @@ public void Returns_correct_entity_with_OfType_query() using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel); var entities = db.Entities.OfType().ToList(); - Assert.Single(entities, e => e.Name == "Customer 1"); + Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Active}); + Assert.Single(entities, e => e is {Name: "Customer 2", Status: Status.Active}); + Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Inactive}); Assert.Single(entities, e => e.Name == "SubCustomer 1"); - Assert.Equal(2, entities.Count); + Assert.Equal(4, entities.Count); } [Fact] @@ -238,12 +257,31 @@ public void Returns_correct_entities_with_mixed_query() using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel); var entities = db.Entities.OfType().Where(e => e is Customer || e.GetType() == typeof(Order)).ToList(); - Assert.Equal(3, entities.Count); - Assert.Single(entities, e => e is Customer {Name: "Customer 1"}); + Assert.Equal(5, entities.Count); + Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Active}); + Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Inactive}); + Assert.Single(entities, e => e is Customer {Name: "Customer 2", Status: Status.Active}); Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"}); Assert.Single(entities, e => e is Order {OrderReference: "Order 1"}); } + [Fact] + public void OfType_does_not_break_entity_serializer_association() + { + var collection = database.CreateCollection(); + SetupTestData(SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel)); + + using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel); + + var allActiveCustomers = db.Entities.OfType().Where(e => e.Status == Status.Active); + Assert.All(allActiveCustomers, f => Assert.Equal(Status.Active, f.Status)); + + var activeCustomer1 = db.Entities.Where(e => e.Status == Status.Active).OfType() + .Single(c => c.Name == "Customer 1"); + Assert.Equal("Customer 1", activeCustomer1.Name); + Assert.Equal(Status.Active, activeCustomer1.Status); + } + [Fact] public void TablePerType_throws_NotSupportedException() { @@ -282,6 +320,7 @@ private static void RealPropertyConfiguredModel(ModelBuilder mb) .HasValue("Order") .HasValue("OrderEx") .HasValue("Contact"); + mb.Entity().Property(e => e.Status).HasConversion(e => e.ToString(), s => Enum.Parse(s)); } private static void ShadowPropertyConfiguredModel(ModelBuilder mb) @@ -294,11 +333,14 @@ private static void ShadowPropertyConfiguredModel(ModelBuilder mb) .HasValue("Order") .HasValue("OrderEx") .HasValue("Contact"); + mb.Entity().Property(e => e.Status).HasConversion(e => e.ToString(), s => Enum.Parse(s)); } private static void SetupTestData(DbContext db) { - db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St"}); + db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Active}); + db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Inactive}); + db.Add(new Customer {Name = "Customer 2", ShippingAddress = "123 Main St", Status = Status.Active}); db.Add(new Supplier {Name = "Supplier 1", Products = ["Product 1", "Product 2"]}); db.Add(new SubCustomer {Name = "SubCustomer 1", ShippingAddress = "3.5 Inch Dr.", AccountingCode = 123}); db.Add(new Order {OrderReference = "Order 1"}); @@ -309,10 +351,18 @@ private static void SetupTestData(DbContext db) db.Dispose(); } + enum Status + { + Active, + Inactive, + Unused + } + class BaseEntity { public ObjectId _id { get; set; } public string? EntityType { get; set; } + public Status Status { get; set; } = Status.Inactive; } class Customer : BaseEntity