From 5899c468ec58c1126eae5c953d1bc3a1b66ca7f2 Mon Sep 17 00:00:00 2001 From: Matthias Friedrich <1573457+matzefriedrich@users.noreply.github.com> Date: Fri, 26 Jul 2024 02:48:39 +0200 Subject: [PATCH] Feature/support non interface types (#10) Adds support for non-interface types (both registration and resolving) * Allows reflect.Pointer type for registration and activation * Refactors service_type.go module (adds struct with name and reflected type to replace all usages of refect.Type) * Adds new tests --- CHANGELOG.md | 6 ++ README.md | 2 +- internal/core/function_info.go | 30 ++++++--- .../tests/registry_register_instance_test.go | 63 +++++++++++++++++++ internal/tests/registry_test.go | 10 +-- internal/tests/resolver_options_test.go | 3 +- internal/tests/resolver_test.go | 2 +- pkg/registration/activator.go | 1 + pkg/registration/dependency.go | 3 +- pkg/registration/registry.go | 16 ++--- pkg/registration/registry_accessor.go | 5 +- pkg/registration/service_registration.go | 16 ++--- pkg/registration/service_type.go | 7 --- pkg/resolving/resolver.go | 7 ++- pkg/types/service_type.go | 44 +++++++++++++ pkg/types/types.go | 27 +++++--- 16 files changed, 185 insertions(+), 57 deletions(-) create mode 100644 internal/tests/registry_register_instance_test.go delete mode 100644 pkg/registration/service_type.go create mode 100644 pkg/types/service_type.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 104419f..51845e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Adds the `Activate[T]` method which can resolve an instance from an unregistered activator func +* Allows registration and activation of pointer types + +### Changes + +* Renames the `ServiceType[T]` method to `MakeServiceType[T]`; adds service type representing the reflected type and typename +* Replaces all usages of `reflect.Type` by `ServiceType` in all Parsley interfaces ## v0.5.0 - 2024-07-16 diff --git a/README.md b/README.md index 2875322..510f9e6 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Though dependency injection is less prevalent in Golang (compared to other langu ## Features - ✔️ Register types via constructor functions -- ✔️ Resolve objects by interface +- ✔️ Resolve objects by type (both interface and pointer type) - ✔️ Constructor injection - ⏳ Injection via field initialization (requires annotation) - ❌ Injection via setter methods diff --git a/internal/core/function_info.go b/internal/core/function_info.go index 3422d74..bb0180c 100644 --- a/internal/core/function_info.go +++ b/internal/core/function_info.go @@ -2,14 +2,16 @@ package core import ( "errors" + "fmt" "github.com/matzefriedrich/parsley/pkg/types" "reflect" + "strings" ) type functionInfo struct { funcType reflect.Type - returnType reflect.Type - parameterTypes []reflect.Type + returnType types.ServiceType + parameterTypes []types.ServiceType } var _ types.FunctionInfo = &functionInfo{} @@ -35,29 +37,39 @@ func (f functionInfo) Name() string { return f.funcType.Name() } -func (f functionInfo) ParameterTypes() []reflect.Type { +func (f functionInfo) ParameterTypes() []types.ServiceType { return f.parameterTypes } -func (f functionInfo) ReturnType() reflect.Type { +func (f functionInfo) ReturnType() types.ServiceType { return f.returnType } -func returnType(funcType reflect.Type) (reflect.Type, error) { +func (f functionInfo) String() string { + parameterTypeNames := make([]string, len(f.parameterTypes)) + for _, t := range f.parameterTypes { + parameterTypeNames = append(parameterTypeNames, t.Name()) + } + funcTypeName := f.returnType.ReflectedType().Elem().String() + return fmt.Sprintf("f(%s) %s", strings.Join(parameterTypeNames, ","), funcTypeName) +} + +func returnType(funcType reflect.Type) (types.ServiceType, error) { numReturnValues := funcType.NumOut() if numReturnValues != 1 { return nil, errors.New("return type has to have exactly one return value") } serviceType := funcType.Out(0) - return serviceType, nil + return types.ServiceTypeFrom(serviceType), nil } -func parameterTypes(funcType reflect.Type) []reflect.Type { - parameters := make([]reflect.Type, 0) +func parameterTypes(funcType reflect.Type) []types.ServiceType { + parameters := make([]types.ServiceType, 0) numParameters := funcType.NumIn() for i := 0; i < numParameters; i++ { parameterType := funcType.In(i) - parameters = append(parameters, parameterType) + serviceType := types.ServiceTypeFrom(parameterType) + parameters = append(parameters, serviceType) } return parameters } diff --git a/internal/tests/registry_register_instance_test.go b/internal/tests/registry_register_instance_test.go new file mode 100644 index 0000000..02ac695 --- /dev/null +++ b/internal/tests/registry_register_instance_test.go @@ -0,0 +1,63 @@ +package tests + +import ( + "context" + "github.com/matzefriedrich/parsley/pkg/registration" + "github.com/matzefriedrich/parsley/pkg/resolving" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_Registry_RegisterInstance_accepts_pointer(t *testing.T) { + + // Arrange + registry := registration.NewServiceRegistry() + + options := NewOptions("value") + + // Act + err := registration.RegisterInstance(registry, options) + resolver := resolving.NewResolver(registry) + actual, _ := resolving.ResolveRequiredService[*someOptions](resolver, resolving.NewScopedContext(context.Background())) + + // Assert + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, options.value, actual.value) +} + +func Test_Registry_RegisterInstance_resolve_object_with_pointer_dependency(t *testing.T) { + + // Arrange + registry := registration.NewServiceRegistry() + + options := NewOptions("value") + _ = registration.RegisterInstance(registry, options) + _ = registration.RegisterTransient(registry, NewOptionsConsumer) + + resolver := resolving.NewResolver(registry) + + // Act + actual, _ := resolving.ResolveRequiredService[*optionsConsumer](resolver, resolving.NewScopedContext(context.Background())) + + // Assert + assert.NotNil(t, actual) +} + +type someOptions struct { + value string +} + +func NewOptions(value string) *someOptions { + return &someOptions{value} +} + +type optionsConsumer struct { + options *someOptions +} + +func NewOptionsConsumer(options *someOptions) *optionsConsumer { + return &optionsConsumer{ + options: options, + } +} diff --git a/internal/tests/registry_test.go b/internal/tests/registry_test.go index 3388b93..577f9e5 100644 --- a/internal/tests/registry_test.go +++ b/internal/tests/registry_test.go @@ -22,8 +22,8 @@ func Test_ServiceRegistry_register_types_with_different_lifetime_behavior(t *tes _ = registration.RegisterSingleton(sut, NewFoo) _ = registration.RegisterTransient(sut, NewFooConsumer) - fooRegistered := sut.IsRegistered(registration.ServiceType[Foo]()) - fooConsumerRegistered := sut.IsRegistered(registration.ServiceType[FooConsumer]()) + fooRegistered := sut.IsRegistered(types.MakeServiceType[Foo]()) + fooConsumerRegistered := sut.IsRegistered(types.MakeServiceType[FooConsumer]()) // Assert assert.True(t, fooRegistered) @@ -115,8 +115,8 @@ func Test_Registry_RegisterModule_registers_collection_of_services(t *testing.T) _ = sut.RegisterModule(fooModule) - fooRegistered := sut.IsRegistered(registration.ServiceType[Foo]()) - fooConsumerRegistered := sut.IsRegistered(registration.ServiceType[FooConsumer]()) + fooRegistered := sut.IsRegistered(types.MakeServiceType[Foo]()) + fooConsumerRegistered := sut.IsRegistered(types.MakeServiceType[FooConsumer]()) // Assert assert.True(t, fooRegistered) @@ -133,7 +133,7 @@ func Test_Registry_RegisterInstance_registers_singleton_service_from_object(t *t // Act _ = registration.RegisterInstance(sut, instance) - fooRegistered := sut.IsRegistered(registration.ServiceType[Foo]()) + fooRegistered := sut.IsRegistered(types.MakeServiceType[Foo]()) r := resolving.NewResolver(sut) diff --git a/internal/tests/resolver_options_test.go b/internal/tests/resolver_options_test.go index 69ac45a..ecc0f57 100644 --- a/internal/tests/resolver_options_test.go +++ b/internal/tests/resolver_options_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/matzefriedrich/parsley/pkg/registration" "github.com/matzefriedrich/parsley/pkg/resolving" + "github.com/matzefriedrich/parsley/pkg/types" "github.com/stretchr/testify/assert" "testing" ) @@ -20,7 +21,7 @@ func Test_Resolver_ResolveWithOptions_inject_unregistered_service_instance(t *te r := resolving.NewResolver(sut) parsleyContext := resolving.NewScopedContext(context.Background()) - consumers, _ := r.ResolveWithOptions(parsleyContext, registration.ServiceType[FooConsumer](), resolving.WithInstance[Foo](NewFoo())) + consumers, _ := r.ResolveWithOptions(parsleyContext, types.MakeServiceType[FooConsumer](), resolving.WithInstance[Foo](NewFoo())) assert.Equal(t, 1, len(consumers)) consumer1 := consumers[0] diff --git a/internal/tests/resolver_test.go b/internal/tests/resolver_test.go index f5e5860..147c0a1 100644 --- a/internal/tests/resolver_test.go +++ b/internal/tests/resolver_test.go @@ -21,7 +21,7 @@ func Test_Resolver_Resolve_returns_err_if_circular_dependency_detected(t *testin scope := resolving.NewScopedContext(context.Background()) // Act - _, err := r.Resolve(scope, registration.ServiceType[fooBar]()) + _, err := r.Resolve(scope, types.MakeServiceType[fooBar]()) // Assert assert.ErrorIs(t, err, types.ErrCircularDependencyDetected) diff --git a/pkg/registration/activator.go b/pkg/registration/activator.go index 7da8628..1a675ca 100644 --- a/pkg/registration/activator.go +++ b/pkg/registration/activator.go @@ -14,6 +14,7 @@ func CreateServiceActivatorFrom[T any](instance T) (func() T, error) { switch t.Kind() { case reflect.Func: case reflect.Interface: + case reflect.Pointer: default: return nil, types.NewRegistryError(types.ErrorActivatorFunctionInvalidReturnType) } diff --git a/pkg/registration/dependency.go b/pkg/registration/dependency.go index 1554cb9..ffef1a9 100644 --- a/pkg/registration/dependency.go +++ b/pkg/registration/dependency.go @@ -2,7 +2,6 @@ package registration import ( "github.com/matzefriedrich/parsley/pkg/types" - "reflect" "sync" ) @@ -40,7 +39,7 @@ func (d *dependencyInfo) AddRequiredServiceInfo(child types.DependencyInfo) { d.children = append(d.children, child) } -func (d *dependencyInfo) RequiredServiceTypes() []reflect.Type { +func (d *dependencyInfo) RequiredServiceTypes() []types.ServiceType { return d.registration.RequiredServiceTypes() } diff --git a/pkg/registration/registry.go b/pkg/registration/registry.go index e7c3275..e5a7fcb 100644 --- a/pkg/registration/registry.go +++ b/pkg/registration/registry.go @@ -23,13 +23,13 @@ func RegisterSingleton(registry types.ServiceRegistry, activatorFunc any) error return registry.Register(activatorFunc, types.LifetimeSingleton) } -func (s *serviceRegistry) addOrUpdateServiceRegistrationListFor(serviceType reflect.Type) types.ServiceRegistrationList { - list, exists := s.registrations[serviceType] +func (s *serviceRegistry) addOrUpdateServiceRegistrationListFor(serviceType types.ServiceType) types.ServiceRegistrationList { + list, exists := s.registrations[serviceType.ReflectedType()] if exists { return list } list = NewServiceRegistrationList(s.identifierSource) - s.registrations[serviceType] = list + s.registrations[serviceType.ReflectedType()] = list return list } @@ -60,23 +60,23 @@ func (s *serviceRegistry) RegisterModule(modules ...types.ModuleFunc) error { return nil } -func (s *serviceRegistry) IsRegistered(serviceType reflect.Type) bool { - _, found := s.registrations[serviceType] +func (s *serviceRegistry) IsRegistered(serviceType types.ServiceType) bool { + _, found := s.registrations[serviceType.ReflectedType()] return found } -func (s *serviceRegistry) TryGetServiceRegistrations(serviceType reflect.Type) (types.ServiceRegistrationList, bool) { +func (s *serviceRegistry) TryGetServiceRegistrations(serviceType types.ServiceType) (types.ServiceRegistrationList, bool) { if s.IsRegistered(serviceType) == false { return nil, false } - list, found := s.registrations[serviceType] + list, found := s.registrations[serviceType.ReflectedType()] if found && list.IsEmpty() == false { return list, true } return nil, false } -func (s *serviceRegistry) TryGetSingleServiceRegistration(serviceType reflect.Type) (types.ServiceRegistration, bool) { +func (s *serviceRegistry) TryGetSingleServiceRegistration(serviceType types.ServiceType) (types.ServiceRegistration, bool) { list, found := s.TryGetServiceRegistrations(serviceType) if found && list.IsEmpty() == false { registrations := list.Registrations() diff --git a/pkg/registration/registry_accessor.go b/pkg/registration/registry_accessor.go index 4fb8731..7f775d0 100644 --- a/pkg/registration/registry_accessor.go +++ b/pkg/registration/registry_accessor.go @@ -2,14 +2,13 @@ package registration import ( "github.com/matzefriedrich/parsley/pkg/types" - "reflect" ) type multiRegistryAccessor struct { registries []types.ServiceRegistryAccessor } -func (m *multiRegistryAccessor) TryGetSingleServiceRegistration(serviceType reflect.Type) (types.ServiceRegistration, bool) { +func (m *multiRegistryAccessor) TryGetSingleServiceRegistration(serviceType types.ServiceType) (types.ServiceRegistration, bool) { for _, registry := range m.registries { registration, ok := registry.TryGetSingleServiceRegistration(serviceType) if ok { @@ -19,7 +18,7 @@ func (m *multiRegistryAccessor) TryGetSingleServiceRegistration(serviceType refl return nil, false } -func (m *multiRegistryAccessor) TryGetServiceRegistrations(serviceType reflect.Type) (types.ServiceRegistrationList, bool) { +func (m *multiRegistryAccessor) TryGetServiceRegistrations(serviceType types.ServiceType) (types.ServiceRegistrationList, bool) { for _, registry := range m.registries { registration, ok := registry.TryGetServiceRegistrations(serviceType) if ok { diff --git a/pkg/registration/service_registration.go b/pkg/registration/service_registration.go index cefca2f..4c48bc9 100644 --- a/pkg/registration/service_registration.go +++ b/pkg/registration/service_registration.go @@ -18,11 +18,11 @@ type serviceRegistration struct { } type typeInfo struct { - t reflect.Type + t types.ServiceType name string } -func newTypeInfo(t reflect.Type) typeInfo { +func newTypeInfo(t types.ServiceType) typeInfo { return typeInfo{ t: t, name: t.Name(), @@ -66,15 +66,15 @@ func (s *serviceRegistration) LifetimeScope() types.LifetimeScope { return s.lifetimeScope } -func (s *serviceRegistration) RequiredServiceTypes() []reflect.Type { - requiredTypes := make([]reflect.Type, len(s.parameters)) +func (s *serviceRegistration) RequiredServiceTypes() []types.ServiceType { + requiredTypes := make([]types.ServiceType, len(s.parameters)) for i, p := range s.parameters { requiredTypes[i] = p.t } return requiredTypes } -func (s *serviceRegistration) ServiceType() reflect.Type { +func (s *serviceRegistration) ServiceType() types.ServiceType { return s.serviceType.t } @@ -99,9 +99,11 @@ func CreateServiceRegistration(activatorFunc any, lifetimeScope types.LifetimeSc } serviceType := info.ReturnType() - switch serviceType.Kind() { + switch serviceType.ReflectedType().Kind() { case reflect.Func: return newServiceRegistration(serviceType, lifetimeScope, value), nil + case reflect.Pointer: + fallthrough case reflect.Interface: requiredTypes := info.ParameterTypes() return newServiceRegistration(serviceType, lifetimeScope, value, requiredTypes...), nil @@ -110,7 +112,7 @@ func CreateServiceRegistration(activatorFunc any, lifetimeScope types.LifetimeSc } } -func newServiceRegistration(serviceType reflect.Type, scope types.LifetimeScope, activatorFunc reflect.Value, parameters ...reflect.Type) *serviceRegistration { +func newServiceRegistration(serviceType types.ServiceType, scope types.LifetimeScope, activatorFunc reflect.Value, parameters ...types.ServiceType) *serviceRegistration { parameterTypeInfos := make([]typeInfo, len(parameters)) for i, p := range parameters { parameterTypeInfos[i] = newTypeInfo(p) diff --git a/pkg/registration/service_type.go b/pkg/registration/service_type.go deleted file mode 100644 index da1e629..0000000 --- a/pkg/registration/service_type.go +++ /dev/null @@ -1,7 +0,0 @@ -package registration - -import "reflect" - -func ServiceType[T any]() reflect.Type { - return reflect.TypeOf(new(T)).Elem() -} diff --git a/pkg/resolving/resolver.go b/pkg/resolving/resolver.go index 9365695..12d6471 100644 --- a/pkg/resolving/resolver.go +++ b/pkg/resolving/resolver.go @@ -19,10 +19,11 @@ func ResolveRequiredServices[T any](resolver types.Resolver, ctx context.Context switch t.Kind() { case reflect.Func: case reflect.Interface: + case reflect.Pointer: default: return []T{}, types.NewResolverError(types.ErrorActivatorFunctionInvalidReturnType) } - resolvedInstances, err := resolver.Resolve(ctx, registration.ServiceType[T]()) + resolvedInstances, err := resolver.Resolve(ctx, types.MakeServiceType[T]()) if err != nil { return []T{}, err } @@ -87,11 +88,11 @@ func (r *resolver) createResolverRegistryAccessor(resolverOptions ...types.Resol return r.registry, nil } -func (r *resolver) Resolve(ctx context.Context, serviceType reflect.Type) ([]interface{}, error) { +func (r *resolver) Resolve(ctx context.Context, serviceType types.ServiceType) ([]interface{}, error) { return r.ResolveWithOptions(ctx, serviceType) } -func (r *resolver) ResolveWithOptions(ctx context.Context, serviceType reflect.Type, resolverOptions ...types.ResolverOptionsFunc) ([]interface{}, error) { +func (r *resolver) ResolveWithOptions(ctx context.Context, serviceType types.ServiceType, resolverOptions ...types.ResolverOptionsFunc) ([]interface{}, error) { registry, registryErr := r.createResolverRegistryAccessor(resolverOptions...) if registryErr != nil { diff --git a/pkg/types/service_type.go b/pkg/types/service_type.go new file mode 100644 index 0000000..f7892f3 --- /dev/null +++ b/pkg/types/service_type.go @@ -0,0 +1,44 @@ +package types + +import ( + "reflect" +) + +type serviceType struct { + reflectedType reflect.Type + name string +} + +func (s serviceType) ReflectedType() reflect.Type { + return s.reflectedType +} + +func (s serviceType) Name() string { + return s.name +} + +func MakeServiceType[T any]() ServiceType { + elem := reflect.TypeOf(new(T)).Elem() + return &serviceType{ + reflectedType: elem, + name: elem.String(), + } +} + +func ServiceTypeFrom(t reflect.Type) ServiceType { + name := "" + switch t.Kind() { + case reflect.Ptr: + name = t.Elem().String() + case reflect.Interface: + name = t.Name() + case reflect.Func: + name = t.String() + default: + panic("unsupported type: " + t.String()) + } + return &serviceType{ + reflectedType: t, + name: name, + } +} diff --git a/pkg/types/types.go b/pkg/types/types.go index 25a1cd5..e19950b 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -2,20 +2,27 @@ package types import ( "context" + "fmt" "reflect" ) type FunctionInfo interface { + fmt.Stringer Name() string - ParameterTypes() []reflect.Type - ReturnType() reflect.Type + ParameterTypes() []ServiceType + ReturnType() ServiceType +} + +type ServiceType interface { + Name() string + ReflectedType() reflect.Type } type ServiceRegistry interface { ServiceRegistryAccessor CreateLinkedRegistry() ServiceRegistry CreateScope() ServiceRegistry - IsRegistered(serviceType reflect.Type) bool + IsRegistered(serviceType ServiceType) bool Register(activatorFunc any, scope LifetimeScope) error RegisterModule(modules ...ModuleFunc) error } @@ -23,8 +30,8 @@ type ServiceRegistry interface { type ModuleFunc func(registry ServiceRegistry) error type ServiceRegistryAccessor interface { - TryGetServiceRegistrations(serviceType reflect.Type) (ServiceRegistrationList, bool) - TryGetSingleServiceRegistration(serviceType reflect.Type) (ServiceRegistration, bool) + TryGetServiceRegistrations(serviceType ServiceType) (ServiceRegistrationList, bool) + TryGetSingleServiceRegistration(serviceType ServiceType) (ServiceRegistration, bool) } type ServiceRegistration interface { @@ -32,8 +39,8 @@ type ServiceRegistration interface { InvokeActivator(params ...interface{}) (interface{}, error) IsSame(other ServiceRegistration) bool LifetimeScope() LifetimeScope - RequiredServiceTypes() []reflect.Type - ServiceType() reflect.Type + RequiredServiceTypes() []ServiceType + ServiceType() ServiceType } type ServiceRegistrationList interface { @@ -53,8 +60,8 @@ type RegistrationConfigurationFunc func(r ServiceRegistration) type ResolverOptionsFunc func(registry ServiceRegistry) error type Resolver interface { - Resolve(ctx context.Context, serviceType reflect.Type) ([]interface{}, error) - ResolveWithOptions(ctx context.Context, serviceType reflect.Type, options ...ResolverOptionsFunc) ([]interface{}, error) + Resolve(ctx context.Context, serviceType ServiceType) ([]interface{}, error) + ResolveWithOptions(ctx context.Context, serviceType ServiceType, options ...ResolverOptionsFunc) ([]interface{}, error) } type DependencyInfo interface { @@ -64,7 +71,7 @@ type DependencyInfo interface { HasInstance() bool Instance() interface{} Registration() ServiceRegistration - RequiredServiceTypes() []reflect.Type + RequiredServiceTypes() []ServiceType RequiredServices() ([]interface{}, error) ServiceTypeName() string SetInstance(instance interface{}) error