Skip to content

Commit

Permalink
Merge pull request #25 from siewers/master
Browse files Browse the repository at this point in the history
Add support for combinatorial member data
  • Loading branch information
AArnott authored Jun 22, 2021
2 parents c1d9b33 + dce27bf commit d12b4ed
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
namespace Xunit.Combinatorial.Tests
{
using System;
using System.Collections.Generic;
using System.Reflection;

public class CombinatorialMemberDataAttributeTests
{
public void StubIntMethod(int p1)
{
}

public void StubGuidMethod(Guid p1)
{
}

[Fact]
public void EnumerableOfIntReturnsValues()
{
var attribute = new CombinatorialMemberDataAttribute(nameof(GetValuesAsEnumerableOfInt));
var testMethod = this.GetType().GetMethod(nameof(StubIntMethod));
var parameter = testMethod.GetParameters()[0];
var values = attribute.GetValues(parameter);
Assert.Equal(new object[] { 1, 2, 3, 4 }, values);
}

[Fact]
public void EnumerableOfArrayThrows()
{
var attribute = new CombinatorialMemberDataAttribute(nameof(GetValuesAsEnumerableOfIntArray));
var testMethod = this.GetType().GetMethod(nameof(StubIntMethod));
var parameter = testMethod.GetParameters()[0];

var exception = Assert.Throws<ArgumentException>(() => attribute.GetValues(parameter));
Assert.Equal("Member GetValuesAsEnumerableOfIntArray on Xunit.Combinatorial.Tests.CombinatorialMemberDataAttributeTests returned an IEnumerable<object[]>, which is not supported", exception.Message);
}

[Fact]
public void EnumerableOfGuidReturnsValue()
{
var attribute = new CombinatorialMemberDataAttribute(nameof(GetValuesAsEnumerableOfGuid));
var testMethod = this.GetType().GetMethod(nameof(StubGuidMethod));
var parameter = testMethod.GetParameters()[0];
var values = attribute.GetValues(parameter);

Assert.Contains(values, obj => (Guid)obj != Guid.Empty);
}

[Fact]
public void IncompatibleMemberDataTypeThrows()
{
var attribute = new CombinatorialMemberDataAttribute(nameof(GetValuesAsEnumerableOfGuid));
var testMethod = this.GetType().GetMethod(nameof(StubIntMethod));
var parameter = testMethod.GetParameters()[0];

var exception = Assert.Throws<ArgumentException>(() => attribute.GetValues(parameter));
Assert.Equal("Parameter type System.Int32 is not compatible with returned member type System.Guid", exception.Message);
}

public static IEnumerable<int> GetValuesAsEnumerableOfInt()
{
yield return 1;
yield return 2;
yield return 3;
yield return 4;
}

public static IEnumerable<int[]> GetValuesAsEnumerableOfIntArray()
{
yield return new[] { 1 };
yield return new[] { 2 };
yield return new[] { 3 };
yield return new[] { 4 };
yield return new[] { 5 };
}

public static IEnumerable<Guid> GetValuesAsEnumerableOfGuid()
{
yield return Guid.NewGuid();
yield return Guid.NewGuid();
yield return Guid.NewGuid();
yield return Guid.NewGuid();
yield return Guid.NewGuid();
}
}
}
79 changes: 79 additions & 0 deletions src/Xunit.Combinatorial.Tests/CombinatorialMemberDataSampleUses.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Xunit.Combinatorial.Tests
{
public class CombinatorialMemberDataSampleUses
{
private static readonly Random Random = new Random();

public static readonly IEnumerable<int> IntFieldValues = Enumerable.Range(0, 5).Select(_ => Random.Next());
public static readonly IEnumerable<Guid> GuidFieldValues = Enumerable.Range(0, 5).Select(_ => Guid.NewGuid());

public static IEnumerable<int> IntPropertyValues => GetIntMethodValues();

public static IEnumerable<Guid> GuidPropertyValues => GetGuidMethodValues();

[Theory, CombinatorialData]
public void CombinatorialMemberDataFromParameterizedMethods(
[CombinatorialMemberData(nameof(GetIntRange), 0, 5)] int p1,
[CombinatorialMemberData(nameof(GetGuidRange), 5)] Guid p2)
{
Assert.True(true);
}

[Theory, CombinatorialData]
public void CombinatorialMemberDataFromProperties(
[CombinatorialMemberData(nameof(GuidPropertyValues))] Guid p1,
[CombinatorialMemberData(nameof(IntPropertyValues))] int p2)
{
Assert.True(true);
}

[Theory, CombinatorialData]
public void CombinatorialMemberDataFromMethods(
[CombinatorialMemberData(nameof(GetGuidMethodValues))] Guid p1,
[CombinatorialMemberData(nameof(GetIntMethodValues))] int p2)
{
Assert.True(true);
}

[Theory, CombinatorialData]
public void CombinatorialMemberDataFromFields(
[CombinatorialMemberData(nameof(GuidFieldValues))] Guid p1,
[CombinatorialMemberData(nameof(IntFieldValues))] int p2)
{
Assert.True(true);
}

public static IEnumerable<int> GetIntMethodValues()
{
for (var i = 0; i < 5; i++)
{
yield return Random.Next();
}
}

public static IEnumerable<Guid> GetGuidMethodValues()
{
for (var i = 0; i < 5; i++)
{
yield return Guid.NewGuid();
}
}

public static IEnumerable<int> GetIntRange(int start, int count)
{
return Enumerable.Range(start, count);
}

public static IEnumerable<Guid> GetGuidRange(int count)
{
for (var i = 0; i < count; i++)
{
yield return Guid.NewGuid();
}
}
}
}
202 changes: 202 additions & 0 deletions src/Xunit.Combinatorial/CombinatorialMemberDataAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
namespace Xunit
{
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

/// <summary>
/// Specifies which member should provide data for this parameter used for running the test method.
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = true)]
public class CombinatorialMemberDataAttribute : Attribute
{
/// <summary>
/// Initializes a new instance of the <see cref="CombinatorialMemberDataAttribute"/> class.
/// </summary>
/// <param name="memberName">The name of the public static member on the test class that will provide the test data</param>
/// <param name="parameters">The parameters for the member (only supported for methods; ignored for everything else)</param>
public CombinatorialMemberDataAttribute(string memberName, params object[] parameters)
{
this.MemberName = memberName ?? throw new ArgumentNullException(nameof(memberName));
this.Parameters = parameters;
}

/// <summary>
/// Gets the member name.
/// </summary>
public string MemberName { get; }

/// <summary>
/// Gets or sets the type to retrieve the member from. If not set, then the property will be
/// retrieved from the unit test class.
/// </summary>
public Type MemberType { get; set; }

/// <summary>
/// Gets or sets the parameters passed to the member. Only supported for static methods.
/// </summary>
public object[] Parameters { get; }

/// <summary>
/// Gets the values that should be passed to this parameter on the test method.
/// </summary>
/// <param name="parameterInfo">The parameter for which the data should be provided</param>
/// <returns>An array of values.</returns>
public object[] GetValues(ParameterInfo parameterInfo)
{
var testMethod = parameterInfo.Member;

var type = this.MemberType ?? testMethod?.DeclaringType;

if (type == null)
{
return new object[0];
}

var accessor = this.GetPropertyAccessor(type, parameterInfo) ?? this.GetMethodAccessor(type, parameterInfo) ?? this.GetFieldAccessor(type, parameterInfo);
if (accessor == null)
{
var parameterText = this.Parameters?.Length > 0 ? $" with parameter types: {string.Join(", ", this.Parameters.Select(p => p?.GetType().FullName ?? "(null)"))}" : string.Empty;
throw new ArgumentException($"Could not find public static member (property, field, or method) named '{this.MemberName}' on {type.FullName}{parameterText}");
}

var obj = (IEnumerable)accessor();
return obj.Cast<object>().ToArray();
}

private Func<object> GetPropertyAccessor(Type type, ParameterInfo parameterInfo)
{
PropertyInfo propInfo = null;
for (var reflectionType = type; reflectionType != null; reflectionType = reflectionType.GetTypeInfo().BaseType)
{
propInfo = reflectionType.GetRuntimeProperty(this.MemberName);
if (propInfo != null)
{
break;
}
}

if (propInfo?.GetMethod == null || !propInfo.GetMethod.IsStatic)
{
return null;
}

this.EnsureValidMemberDataType(propInfo.PropertyType, propInfo.DeclaringType, parameterInfo);

return () => propInfo.GetValue(null, null);
}

private Func<object> GetMethodAccessor(Type type, ParameterInfo parameterInfo)
{
MethodInfo methodInfo = null;
var parameterTypes = this.Parameters == null
? new Type[0]
: this.Parameters.Select(p => p.GetType()).ToArray();
for (var reflectionType = type; reflectionType != null; reflectionType = reflectionType.GetTypeInfo().BaseType)
{
methodInfo = reflectionType.GetRuntimeMethods().FirstOrDefault(m => m.Name == this.MemberName && this.ParameterTypesCompatible(m.GetParameters(), parameterTypes));

if (methodInfo != null)
{
break;
}
}

if (methodInfo == null || !methodInfo.IsStatic)
{
return null;
}

this.EnsureValidMemberDataType(methodInfo.ReturnType, methodInfo.DeclaringType, parameterInfo);

return () => methodInfo.Invoke(null, this.Parameters);
}

private bool ParameterTypesCompatible(ParameterInfo[] parameters, Type[] parameterTypes)
{
if (parameters.Length != parameterTypes.Length)
{
return false;
}

for (var i = 0; i < parameters.Length; i++)
{
if (parameterTypes[i] != null && !parameters[i].ParameterType.GetTypeInfo()
.IsAssignableFrom(parameterTypes[i].GetTypeInfo()))
{
return false;
}
}

return true;
}

private Func<object> GetFieldAccessor(Type type, ParameterInfo parameterInfo)
{
FieldInfo fieldInfo = null;
for (var reflectionType = type; reflectionType != null; reflectionType = reflectionType.GetTypeInfo().BaseType)
{
fieldInfo = reflectionType.GetRuntimeField(this.MemberName);

if (fieldInfo != null)
{
break;
}
}

if (fieldInfo == null || !fieldInfo.IsStatic)
{
return null;
}

this.EnsureValidMemberDataType(fieldInfo.FieldType, fieldInfo.DeclaringType, parameterInfo);

return () => fieldInfo.GetValue(null);
}

private void EnsureValidMemberDataType(Type type, Type declaringType, ParameterInfo parameterType)
{
var enumerableTypeInfo = typeof(IEnumerable).GetTypeInfo();

if (!enumerableTypeInfo.IsAssignableFrom(type.GetTypeInfo()))
{
throw new ArgumentException($"Member {this.MemberName} on {type.FullName} did not return IEnumerable");
}

var enumerableGenericType = this.GetEnumerableType(type);
if (enumerableTypeInfo.IsAssignableFrom(enumerableGenericType))
{
throw new ArgumentException(
$"Member {this.MemberName} on {declaringType.FullName} returned an IEnumerable<object[]>, which is not supported");
}

if (!enumerableGenericType.IsAssignableFrom(parameterType.ParameterType.GetTypeInfo()))
{
throw new ArgumentException(
$"Parameter type {parameterType.ParameterType.FullName} is not compatible with returned member type {enumerableGenericType.FullName}");
}
}

private TypeInfo GetEnumerableType(Type enumerableType)
{
var enumerableGenericTypeDefinition = enumerableType.GetTypeInfo().GetGenericArguments();
if (enumerableGenericTypeDefinition != null)
{
return enumerableGenericTypeDefinition[0].GetTypeInfo();
}

foreach (var implementedInterface in enumerableType.GetTypeInfo().ImplementedInterfaces)
{
var interfaceTypeInfo = implementedInterface.GetTypeInfo();
if (interfaceTypeInfo.IsGenericType && interfaceTypeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
return interfaceTypeInfo.GetGenericArguments()[0].GetTypeInfo();
}
}

return null;
}
}
}
6 changes: 6 additions & 0 deletions src/Xunit.Combinatorial/ValuesUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ internal static IEnumerable<object> GetValuesFor(ParameterInfo parameter)
return rangeAttribute.Values;
}

var memberDataValuesAttribute = parameter.GetCustomAttribute<CombinatorialMemberDataAttribute>();
if (memberDataValuesAttribute != null)
{
return memberDataValuesAttribute.GetValues(parameter);
}

return GetValuesFor(parameter.ParameterType);
}

Expand Down

0 comments on commit d12b4ed

Please sign in to comment.