Skip to content

Commit

Permalink
EF-42: Shadow property support (#168)
Browse files Browse the repository at this point in the history
* Initial shadow property support.

* Use Mql.Field for any EF.Property reference.

* Fallback on blank CallerAttributeName prefix for collection naming in functional tests.
  • Loading branch information
damieng authored Dec 11, 2024
1 parent a5e9164 commit 54d151b
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public override void Validate(IModel model, IDiagnosticsLogger<DbLoggerCategory.
ValidateMaximumOneRowVersionPerEntity(model);
ValidateNoUnsupportedAttributesOrAnnotations(model);
ValidateElementNames(model);
ValidateNoUnsupportedShadowProperties(model);
ValidateNoMutableKeys(model, logger);
ValidatePrimaryKeys(model);
}
Expand Down Expand Up @@ -329,29 +328,6 @@ private static void ValidateEntityElementNames(IEntityType entityType)
}
}

/// <summary>
/// Validate that no entities have shadow unsupported properties.
/// </summary>
/// <param name="model">The <see cref="IModel"/> to validate for whether shadow properties are present.</param>
/// <exception cref="NotSupportedException">Thrown when unsupported shadow properties are found on an entity.</exception>
private static void ValidateNoUnsupportedShadowProperties(IModel model)
{
foreach (var entityType in model.GetEntityTypes())
{
var discriminatorProperty = entityType.FindDiscriminatorProperty();
var shadowProperty = entityType.GetProperties()
.FirstOrDefault(property =>
property != discriminatorProperty
&& property.IsShadowProperty()
&& !property.IsOwnedTypeKey());
if (shadowProperty != null)
{
throw new NotSupportedException(
$"Unsupported shadow property '{shadowProperty.Name}' identified on entity type '{entityType.DisplayName()}'.");
}
}
}

/// <summary>
/// Validates that the only keys that can actually be changed are shadow keys used by owned entities.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void ProcessPropertyAdded(
IConventionPropertyBuilder propertyBuilder,
IConventionContext<IConventionPropertyBuilder> context)
{
if (propertyBuilder.Metadata.Name == "_id" || propertyBuilder.Metadata.IsShadowProperty()) return;
if (propertyBuilder.Metadata.Name == "_id" || propertyBuilder.Metadata.IsOwnedCollectionShadowKey()) return;

propertyBuilder.HasElementName(propertyBuilder.Metadata.Name.ToCamelCase(CultureInfo.CurrentCulture));
}
Expand All @@ -50,8 +50,6 @@ public void ProcessNavigationAdded(
IConventionNavigationBuilder navigationBuilder,
IConventionContext<IConventionNavigationBuilder> context)
{
if (navigationBuilder.Metadata.IsShadowProperty()) return;

var name = navigationBuilder.Metadata.Name.ToCamelCase(CultureInfo.CurrentCulture);
navigationBuilder.Metadata.TargetEntityType.SetAnnotation(MongoAnnotationNames.ElementName, name);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;
using MongoDB.Bson;
using MongoDB.Bson.Serialization.Serializers;
using MongoDB.Driver;
using MongoDB.Driver.Linq;
using MongoDB.EntityFrameworkCore.Extensions;
using MongoDB.EntityFrameworkCore.Serializers;

namespace MongoDB.EntityFrameworkCore.Query.Visitors;

Expand All @@ -30,15 +34,20 @@ namespace MongoDB.EntityFrameworkCore.Query.Visitors;
/// </summary>
internal sealed class MongoEFToLinqTranslatingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo MqlFieldMethodInfo = typeof(Mql).GetMethod(nameof(Mql.Field), BindingFlags.Public | BindingFlags.Static)!;

private readonly QueryContext _queryContext;
private readonly Expression _source;
private readonly BsonSerializerFactory _bsonSerializerFactory;

internal MongoEFToLinqTranslatingExpressionVisitor(
QueryContext queryContext,
Expression source)
Expression source,
BsonSerializerFactory bsonSerializerFactory)
{
_queryContext = queryContext;
_source = source;
_bsonSerializerFactory = bsonSerializerFactory;
}

public MethodCallExpression Translate(
Expand Down Expand Up @@ -100,26 +109,72 @@ private static MethodCallExpression InjectAsBsonDocumentMethod(
// Replace object.Equals(Property(p, "propName"), ConstantExpression) elements generated by EF's Find.
case MethodCallExpression {Method.Name: nameof(object.Equals), Object: null, Arguments.Count: 2} methodCallExpression:
var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]))!;
var leftType = methodCallExpression.Method.GetParameters()[0].ParameterType;
var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1]))!;
var rightType = methodCallExpression.Method.GetParameters()[0].ParameterType;
var method = methodCallExpression.Method;

