Skip to content

Commit

Permalink
Ensure operations after OfType use EntitySerializer. (#170)
Browse files Browse the repository at this point in the history
  • Loading branch information
damieng authored Dec 12, 2024
1 parent 54d151b commit 311e90e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;
using MongoDB.Bson;
using MongoDB.Bson.Serialization;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver;
using MongoDB.Driver.Linq;
Expand All @@ -32,9 +33,10 @@ namespace MongoDB.EntityFrameworkCore.Query.Visitors;
/// <summary>
/// Visits the tree resolving any query context parameter bindings and EF references so the query can be used with the MongoDB V3 LINQ provider.
/// </summary>
internal sealed class MongoEFToLinqTranslatingExpressionVisitor : ExpressionVisitor
internal sealed class MongoEFToLinqTranslatingExpressionVisitor : System.Linq.Expressions.ExpressionVisitor
{
private static readonly MethodInfo MqlFieldMethodInfo = typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!;
private static readonly MethodInfo MqlFieldMethodInfo =
typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!;

private readonly QueryContext _queryContext;
private readonly Expression _source;
Expand All @@ -56,29 +58,30 @@ public MethodCallExpression Translate(
{
if (efQueryExpression == null) // No LINQ methods, e.g. Direct ToList() against DbSet
{
return InjectAsBsonDocumentMethod(_source, BsonDocumentSerializer.Instance);
return ApplyAsSerializer(_source, BsonDocumentSerializer.Instance, typeof(BsonDocument));
}

var query = (MethodCallExpression)Visit(efQueryExpression)!;

if (resultCardinality == ResultCardinality.Enumerable)
{
return InjectAsBsonDocumentMethod(query, BsonDocumentSerializer.Instance);
return ApplyAsSerializer(query, BsonDocumentSerializer.Instance, typeof(BsonDocument));
}

var documentQueryableSource = InjectAsBsonDocumentMethod(query.Arguments[0], BsonDocumentSerializer.Instance);
var documentQueryableSource = ApplyAsSerializer(query.Arguments[0], BsonDocumentSerializer.Instance, typeof(BsonDocument));

return Expression.Call(
null,
query.Method.GetGenericMethodDefinition().MakeGenericMethod(typeof(BsonDocument)),
documentQueryableSource);
}

private static MethodCallExpression InjectAsBsonDocumentMethod(
private static MethodCallExpression ApplyAsSerializer(
Expression query,
BsonDocumentSerializer resultSerializer)
IBsonSerializer resultSerializer,
Type resultType)
{
var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], typeof(BsonDocument));
var asMethodInfo = AsMethodInfo.MakeGenericMethod(query.Type.GenericTypeArguments[0], resultType);
var serializerExpression = Expression.Constant(resultSerializer, resultSerializer.GetType());

return Expression.Call(
Expand Down Expand Up @@ -106,6 +109,20 @@ private static MethodCallExpression InjectAsBsonDocumentMethod(

break;

// Wrap OfType<T> with As(serializer) to re-attach the custom serializer in LINQ3
case MethodCallExpression
{
Method.Name: nameof(Queryable.OfType), Method.IsGenericMethod: true, Arguments.Count: 1
} ofTypeCall
when ofTypeCall.Method.DeclaringType == typeof(Queryable):
var resultType = ofTypeCall.Method.GetGenericArguments()[0];
var resultEntityType = _queryContext.Context.Model.FindEntityType(resultType)
?? throw new NotSupportedException($"OfType type '{resultType.ShortDisplayName()
}' does not map to an entity type.");
var resultSerializer = _bsonSerializerFactory.GetEntitySerializer(resultEntityType);
var translatedOfTypeCall = Expression.Call(null, ofTypeCall.Method, Visit(ofTypeCall.Arguments[0])!);
return ApplyAsSerializer(translatedOfTypeCall, resultSerializer, resultType);

// Replace object.Equals(Property(p, "propName"), ConstantExpression) elements generated by EF's Find.
case MethodCallExpression {Method.Name: nameof(object.Equals), Object: null, Arguments.Count: 2} methodCallExpression:
var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]))!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,22 @@ public void Uses_real_property_type_discriminator_property_for_read_and_write()

using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.ToList();
Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer));
Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active
&& e is Customer
{
Name: "Customer 1"
});
Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Active
&& e is Customer
{
Name: "Customer 2"
});
Assert.Single(entities, e => e.EntityType == "Client" && e.GetType() == typeof(Customer) && e.Status == Status.Inactive
&& e is Customer
{
Name: "Customer 1"
});

Assert.Single(entities, e => e.EntityType == "SubClient" && e.GetType() == typeof(SubCustomer));
Assert.Single(entities, e => e.EntityType == "Order" && e.GetType() == typeof(Order));
Assert.Single(entities, e => e.EntityType == "Supplier" && e.GetType() == typeof(Supplier));
Expand All @@ -48,7 +63,9 @@ public void Uses_shadow_property_type_discriminator_for_read_and_write()

