diff --git a/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs b/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs
index 8767bbe1..f168760a 100644
--- a/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs
+++ b/src/MongoDB.EntityFrameworkCore/Query/Visitors/MongoEFToLinqTranslatingExpressionVisitor.cs
@@ -21,6 +21,7 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;
using MongoDB.Bson;
+using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver;
using MongoDB.Driver.Linq;
@@ -32,9 +33,10 @@ namespace MongoDB.EntityFrameworkCore.Query.Visitors;
///
/// Visits the tree resolving any query context parameter bindings and EF references so the query can be used with the MongoDB V3 LINQ provider.
///
-internal sealed class MongoEFToLinqTranslatingExpressionVisitor : ExpressionVisitor
+internal sealed class MongoEFToLinqTranslatingExpressionVisitor : System.Linq.Expressions.ExpressionVisitor
{
- private static readonly MethodInfo MqlFieldMethodInfo = typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!;
+ private static readonly MethodInfo MqlFieldMethodInfo =
+ typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!;
private readonly QueryContext _queryContext;
private readonly Expression _source;
@@ -56,17 +58,17 @@ public MethodCallExpression Translate(
{
if (efQueryExpression == null) // No LINQ methods, e.g. Direct ToList() against DbSet
{
- return InjectAsBsonDocumentMethod(_source, BsonDocumentSerializer.Instance);
+ return ApplyAsSerializer(_source, BsonDocumentSerializer.Instance, typeof(BsonDocument));
}
var query = (MethodCallExpression)Visit(efQueryExpression)!;
if (resultCardinality == ResultCardinality.Enumerable)
{
- return InjectAsBsonDocumentMethod(query, BsonDocumentSerializer.Instance);
+ return ApplyAsSerializer(query, BsonDocumentSerializer.Instance, typeof(BsonDocument));
}
- var documentQueryableSource = InjectAsBsonDocumentMethod(query.Arguments[0], BsonDocumentSerializer.Instance);
+ var documentQueryableSource = ApplyAsSerializer(query.Arguments[0], BsonDocumentSerializer.Instance, typeof(BsonDocument));
return Expression.Call(
null,
@@ -74,11 +76,12 @@ public MethodCallExpression Translate(
documentQueryableSource);
}
- private static MethodCallExpression InjectAsBsonDocumentMethod(
+ private static MethodCallExpression ApplyAsSerializer(
Expression query,
- BsonDocumentSerializer resultSerializer)
+ IBsonSerializer resultSerializer,
+ Type resultType)
{
- var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], typeof(BsonDocument));
+ var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], resultType);
var serializerExpression = Expression.Constant(resultSerializer, resultSerializer.GetType());
return Expression.Call(
@@ -106,6 +109,20 @@ private static MethodCallExpression InjectAsBsonDocumentMethod(
break;
+ // Wrap OfType with As(serializer) to re-attach the custom serializer in LINQ3
+ case MethodCallExpression
+ {
+ Method.Name: nameof(Queryable.OfType), Method.IsGenericMethod: true, Arguments.Count: 1
+ } ofTypeCall
+ when ofTypeCall.Method.DeclaringType == typeof(Queryable):
+ var resultType = ofTypeCall.Method.GetGenericArguments()[0];
+ var resultEntityType = _queryContext.Context.Model.FindEntityType(resultType)
+ ?? throw new NotSupportedException($"OfType type '{resultType.ShortDisplayName()
+ }' does not map to an entity type.");
+ var resultSerializer = _bsonSerializerFactory.GetEntitySerializer(resultEntityType);
+ var translatedOfTypeCall = Expression.Call(null, ofTypeCall.Method, Visit(ofTypeCall.Arguments[0])!);
+ return ApplyAsSerializer(translatedOfTypeCall, resultSerializer, resultType);
+
// Replace object.Equals(Property(p, "propName"), ConstantExpression) elements generated by EF's Find.
case MethodCallExpression {Method.Name: nameof(object.Equals), Object: null, Arguments.Count: 2} methodCallExpression:
var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]))!;
diff --git a/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs b/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs
index 461df1fd..2645e7de 100644
--- a/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs
+++ b/tests/MongoDB.EntityFrameworkCore.FunctionalTests/Mapping/DiscriminatorTests.cs
@@ -32,7 +32,22 @@ public void Uses_real_property_type_discriminator_property_for_read_and_write()
using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.ToList();
- Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer));
+ Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active
+ && e is Customer
+ {
+ Name: "Customer 1"
+ });
+ Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active
+ && e is Customer
+ {
+ Name: "Customer 2"
+ });
+ Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Inactive
+ && e is Customer
+ {
+ Name: "Customer 1"
+ });
+
Assert.Single(entities, e => e.EntityType == "SubClient" && e.GetType() == typeof(SubCustomer));
Assert.Single(entities, e => e.EntityType == "Order" && e.GetType() == typeof(Order));
Assert.Single(entities, e => e.EntityType == "Supplier" && e.GetType() == typeof(Supplier));
@@ -48,7 +63,9 @@ public void Uses_shadow_property_type_discriminator_for_read_and_write()
using var db = SingleEntityDbContext.Create(collection, ShadowPropertyConfiguredModel);
var entities = db.Entities.ToList();
- Assert.Single(entities, e => e is Customer {Name: "Customer 1"});
+ Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 1"});
+ Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 2"});
+ Assert.Single(entities, e => e.Status == Status.Inactive && e is Customer {Name: "Customer 1"});
Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"});
Assert.Single(entities, e => e is Order {OrderReference: "Order 1"});
Assert.Single(entities, e => e is Supplier {Name: "Supplier 1"});
@@ -225,9 +242,11 @@ public void Returns_correct_entity_with_OfType_query()
using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.OfType().ToList();
- Assert.Single(entities, e => e.Name == "Customer 1");
+ Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Active});
+ Assert.Single(entities, e => e is {Name: "Customer 2", Status: Status.Active});
+ Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Inactive});
Assert.Single(entities, e => e.Name == "SubCustomer 1");
- Assert.Equal(2, entities.Count);
+ Assert.Equal(4, entities.Count);
}
[Fact]
@@ -238,12 +257,31 @@ public void Returns_correct_entities_with_mixed_query()
using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.OfType().Where(e => e is Customer || e.GetType() == typeof(Order)).ToList();
- Assert.Equal(3, entities.Count);
- Assert.Single(entities, e => e is Customer {Name: "Customer 1"});
+ Assert.Equal(5, entities.Count);
+ Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Active});
+ Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Inactive});
+ Assert.Single(entities, e => e is Customer {Name: "Customer 2", Status: Status.Active});
Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"});
Assert.Single(entities, e => e is Order {OrderReference: "Order 1"});
}
+ [Fact]
+ public void OfType_does_not_break_entity_serializer_association()
+ {
+ var collection = database.CreateCollection();
+ SetupTestData(SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel));
+
+ using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
+
+ var allActiveCustomers = db.Entities.OfType().Where(e => e.Status == Status.Active);
+ Assert.All(allActiveCustomers, f => Assert.Equal(Status.Active, f.Status));
+
+ var activeCustomer1 = db.Entities.Where(e => e.Status == Status.Active).OfType()
+ .Single(c => c.Name == "Customer 1");
+ Assert.Equal("Customer 1", activeCustomer1.Name);
+ Assert.Equal(Status.Active, activeCustomer1.Status);
+ }
+
[Fact]
public void TablePerType_throws_NotSupportedException()
{
@@ -282,6 +320,7 @@ private static void RealPropertyConfiguredModel(ModelBuilder mb)
.HasValue("Order")
.HasValue("OrderEx")
.HasValue("Contact");
+ mb.Entity().Property(e => e.Status).HasConversion(e => e.ToString(), s => Enum.Parse(s));
}
private static void ShadowPropertyConfiguredModel(ModelBuilder mb)
@@ -294,11 +333,14 @@ private static void ShadowPropertyConfiguredModel(ModelBuilder mb)
.HasValue("Order")
.HasValue("OrderEx")
.HasValue("Contact");
+ mb.Entity().Property(e => e.Status).HasConversion(e => e.ToString(), s => Enum.Parse(s));
}
private static void SetupTestData(DbContext db)
{
- db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St"});
+ db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Active});
+ db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Inactive});
+ db.Add(new Customer {Name = "Customer 2", ShippingAddress = "123 Main St", Status = Status.Active});
db.Add(new Supplier {Name = "Supplier 1", Products = ["Product 1", "Product 2"]});
db.Add(new SubCustomer {Name = "SubCustomer 1", ShippingAddress = "3.5 Inch Dr.", AccountingCode = 123});
db.Add(new Order {OrderReference = "Order 1"});
@@ -309,10 +351,18 @@ private static void SetupTestData(DbContext db)
db.Dispose();
}
+ enum Status
+ {
+ Active,
+ Inactive,
+ Unused
+ }
+
class BaseEntity
{
public ObjectId _id { get; set; }
public string? EntityType { get; set; }
+ public Status Status { get; set; } = Status.Inactive;
}
class Customer : BaseEntity