if (left.Type == right.Type)
{
return Expression.Equal(RemoveObjectConvert(left), RemoveObjectConvert(right));
}

return Expression.Call(null, methodCallExpression.Method, ConvertIfRequired(left, leftType),
ConvertIfRequired(right, rightType));
var parameters = method.GetParameters();
left = ConvertIfRequired(left, parameters[0].ParameterType);
right = ConvertIfRequired(right, parameters[1].ParameterType);
return Expression.Call(null, method, left, right);

// Replace EF-generated Property(p, "propName") with p.propName.
// Replace EF-generated Property(p, "propName") with Property(p.propName) or Mql.Field(p, "propName", serializer)
case MethodCallExpression methodCallExpression
when methodCallExpression.Method.IsEFPropertyMethod()
&& methodCallExpression.Arguments[1] is ConstantExpression propertyNameExpression:
var source = Visit(methodCallExpression.Arguments[0])
?? throw new InvalidOperationException("Unsupported source to EF.Property expression.");
var property = source.Type.GetProperties()
.First(prop => prop.Name == propertyNameExpression.GetConstantValue<string>());
var propertyExpression = Expression.Property(source, property);

return methodCallExpression.Method.ReturnType != property.PropertyType
? Expression.Convert(propertyExpression, methodCallExpression.Method.ReturnType)
: propertyExpression;
var propertyName = propertyNameExpression.GetConstantValue<string>();
var entityType = _queryContext.Context.Model.FindEntityType(source.Type);
if (entityType != null)
{
// Try an EF property
var efProperty = entityType.FindProperty(propertyName);
if (efProperty != null)
{
var elementName = efProperty.IsPrimaryKey() && entityType.FindPrimaryKey()?.Properties.Count > 1
? "_id." + efProperty.GetElementName()
: efProperty.GetElementName();
var mqlField = MqlFieldMethodInfo.MakeGenericMethod(source.Type, efProperty.ClrType);
var serializer = BsonSerializerFactory.CreateTypeSerializer(efProperty);
var callExpression = Expression.Call(null, mqlField, source,
Expression.Constant(elementName),
Expression.Constant(serializer));
return ConvertIfRequired(callExpression, methodCallExpression.Method.ReturnType);
}

// Try an EF navigation if no property
var efNavigation = entityType.FindNavigation(propertyName);
if (efNavigation != null)
{
var elementName = efNavigation.TargetEntityType.GetContainingElementName();
var mqlField = MqlFieldMethodInfo.MakeGenericMethod(source.Type, efNavigation.ClrType);
var serializer = _bsonSerializerFactory.GetNavigationSerializer(efNavigation);
var callExpression = Expression.Call(null, mqlField, source,
Expression.Constant(elementName),
Expression.Constant(serializer));
return ConvertIfRequired(callExpression, methodCallExpression.Method.ReturnType);
}
}

// Try CLR property
// This should not really be required but is kept here for backwards compatibility with any edge cases.
var clrProperty = source.Type.GetProperties().FirstOrDefault(p => p.Name == propertyName);
if (clrProperty != null)
{
var propertyExpression = Expression.Property(source, clrProperty);
return ConvertIfRequired(propertyExpression, methodCallExpression.Method.ReturnType);
}

return VisitMethodCall(methodCallExpression);

case MethodCallExpression {Arguments.Count: > 0} methodCallExpression when methodCallExpression.Arguments[0] is EntityQueryRootExpression e:
return base.Visit(expression);

// Unwrap include expressions.
case IncludeExpression includeExpression:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ private static QueryingEnumerable<TResult, TResult> TranslateAndExecuteUnshapedQ
var collection = mongoQueryContext.MongoClient.GetCollection<TSource>(queryExpression.CollectionExpression.CollectionName);
var source = collection.AsQueryable().As(serializer);

var queryTranslator = new MongoEFToLinqTranslatingExpressionVisitor(queryContext, source.Expression);
var queryTranslator = new MongoEFToLinqTranslatingExpressionVisitor(queryContext, source.Expression, bsonSerializerFactory);
var translatedQuery = queryTranslator.Visit(queryExpression.CapturedExpression)!;

var executableQuery =
new MongoExecutableQuery(translatedQuery, resultCardinality, (IMongoQueryProvider)source.Provider, collection.CollectionNamespace);
new MongoExecutableQuery(translatedQuery, resultCardinality, (IMongoQueryProvider)source.Provider,
collection.CollectionNamespace);