using var db = SingleEntityDbContext.Create(collection, ShadowPropertyConfiguredModel);
var entities = db.Entities.ToList();
Assert.Single(entities, e => e is Customer {Name: "Customer 1"});
Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 1"});
Assert.Single(entities, e => e.Status == Status.Active && e is Customer {Name: "Customer 2"});
Assert.Single(entities, e => e.Status == Status.Inactive && e is Customer {Name: "Customer 1"});
Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"});
Assert.Single(entities, e => e is Order {OrderReference: "Order 1"});
Assert.Single(entities, e => e is Supplier {Name: "Supplier 1"});
Expand Down Expand Up @@ -225,9 +242,11 @@ public void Returns_correct_entity_with_OfType_query()

using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.OfType<Customer>().ToList();
Assert.Single(entities, e => e.Name == "Customer 1");
Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Active});
Assert.Single(entities, e => e is {Name: "Customer 2", Status: Status.Active});
Assert.Single(entities, e => e is {Name: "Customer 1", Status: Status.Inactive});
Assert.Single(entities, e => e.Name == "SubCustomer 1");
Assert.Equal(2, entities.Count);
Assert.Equal(4, entities.Count);
}

[Fact]
Expand All @@ -238,12 +257,31 @@ public void Returns_correct_entities_with_mixed_query()

using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);
var entities = db.Entities.OfType<BaseEntity>().Where(e => e is Customer || e.GetType() == typeof(Order)).ToList();
Assert.Equal(3, entities.Count);
Assert.Single(entities, e => e is Customer {Name: "Customer 1"});
Assert.Equal(5, entities.Count);
Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Active});
Assert.Single(entities, e => e is Customer {Name: "Customer 1", Status: Status.Inactive});
Assert.Single(entities, e => e is Customer {Name: "Customer 2", Status: Status.Active});
Assert.Single(entities, e => e is SubCustomer {Name: "SubCustomer 1"});
Assert.Single(entities, e => e is Order {OrderReference: "Order 1"});
}

[Fact]
public void OfType_does_not_break_entity_serializer_association()
{
var collection = database.CreateCollection<BaseEntity>();
SetupTestData(SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel));

using var db = SingleEntityDbContext.Create(collection, RealPropertyConfiguredModel);

var allActiveCustomers = db.Entities.OfType<Customer>().Where(e => e.Status == Status.Active);
Assert.All(allActiveCustomers, f => Assert.Equal(Status.Active, f.Status));

var activeCustomer1 = db.Entities.Where(e => e.Status == Status.Active).OfType<Customer>()
.Single(c => c.Name == "Customer 1");
Assert.Equal("Customer 1", activeCustomer1.Name);
Assert.Equal(Status.Active, activeCustomer1.Status);
}

[Fact]
public void TablePerType_throws_NotSupportedException()
{
Expand Down Expand Up @@ -282,6 +320,7 @@ private static void RealPropertyConfiguredModel(ModelBuilder mb)
.HasValue<Order>("Order")
.HasValue<OrderWithProducts>("OrderEx")
.HasValue<Contact>("Contact");
mb.Entity<BaseEntity>().Property(e => e.Status).HasConversion<string>(e => e.ToString(), s => Enum.Parse<Status>(s));
}

private static void ShadowPropertyConfiguredModel(ModelBuilder mb)
Expand All @@ -294,11 +333,14 @@ private static void ShadowPropertyConfiguredModel(ModelBuilder mb)
.HasValue<Order>("Order")
.HasValue<OrderWithProducts>("OrderEx")
.HasValue<Contact>("Contact");
mb.Entity<BaseEntity>().Property(e => e.Status).HasConversion<string>(e => e.ToString(), s => Enum.Parse<Status>(s));
}

private static void SetupTestData(DbContext db)
{
db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St"});
db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Active});
db.Add(new Customer {Name = "Customer 1", ShippingAddress = "123 Main St", Status = Status.Inactive});
db.Add(new Customer {Name = "Customer 2", ShippingAddress = "123 Main St", Status = Status.Active});
db.Add(new Supplier {Name = "Supplier 1", Products = ["Product 1", "Product 2"]});
db.Add(new SubCustomer {Name = "SubCustomer 1", ShippingAddress = "3.5 Inch Dr.", AccountingCode = 123});
db.Add(new Order {OrderReference = "Order 1"});
Expand All @@ -309,10 +351,18 @@ private static void SetupTestData(DbContext db)
db.Dispose();
}

enum Status
{
Active,
Inactive,
Unused
}

class BaseEntity
{
public ObjectId _id { get; set; }
public string? EntityType { get; set; }
public Status Status { get; set; } = Status.Inactive;
}

class Customer : BaseEntity
Expand Down

0 comments on commit 311e90e

Please sign in to comment.