return new QueryingEnumerable<TResult, TResult>(
mongoQueryContext,
Expand All @@ -160,11 +161,11 @@ private static QueryingEnumerable<BsonDocument, TResult> TranslateAndExecuteQuer
var collection = mongoQueryContext.MongoClient.GetCollection<TSource>(queryExpression.CollectionExpression.CollectionName);
var source = collection.AsQueryable().As((IBsonSerializer<TSource>)bsonSerializerFactory.GetEntitySerializer(entityType));

var queryTranslator = new MongoEFToLinqTranslatingExpressionVisitor(queryContext, source.Expression);
var queryTranslator = new MongoEFToLinqTranslatingExpressionVisitor(queryContext, source.Expression, bsonSerializerFactory);
var translatedQuery = queryTranslator.Translate(queryExpression.CapturedExpression, resultCardinality);

var executableQuery =
new MongoExecutableQuery(translatedQuery, resultCardinality, (IMongoQueryProvider)source.Provider, collection.CollectionNamespace);
var executableQuery = new MongoExecutableQuery(translatedQuery, resultCardinality, (IMongoQueryProvider)source.Provider,
collection.CollectionNamespace);

return new QueryingEnumerable<BsonDocument, TResult>(
mongoQueryContext,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright 2023-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
using MongoDB.Bson;
using MongoDB.Driver;
using MongoDB.EntityFrameworkCore.Metadata.Conventions;

namespace MongoDB.EntityFrameworkCore.FunctionalTests.Mapping;

[XUnitCollection("MappingTests")]
public class NavigationProxyTests(TemporaryDatabaseFixture database)
: IClassFixture<TemporaryDatabaseFixture>
{
[Fact]
public void Navigation_proxy_single_can_use_shadow_property()
{
var originalAuthor = new Author {Name = "Damien"};
var originalPost = new Post {Title = "Navigation proxy", Author = originalAuthor};

{
var db = new BloggingContext(For(database.MongoDatabase).UseLazyLoadingProxies().Options);
db.Authors.Add(originalAuthor);
db.Posts.Add(originalPost);
db.SaveChanges();
}

{
var db = new BloggingContext(For(database.MongoDatabase).UseLazyLoadingProxies().Options);
var foundPost = db.Posts.First(p => p.Id == originalPost.Id);
Assert.Equal(originalAuthor.Name, foundPost.Author.Name);
}
}

[Fact]
public void Navigation_proxy_collection_loads()
{
var blog = new Blog { Name = "By Proxy" };
var originalAuthor1 = new Author {Name = "Damien"};
var originalPost1 = new Post {Title = "Navigation 1", Author = originalAuthor1, Blog = blog};
var originalAuthor2 = new Author {Name = "Henry"};
var originalPost2 = new Post {Title = "Navigation 2", Author = originalAuthor2, Blog = blog};

{
var db = new BloggingContext(For(database.MongoDatabase).UseLazyLoadingProxies().Options);
db.AddRange(originalAuthor1, originalAuthor2, originalPost1, originalPost2, blog);
db.SaveChanges();
}

{
var db = new BloggingContext(For(database.MongoDatabase).UseLazyLoadingProxies().Options);
var foundBlog = db.Blogs.First(b => b.Id == blog.Id);
Assert.Equal(2, foundBlog.Posts.Count);
Assert.Single(foundBlog.Posts, p => p.Author.Name == originalAuthor1.Name);
Assert.Single(foundBlog.Posts, p => p.Author.Name == originalAuthor2.Name);
}
}

public class BloggingContext(DbContextOptions options, Action<ModelBuilder>? mb = null)
: DbContext(options)
{
public DbSet<Author> Authors { get; set; }
public DbSet<Post> Posts { get; set; }
public DbSet<Blog> Blogs { get; set; }

protected override void ConfigureConventions(ModelConfigurationBuilder cb)
{
base.ConfigureConventions(cb);
cb.Conventions.Add(_ => new CamelCaseElementNameConvention());
}

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
base.OnModelCreating(modelBuilder);
mb?.Invoke(modelBuilder);
}
}

public static DbContextOptionsBuilder<BloggingContext> For(IMongoDatabase mongoDatabase) =>
new DbContextOptionsBuilder<BloggingContext>()
.UseMongoDB(mongoDatabase.Client, mongoDatabase.DatabaseNamespace.DatabaseName)
.ConfigureWarnings(x => x.Ignore(CoreEventId.ManyServiceProvidersCreatedWarning));

public class Post
{
public ObjectId Id { get; set; }
public string Title { get; set; }
public virtual Author Author { get; set; }
public virtual Blog Blog { get; set; }
}

public class Author
{
public ObjectId Id { get; set; }
public string Name { get; set; }
}

public class Blog
{
public ObjectId Id { get; set; }
public string Name { get; set; }
public virtual List<Post> Posts { get; set; } = new();
}
}
Loading

0 comments on commit 54d151b

Please sign in to comment.