diff --git a/Dockerfile b/Dockerfile index e94004c..f5181c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ RUN go mod download # Copy the go source COPY cmd/main.go cmd/main.go COPY api/ api/ -COPY internal/controller/ internal/controller/ +COPY internal/ internal/ # Copy Makefile and build command requirements COPY config/ config/ diff --git a/api/v1alpha1/modelregistry_annotations_test.go b/api/v1alpha1/modelregistry_annotations_test.go new file mode 100644 index 0000000..aece5c8 --- /dev/null +++ b/api/v1alpha1/modelregistry_annotations_test.go @@ -0,0 +1,64 @@ +package v1alpha1_test + +import ( + "reflect" + "testing" + + "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestHandleAnnotations(t *testing.T) { + tests := []struct { + name string + annotations map[string]string + wantAnnotations map[string]string + mrRestImage string + wantMrRestImage string + }{ + { + name: "no handler annotations", + annotations: map[string]string{ + "modelregistry.opendatahub.io/other-annotation": "true", + }, + wantAnnotations: map[string]string{ + "modelregistry.opendatahub.io/other-annotation": "true", + }, + mrRestImage: "test", + wantMrRestImage: "test", + }, + { + name: "reset defaults", + annotations: map[string]string{ + "modelregistry.opendatahub.io/reset-spec-defaults": "", + }, + wantAnnotations: map[string]string{}, + mrRestImage: "test", + wantMrRestImage: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mr := &v1alpha1.ModelRegistry{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: tt.annotations, + }, + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{ + Image: tt.mrRestImage, + }, + }, + } + mr.HandleAnnotations() + + if !reflect.DeepEqual(mr.GetAnnotations(), tt.wantAnnotations) { + t.Errorf("HandleAnnotations() = %v, want %v", mr.GetAnnotations(), tt.wantAnnotations) + } + + if mr.Spec.Rest.Image != tt.wantMrRestImage { + t.Errorf("HandleAnnotations() = %v, want %v", mr.Name, tt.wantMrRestImage) + } + }) + } +} diff --git a/api/v1alpha1/modelregistry_webhook.go b/api/v1alpha1/modelregistry_webhook.go index 89cadb3..e9aa61b 100644 --- a/api/v1alpha1/modelregistry_webhook.go +++ b/api/v1alpha1/modelregistry_webhook.go @@ -19,18 +19,19 @@ package v1alpha1 import ( "context" "fmt" + "maps" + "slices" + "strings" + "github.com/opendatahub-io/model-registry-operator/internal/controller/config" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" - "maps" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" - "slices" - "strings" ) // log is for logging in this package. diff --git a/api/v1alpha1/modelregistry_webhook_test.go b/api/v1alpha1/modelregistry_webhook_test.go new file mode 100644 index 0000000..e0e7d55 --- /dev/null +++ b/api/v1alpha1/modelregistry_webhook_test.go @@ -0,0 +1,387 @@ +package v1alpha1_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" + "github.com/opendatahub-io/model-registry-operator/internal/controller/config" +) + +var ( + certName = "test-cert" + tlsMode = "SIMPLE" + audience = "test-audience" + authProvider = "test-auth-provider" + authLabelKey = "test-auth-labels" + authLabelValue = "true" + authLabel = fmt.Sprintf("%s=%s", authLabelKey, authLabelValue) + domain = "example.com" +) + +func TestValidateDatabase(t *testing.T) { + tests := []struct { + name string + mrSpec *v1alpha1.ModelRegistry + wantErr bool + }{ + { + name: "valid - mysql", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + MySQL: &v1alpha1.MySQLConfig{}, + }}, + wantErr: false, + }, + { + name: "valid - postgres", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Postgres: &v1alpha1.PostgresConfig{}, + }}, + wantErr: false, + }, + { + name: "invalid - missing databases", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{}}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, errList := tt.mrSpec.ValidateDatabase() + if tt.wantErr { + if len(errList) == 0 { + t.Errorf("ValidateDatabase() error = %v, wantErr %v", errList, tt.wantErr) + } + } else { + if len(errList) > 0 { + t.Errorf("ValidateDatabase() error = %v, wantErr %v", errList, tt.wantErr) + } + } + }) + } +} + +func TestValidateIstioConfig(t *testing.T) { + tests := []struct { + name string + mrSpec *v1alpha1.ModelRegistry + wantErr bool + }{ + { + name: "invalid - istio missing authProvider", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{}, + }}, + wantErr: true, + }, + { + name: "invalid - istio missing authConfigLabels", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{AuthProvider: "istio"}, + }}, + wantErr: true, + }, + { + name: "invalid - istio gateway missing domain", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{ + AuthProvider: "istio", + AuthConfigLabels: map[string]string{"auth": "enabled"}, + Gateway: &v1alpha1.GatewayConfig{}, + }, + }}, + wantErr: true, + }, + { + name: "invalid - istio gateway rest custom TLS missing credentials", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{ + AuthProvider: "istio", + AuthConfigLabels: map[string]string{"auth": "enabled"}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "test.com", + Rest: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: "SIMPLE", + }, + }, + }, + }, + }}, + wantErr: true, + }, + { + name: "invalid - istio gateway grpc custom TLS missing credentials", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{ + AuthProvider: "istio", + AuthConfigLabels: map[string]string{"auth": "enabled"}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "test.com", + Grpc: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: "SIMPLE", + }, + }, + }, + }, + }}, + wantErr: true, + }, + { + name: "valid - istio config", + mrSpec: &v1alpha1.ModelRegistry{Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{ + AuthProvider: "istio", + AuthConfigLabels: map[string]string{"auth": "enabled"}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "test.com", + }, + }, + }}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, errList := tt.mrSpec.ValidateIstioConfig() + if tt.wantErr { + if len(errList) == 0 { + t.Errorf("ValidateIstioConfig() error = %v, wantErr %v", errList, tt.wantErr) + } + } else { + if len(errList) > 0 { + t.Errorf("ValidateIstioConfig() error = %v, wantErr %v", errList, tt.wantErr) + } + } + }) + } +} + +func TestDefault(t *testing.T) { + var httpPort int32 = v1alpha1.DefaultHttpPort + defaultIstioGateway := v1alpha1.DefaultIstioGateway + + tests := []struct { + name string + mrSpec *v1alpha1.ModelRegistry + wantMrSpec *v1alpha1.ModelRegistry + }{ + { + name: "set default values", + mrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{}, + Postgres: &v1alpha1.PostgresConfig{}, + MySQL: &v1alpha1.MySQLConfig{}, + Istio: &v1alpha1.IstioConfig{ + Gateway: &v1alpha1.GatewayConfig{}, + }, + }, + }, + wantMrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{ + ServiceRoute: config.RouteDisabled, + }, + Postgres: nil, + MySQL: nil, + Istio: &v1alpha1.IstioConfig{ + TlsMode: v1alpha1.DefaultTlsMode, + Gateway: &v1alpha1.GatewayConfig{ + IstioIngress: &defaultIstioGateway, + Rest: v1alpha1.ServerConfig{ + Port: &httpPort, + GatewayRoute: config.RouteEnabled, + }, + Grpc: v1alpha1.ServerConfig{ + Port: &httpPort, + GatewayRoute: config.RouteEnabled, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mrSpec.Default() + if !reflect.DeepEqual(tt.mrSpec, tt.wantMrSpec) { + t.Errorf("Default() = %v, want %v", tt.mrSpec, tt.wantMrSpec) + } + }) + } +} + +func TestCleanupRuntimeDefaults(t *testing.T) { + setupDefaults(t) + + tests := []struct { + name string + mrSpec *v1alpha1.ModelRegistry + wantMrSpec *v1alpha1.ModelRegistry + }{ + { + name: "cleanup runtime default values", + mrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{ + Resources: config.MlmdRestResourceRequirements.DeepCopy(), + Image: config.DefaultRestImage, + }, + Grpc: v1alpha1.GrpcSpec{ + Resources: config.MlmdGRPCResourceRequirements.DeepCopy(), + Image: config.DefaultGrpcImage, + }, + Istio: &v1alpha1.IstioConfig{ + Audiences: []string{audience}, + AuthProvider: authProvider, + AuthConfigLabels: map[string]string{authLabelKey: authLabelValue}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: domain, + Rest: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: &certName, + }, + }, + Grpc: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: &certName, + }, + }, + }, + }, + }, + }, + wantMrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{ + Resources: nil, + Image: "", + }, + Grpc: v1alpha1.GrpcSpec{ + Resources: nil, + Image: "", + }, + Istio: &v1alpha1.IstioConfig{ + Audiences: []string{}, + AuthProvider: "", + AuthConfigLabels: map[string]string{}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "", + Rest: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: nil, + }, + }, + Grpc: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: nil, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mrSpec.CleanupRuntimeDefaults() + if !reflect.DeepEqual(tt.mrSpec, tt.wantMrSpec) { + t.Errorf("CleanupRuntimeDefaults() = %v, want %v", tt.mrSpec, tt.wantMrSpec) + } + }) + } +} + +func TestRuntimeDefaults(t *testing.T) { + setupDefaults(t) + + tests := []struct { + name string + mrSpec *v1alpha1.ModelRegistry + wantMrSpec *v1alpha1.ModelRegistry + }{ + { + name: "set runtime default values", + mrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Istio: &v1alpha1.IstioConfig{ + Gateway: &v1alpha1.GatewayConfig{ + Rest: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + }, + }, + Grpc: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + }, + }, + }, + }, + }, + }, + wantMrSpec: &v1alpha1.ModelRegistry{ + Spec: v1alpha1.ModelRegistrySpec{ + Grpc: v1alpha1.GrpcSpec{ + Resources: config.MlmdGRPCResourceRequirements.DeepCopy(), + Image: config.DefaultGrpcImage, + }, + Rest: v1alpha1.RestSpec{ + Resources: config.MlmdRestResourceRequirements.DeepCopy(), + Image: config.DefaultRestImage, + }, + Istio: &v1alpha1.IstioConfig{ + Audiences: []string{audience}, + AuthProvider: authProvider, + AuthConfigLabels: map[string]string{authLabelKey: authLabelValue}, + Gateway: &v1alpha1.GatewayConfig{ + Domain: domain, + Rest: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: &certName, + }, + }, + Grpc: v1alpha1.ServerConfig{ + TLS: &v1alpha1.TLSServerSettings{ + Mode: tlsMode, + CredentialName: &certName, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mrSpec.RuntimeDefaults() + if !reflect.DeepEqual(tt.mrSpec, tt.wantMrSpec) { + t.Errorf("RuntimeDefaults() = %v, want %v", tt.mrSpec, tt.wantMrSpec) + } + }) + } +} + +func setupDefaults(t testing.TB) { + t.Helper() + + config.SetDefaultAudiences([]string{audience}) + config.SetDefaultDomain(domain, k8sClient, false) + config.SetDefaultCert(certName) + config.SetDefaultAuthProvider(authProvider) + config.SetDefaultAuthConfigLabels(authLabel) +} diff --git a/api/v1alpha1/webhook_suite_test.go b/api/v1alpha1/webhook_suite_test.go index 27e811a..c971af7 100644 --- a/api/v1alpha1/webhook_suite_test.go +++ b/api/v1alpha1/webhook_suite_test.go @@ -14,22 +14,24 @@ See the License for the specific language governing permissions and limitations under the License. */ -package v1alpha1 +package v1alpha1_test import ( "context" "crypto/tls" "fmt" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "net" "path/filepath" "runtime" "testing" "time" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" admissionv1 "k8s.io/api/admission/v1" //+kubebuilder:scaffold:imports @@ -94,7 +96,7 @@ var _ = BeforeSuite(func() { Expect(cfg).NotTo(BeNil()) scheme := apimachineryruntime.NewScheme() - err = AddToScheme(scheme) + err = v1alpha1.AddToScheme(scheme) Expect(err).NotTo(HaveOccurred()) err = admissionv1.AddToScheme(scheme) @@ -123,7 +125,7 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) - err = (&ModelRegistry{}).SetupWebhookWithManager(mgr) + err = (&v1alpha1.ModelRegistry{}).SetupWebhookWithManager(mgr) Expect(err).NotTo(HaveOccurred()) //+kubebuilder:scaffold:webhook @@ -171,27 +173,45 @@ var _ = Describe("Model Registry validating webhook", func() { mr2 = newModelRegistry(ctx, mrNameBase+suffix2, namespaceBase+suffix2) Expect(k8sClient.Create(ctx, mr2)).Should(Succeed()) }) + + It("Should not allow creation of MR instance with invalid database config", func(ctx context.Context) { + mr := newModelRegistry(ctx, mrNameBase+"-invalid-db-create", namespaceBase) + mr.Spec = v1alpha1.ModelRegistrySpec{} + + Expect(k8sClient.Create(ctx, mr)).ShouldNot(Succeed()) + }) + + It("Should not allow update of MR instance with invalid database config", func(ctx context.Context) { + mr := newModelRegistry(ctx, mrNameBase+"-invalid-db-update", namespaceBase) + Expect(k8sClient.Create(ctx, mr)).Should(Succeed()) + + mr.Spec = v1alpha1.ModelRegistrySpec{ + MySQL: &v1alpha1.MySQLConfig{}, + } + + Expect(k8sClient.Update(ctx, mr)).ShouldNot(Succeed()) + }) }) -func newModelRegistry(ctx context.Context, name string, namespace string) *ModelRegistry { +func newModelRegistry(ctx context.Context, name string, namespace string) *v1alpha1.ModelRegistry { // create test namespace Expect(client.IgnoreAlreadyExists(k8sClient.Create(ctx, &corev1.Namespace{ ObjectMeta: metav1.ObjectMeta{Name: namespace}, }))).Should(Succeed()) // return - return &ModelRegistry{ + return &v1alpha1.ModelRegistry{ ObjectMeta: metav1.ObjectMeta{ Name: name, Namespace: namespace, }, - Spec: ModelRegistrySpec{ - Rest: RestSpec{}, - Grpc: GrpcSpec{}, - MySQL: &MySQLConfig{ + Spec: v1alpha1.ModelRegistrySpec{ + Rest: v1alpha1.RestSpec{}, + Grpc: v1alpha1.GrpcSpec{}, + MySQL: &v1alpha1.MySQLConfig{ Host: "test-db", Username: "test-user", - PasswordSecret: &SecretKeyValue{ + PasswordSecret: &v1alpha1.SecretKeyValue{ Name: "test-secret", Key: "test-key", }, diff --git a/go.mod b/go.mod index 45fb805..1bfbd92 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( github.com/prometheus/procfs v0.10.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect diff --git a/go.sum b/go.sum index 299b840..6c744f5 100644 --- a/go.sum +++ b/go.sum @@ -241,8 +241,8 @@ github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6ke github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= -github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -380,6 +380,7 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/internal/controller/config/defaults.go b/internal/controller/config/defaults.go index 89553ab..68d6485 100644 --- a/internal/controller/config/defaults.go +++ b/internal/controller/config/defaults.go @@ -19,6 +19,10 @@ package config import ( "context" "embed" + "fmt" + "strings" + "text/template" + configv1 "github.com/openshift/api/config/v1" "github.com/spf13/viper" v1 "k8s.io/api/core/v1" @@ -26,8 +30,6 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" klog "sigs.k8s.io/controller-runtime/pkg/log" - "strings" - "text/template" ) //go:embed templates/*.yaml.tmpl @@ -152,7 +154,7 @@ func GetDefaultDomain() string { namespacedName := types.NamespacedName{Name: "cluster"} err := defaultClient.Get(context.Background(), namespacedName, &ingress) if err != nil { - klog.Log.Error(err, "error getting OpenShift domain name", ingress.GetObjectKind(), namespacedName) + klog.Log.Error(err, "error getting OpenShift domain name", fmt.Sprintf("%+v", ingress.GetObjectKind()), namespacedName) return "" } defaultDomain = ingress.Spec.Domain diff --git a/internal/controller/config/defaults_test.go b/internal/controller/config/defaults_test.go index cc61658..264a84b 100644 --- a/internal/controller/config/defaults_test.go +++ b/internal/controller/config/defaults_test.go @@ -1,8 +1,19 @@ -package config +package config_test import ( + "fmt" "os" "testing" + + "github.com/go-logr/logr" + "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" + "github.com/opendatahub-io/model-registry-operator/internal/controller" + "github.com/opendatahub-io/model-registry-operator/internal/controller/config" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + rbac "k8s.io/api/rbac/v1" ) func TestGetStringConfigWithDefault(t *testing.T) { @@ -11,20 +22,20 @@ func TestGetStringConfigWithDefault(t *testing.T) { configName string want string }{ - {name: "test " + GrpcImage, configName: GrpcImage, want: "success1"}, - {name: "test " + RestImage, configName: RestImage, want: "success2"}, + {name: "test " + config.GrpcImage, configName: config.GrpcImage, want: "success1"}, + {name: "test " + config.RestImage, configName: config.RestImage, want: "success2"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { os.Setenv(tt.configName, tt.want) - if got := GetStringConfigWithDefault(tt.configName, "fail"); got != tt.want { + if got := config.GetStringConfigWithDefault(tt.configName, "fail"); got != tt.want { t.Errorf("GetStringConfigWithDefault() = %v, want %v", got, tt.want) } }) } } -/*func TestParseTemplates(t *testing.T) { +func TestParseTemplates(t *testing.T) { tests := []struct { name string spec v1alpha1.ModelRegistrySpec @@ -35,7 +46,7 @@ func TestGetStringConfigWithDefault(t *testing.T) { } // parse all templates - templates, err := ParseTemplates() + templates, err := config.ParseTemplates() if err != nil { t.Errorf("ParseTemplates() error = %v", err) } @@ -50,17 +61,116 @@ func TestGetStringConfigWithDefault(t *testing.T) { params := controller.ModelRegistryParams{ Name: "test", Namespace: "test-namespace", - Spec: tt.spec, + Spec: &tt.spec, } - got, err := reconciler.Apply(params, tt.name, result) + var result rbac.Role + err := reconciler.Apply(¶ms, tt.name, &result) if (err != nil) != tt.wantErr { t.Errorf("ParseTemplates() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ParseTemplates() got = %v, want %v", got, tt.want) + + if result.Name != fmt.Sprintf("registry-user-%s", params.Name) { + t.Errorf("ParseTemplates() got = %v, want %v", result.Name, fmt.Sprintf("registry-user-%s", params.Name)) + } + + if result.Namespace != params.Namespace { + t.Errorf("ParseTemplates() got = %v, want %v", result.Namespace, params.Namespace) + } + }) + } +} + +func TestSetGetDefaultAudiences(t *testing.T) { + tests := []struct { + name string + audiences []string + }{ + {name: "test1", audiences: []string{"audience1", "audience2"}}, + {name: "test2", audiences: []string{"audience3", "audience4"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.SetDefaultAudiences(tt.audiences) + if got := config.GetDefaultAudiences(); len(got) != len(tt.audiences) { + t.Errorf("GetDefaultAudiences() = %v, want %v", got, tt.audiences) + } + }) + } +} + +func TestSetGetDefaultAuthProvider(t *testing.T) { + tests := []struct { + name string + provider string + }{ + {name: "test1", provider: "provider1"}, + {name: "test2", provider: "provider2"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.SetDefaultAuthProvider(tt.provider) + if got := config.GetDefaultAuthProvider(); got != tt.provider { + t.Errorf("GetDefaultAuthProvider() = %v, want %v", got, tt.provider) } }) } } -*/ + +func TestSetGetDefaultAuthConfigLabels(t *testing.T) { + tests := []struct { + name string + labels string + }{ + {name: "test1", labels: "label1"}, + {name: "test2", labels: "label2"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.SetDefaultAuthConfigLabels(tt.labels) + if got := config.GetDefaultAuthConfigLabels(); len(got) != 1 { + t.Errorf("GetDefaultAuthConfigLabels() = %v, want %v", got, tt.labels) + } + }) + } +} + +func TestSetGetDefaultCert(t *testing.T) { + tests := []struct { + name string + cert string + }{ + {name: "test1", cert: "cert1"}, + {name: "test2", cert: "cert2"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.SetDefaultCert(tt.cert) + if got := config.GetDefaultCert(); got != tt.cert { + t.Errorf("GetDefaultCert() = %v, want %v", got, tt.cert) + } + }) + } +} + +var _ = Describe("Defaults integration tests", func() { + Describe("TestSetGetDefaultDomain", func() { + It("Should return the set domain on openshift", func() { + config.SetDefaultDomain("domain1", k8sClient, true) + + Expect(config.GetDefaultDomain()).To(Equal("domain1")) + }) + + It("Should return the set domain on non-openshift", func() { + config.SetDefaultDomain("domain2", k8sClient, false) + + Expect(config.GetDefaultDomain()).To(Equal("domain2")) + }) + + It("Should return the domain from ingress when no domain is set", func() { + config.SetDefaultDomain("", k8sClient, true) + + Expect(config.GetDefaultDomain()).To(Equal("domain3")) + }) + }) +}) diff --git a/internal/controller/config/suite_test.go b/internal/controller/config/suite_test.go new file mode 100644 index 0000000..ea68d16 --- /dev/null +++ b/internal/controller/config/suite_test.go @@ -0,0 +1,137 @@ +/* +Copyright 2023. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/opendatahub-io/model-registry-operator/internal/utils" + + configv1 "github.com/openshift/api/config/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/envtest" + logf "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + //+kubebuilder:scaffold:imports +) + +// These tests use Ginkgo (BDD-style Go testing framework). Refer to +// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. + +type remoteCRD struct { + url string + fileName string +} + +var ( + cfg *rest.Config + k8sClient client.Client + testEnv *envtest.Environment + testCRDLocalPath = "./testdata/crd" + remoteCRDs = []remoteCRD{ + { + url: "https://raw.githubusercontent.com/openshift/api/refs/heads/master/config/v1/zz_generated.crd-manifests/0000_10_config-operator_01_ingresses.crd.yaml", + fileName: "ingress.openshift.io_ingresses.yaml", + }, + } +) + +func TestControllers(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecs(t, "Controller Suite") +} + +var _ = BeforeSuite(func() { + var err error + + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + + schm := apiruntime.NewScheme() + + err = configv1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:scheme + + // Download CRDs + + if err := os.MkdirAll(filepath.Dir(testCRDLocalPath), 0755); err != nil { + Fail(err.Error()) + } + + for _, crd := range remoteCRDs { + if err := utils.DownloadFile(crd.url, filepath.Join(testCRDLocalPath, crd.fileName)); err != nil { + Fail(err.Error()) + } + } + + By("bootstrapping test environment") + useExistingCluster := false + testEnv = &envtest.Environment{ + Scheme: schm, + CRDDirectoryPaths: []string{ + filepath.Join("testdata", "crd"), + }, + ErrorIfCRDPathMissing: true, + + // The BinaryAssetsDirectory is only required if you want to run the tests directly + // without call the makefile target test. If not informed it will look for the + // default path defined in controller-runtime which is /usr/local/kubebuilder/. + // Note that you must have the required binaries setup under the bin directory to perform + // the tests directly. When we run make test it will be setup and used automatically. + BinaryAssetsDirectory: filepath.Join("..", "..", "bin", "k8s", + fmt.Sprintf("1.28.0-%s-%s", runtime.GOOS, runtime.GOARCH)), + UseExistingCluster: &useExistingCluster, + } + // cfg is defined in this file globally. + cfg, err = testEnv.Start() + Expect(err).NotTo(HaveOccurred()) + Expect(cfg).NotTo(BeNil()) + + k8sClient, err = client.New(cfg, client.Options{Scheme: schm}) + Expect(err).NotTo(HaveOccurred()) + Expect(k8sClient).NotTo(BeNil()) + + clusterIngress := &configv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "cluster", + }, + Spec: configv1.IngressSpec{ + Domain: "domain3", + }, + } + + err = k8sClient.Create(context.Background(), clusterIngress) + Expect(err).NotTo(HaveOccurred()) +}) + +var _ = AfterSuite(func() { + By("tearing down the test environment") + err := testEnv.Stop() + Expect(err).NotTo(HaveOccurred()) +}) diff --git a/internal/controller/config/testdata/crd/.gitignore b/internal/controller/config/testdata/crd/.gitignore new file mode 100644 index 0000000..5142034 --- /dev/null +++ b/internal/controller/config/testdata/crd/.gitignore @@ -0,0 +1,4 @@ +# CRDs files will be fetched from the test suite + +*.yaml +!*_mocked.yaml diff --git a/internal/controller/modelregistry_controller.go b/internal/controller/modelregistry_controller.go index 0718ae9..2f6ddac 100644 --- a/internal/controller/modelregistry_controller.go +++ b/internal/controller/modelregistry_controller.go @@ -20,11 +20,12 @@ import ( "context" errors2 "errors" "fmt" + "strings" + "text/template" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/kubernetes" - "strings" - "text/template" "github.com/banzaicloud/k8s-objectmatcher/patch" "github.com/go-logr/logr" diff --git a/internal/controller/modelregistry_controller_status.go b/internal/controller/modelregistry_controller_status.go index 4cfac75..e27e199 100644 --- a/internal/controller/modelregistry_controller_status.go +++ b/internal/controller/modelregistry_controller_status.go @@ -20,7 +20,10 @@ import ( "bufio" "context" "fmt" - "github.com/evanphx/json-patch/v5" + "regexp" + "strings" + + jsonpatch "github.com/evanphx/json-patch/v5" "github.com/go-logr/logr" modelregistryv1alpha1 "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" routev1 "github.com/openshift/api/route/v1" @@ -33,11 +36,9 @@ import ( "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/json" - "regexp" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" klog "sigs.k8s.io/controller-runtime/pkg/log" - "strings" ) // Definitions to manage status conditions @@ -366,13 +367,15 @@ func (r *ModelRegistryReconciler) SetIstioCondition(ctx context.Context, req ctr available := true // verify that virtualservice, destinationrule, authorizationpolicy are available name := req.NamespacedName - message, available = r.CheckIstioResourcesAvailable(ctx, name, log, message, available) + message, available, reason = r.CheckIstioResourcesAvailable(ctx, name, log, message, available, reason) - message, available, reason = r.CheckAuthConfigCondition(ctx, name, log, message, available, reason) + if r.CreateAuthResources { + message, available, reason = r.CheckAuthConfigCondition(ctx, name, log, message, available, reason) + } status := metav1.ConditionFalse if available { - if reason == ReasonResourcesAvailable { + if reason == ReasonResourcesAvailable || (!r.CreateAuthResources && reason == ReasonResourcesCreated) { status = metav1.ConditionTrue } // additionally verify that Deployment pod has 3 containers including the istio-envoy proxy @@ -390,7 +393,6 @@ func (r *ModelRegistryReconciler) SetIstioCondition(ctx context.Context, req ctr func (r *ModelRegistryReconciler) CheckDeploymentPods(ctx context.Context, name types.NamespacedName, log logr.Logger, message string, reason string, status metav1.ConditionStatus) (string, string, metav1.ConditionStatus) { - pods := corev1.PodList{} if err := r.Client.List(ctx, &pods, client.MatchingLabels{"app": name.Name, "component": "model-registry"}, @@ -401,15 +403,24 @@ func (r *ModelRegistryReconciler) CheckDeploymentPods(ctx context.Context, name reason = ReasonResourcesUnavailable status = metav1.ConditionFalse - } else { - // check that pods have 3 containers - for _, pod := range pods.Items { - if len(pod.Spec.Containers) != 3 { - message = fmt.Sprintf("Istio proxy unavailable in Pod %s", pod.Name) - reason = ReasonResourcesUnavailable - status = metav1.ConditionFalse - break - } + return message, reason, status + } + + if len(pods.Items) == 0 { + message = fmt.Sprintf("No Pods found for Deployment %s", name.Name) + reason = ReasonResourcesUnavailable + status = metav1.ConditionFalse + + return message, reason, status + } + + // check that pods have 3 containers + for _, pod := range pods.Items { + if len(pod.Spec.Containers) != 3 { + message = fmt.Sprintf("Istio proxy unavailable in Pod %s", pod.Name) + reason = ReasonResourcesUnavailable + status = metav1.ConditionFalse + break } } @@ -453,7 +464,7 @@ func (r *ModelRegistryReconciler) CheckAuthConfigCondition(ctx context.Context, } func (r *ModelRegistryReconciler) CheckIstioResourcesAvailable(ctx context.Context, name types.NamespacedName, - log logr.Logger, message string, available bool) (string, bool) { + log logr.Logger, message string, available bool, reason string) (string, bool, string) { var resource client.Object resource = &v1beta1.VirtualService{} @@ -461,23 +472,28 @@ func (r *ModelRegistryReconciler) CheckIstioResourcesAvailable(ctx context.Conte log.Error(err, "Failed to get model registry Istio VirtualService", "name", name) message = fmt.Sprintf("Failed to find VirtualService: %s", err.Error()) available = false + reason = ReasonResourcesUnavailable } resource = &v1beta1.DestinationRule{} if err := r.Get(ctx, name, resource); err != nil { log.Error(err, "Failed to get model registry Istio DestinationRule", "name", name) message = fmt.Sprintf("Failed to find DestinationRule: %s", err.Error()) available = false + reason = ReasonResourcesUnavailable } - resource = &v1beta12.AuthorizationPolicy{} - policyName := name - policyName.Name = policyName.Name + "-authorino" - if err := r.Get(ctx, policyName, resource); err != nil { - log.Error(err, "Failed to get model registry Istio AuthorizationPolicy", "name", policyName) - message = fmt.Sprintf("Failed to find AuthorizationPolicy %s: %s", policyName, err.Error()) - available = false + if r.CreateAuthResources { + resource = &v1beta12.AuthorizationPolicy{} + policyName := name + policyName.Name = policyName.Name + "-authorino" + if err := r.Get(ctx, policyName, resource); err != nil { + log.Error(err, "Failed to get model registry Istio AuthorizationPolicy", "name", policyName) + message = fmt.Sprintf("Failed to find AuthorizationPolicy %s: %s", policyName, err.Error()) + available = false + reason = ReasonResourcesUnavailable + } } - return message, available + return message, available, reason } func (r *ModelRegistryReconciler) SetGatewayCondition(ctx context.Context, req ctrl.Request, @@ -511,7 +527,7 @@ func (r *ModelRegistryReconciler) SetGatewayCondition(ctx context.Context, req c status := metav1.ConditionFalse if available { - if reason == ReasonResourcesAvailable { + if reason == ReasonResourcesAvailable || (!r.IsOpenShift && reason == ReasonResourcesCreated) { status = metav1.ConditionTrue } } else { diff --git a/internal/controller/modelregistry_controller_test.go b/internal/controller/modelregistry_controller_test.go index aaef623..3d6311e 100644 --- a/internal/controller/modelregistry_controller_test.go +++ b/internal/controller/modelregistry_controller_test.go @@ -19,23 +19,29 @@ package controller import ( "context" "fmt" + "os" + "strings" + "text/template" + "time" + "github.com/opendatahub-io/model-registry-operator/internal/controller/config" - "github.com/openshift/api" + routev1 "github.com/openshift/api/route/v1" + userv1 "github.com/openshift/api/user/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/tools/record" - "os" ctrl "sigs.k8s.io/controller-runtime" - "strings" - "text/template" - "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" @@ -127,7 +133,9 @@ var _ = Describe("ModelRegistry controller", func() { err = k8sClient.Create(ctx, modelRegistry) Expect(err).To(Not(HaveOccurred())) - Eventually(validateRegistry(ctx, typeNamespaceName, template, modelRegistry), + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryBase(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), time.Minute, time.Second).Should(Succeed()) }) @@ -151,7 +159,243 @@ var _ = Describe("ModelRegistry controller", func() { err = k8sClient.Create(ctx, modelRegistry) Expect(err).To(Not(HaveOccurred())) - Eventually(validateRegistry(ctx, typeNamespaceName, template, modelRegistry), + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryBase(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + }) + + It("When using OpenShift - serviceRoute enabled", func() { + registryName = "model-registry-openshift-with-serviceroute" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + modelRegistry.Spec.Rest.ServiceRoute = config.RouteEnabled + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryOpenshift(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + + By("Checking if the Openshift Route was successfully created in the reconciliation") + Eventually(func() error { + found := &routev1.Route{} + + return k8sClient.Get(ctx, types.NamespacedName{Name: fmt.Sprintf("%s-http", modelRegistry.Name), Namespace: modelRegistry.Namespace}, found) + }, 5*time.Second, time.Second).Should(Succeed()) + }) + + It("When using OpenShift - serviceRoute disabled", func() { + registryName = "model-registry-openshift-without-serviceroute" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryOpenshift(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + + By("Checking if the Openshift Route was not created in the reconciliation") + found := &routev1.Route{} + + err := k8sClient.Get(ctx, types.NamespacedName{Name: fmt.Sprintf("%s-http", modelRegistry.Name), Namespace: modelRegistry.Namespace}, found) + Expect(err).To(HaveOccurred()) + }) + + It("When using Istio", func() { + registryName = "model-registry-istio" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + modelRegistry.Spec.Istio = &v1alpha1.IstioConfig{ + AuthProvider: "opendatahub-auth-provider", + AuthConfigLabels: map[string]string{ + "auth": "enabled", + }, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "example.com", + Rest: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + Grpc: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + }, + } + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryIstio(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + }) + + It("When using Istio on Openshift", func() { + registryName = "model-registry-istio-openshift" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + modelRegistry.Spec.Istio = &v1alpha1.IstioConfig{ + AuthProvider: "opendatahub-auth-provider", + AuthConfigLabels: map[string]string{ + "auth": "enabled", + }, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "example.com", + Rest: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + Grpc: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + }, + } + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + modelRegistryReconciler.IsOpenShift = true + + Eventually(validateRegistryIstio(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + }) + + It("When using Istio and Authorino", func() { + registryName = "model-registry-istio-authorino" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + modelRegistry.Spec.Istio = &v1alpha1.IstioConfig{ + AuthProvider: "opendatahub-auth-provider", + AuthConfigLabels: map[string]string{ + "auth": "enabled", + }, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "example.com", + Rest: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + Grpc: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + }, + } + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + Eventually(validateRegistryAuth(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), + time.Minute, time.Second).Should(Succeed()) + }) + + It("When using Istio and Authorino on openshift", func() { + registryName = "model-registry-istio-authorino-openshift" + specInit() + + var mySQLPort int32 = 3306 + modelRegistry.Spec.Postgres = nil + modelRegistry.Spec.MySQL = &v1alpha1.MySQLConfig{ + Host: "model-registry-db", + Port: &mySQLPort, + Database: "model_registry", + Username: "mlmduser", + PasswordSecret: &v1alpha1.SecretKeyValue{ + Name: "model-registry-db", + Key: "database-password", + }, + } + modelRegistry.Spec.Istio = &v1alpha1.IstioConfig{ + AuthProvider: "opendatahub-auth-provider", + AuthConfigLabels: map[string]string{ + "auth": "enabled", + }, + Gateway: &v1alpha1.GatewayConfig{ + Domain: "example.com", + Rest: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + Grpc: v1alpha1.ServerConfig{ + GatewayRoute: "enabled", + }, + }, + } + + err = k8sClient.Create(ctx, modelRegistry) + Expect(err).To(Not(HaveOccurred())) + + modelRegistryReconciler := initModelRegistryReconciler(template) + + modelRegistryReconciler.IsOpenShift = true + + Eventually(validateRegistryAuth(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler), time.Minute, time.Second).Should(Succeed()) }) @@ -165,6 +409,13 @@ var _ = Describe("ModelRegistry controller", func() { return k8sClient.Delete(context.TODO(), found) }, 2*time.Minute, time.Second).Should(Succeed()) + By("Cleaning up istio services") + svc := corev1.Service{} + svc.Name = "istio" + svc.Namespace = typeNamespaceName.Namespace + + _ = k8sClient.Delete(ctx, &svc) + // TODO(user): Attention if you improve this code by adding other context test you MUST // be aware of the current delete namespace limitations. // More info: https://book.kubebuilder.io/reference/envtest.html#testing-considerations @@ -180,7 +431,21 @@ var _ = Describe("ModelRegistry controller", func() { }) }) -func validateRegistry(ctx context.Context, typeNamespaceName types.NamespacedName, template *template.Template, modelRegistry *v1alpha1.ModelRegistry) func() error { +func initModelRegistryReconciler(template *template.Template) *ModelRegistryReconciler { + scheme := k8sClient.Scheme() + + modelRegistryReconciler := &ModelRegistryReconciler{ + Client: k8sClient, + Scheme: scheme, + Recorder: &record.FakeRecorder{}, + Log: ctrl.Log.WithName("controller"), + Template: template, + } + + return modelRegistryReconciler +} + +func validateRegistryBase(ctx context.Context, typeNamespaceName types.NamespacedName, modelRegistry *v1alpha1.ModelRegistry, modelRegistryReconciler *ModelRegistryReconciler) func() error { return func() error { By("Checking if the custom resource was successfully created") Eventually(func() error { @@ -188,16 +453,32 @@ func validateRegistry(ctx context.Context, typeNamespaceName types.NamespacedNam return k8sClient.Get(ctx, typeNamespaceName, found) }, time.Minute, time.Second).Should(Succeed()) - scheme := k8sClient.Scheme() - _ = api.Install(scheme) - modelRegistryReconciler := &ModelRegistryReconciler{ - Client: k8sClient, - Scheme: scheme, - Recorder: &record.FakeRecorder{}, - Log: ctrl.Log.WithName("controller"), - Template: template, + By("Mocking the Pod creation to perform the tests") + mrPod := &corev1.Pod{} + mrPod.Name = typeNamespaceName.Name + mrPod.Namespace = typeNamespaceName.Namespace + mrPod.Labels = map[string]string{"app": typeNamespaceName.Name, "component": "model-registry"} + mrPod.Spec.Containers = []corev1.Container{ + { + Name: "model-registry-rest", + Image: config.DefaultRestImage, + }, + { + Name: "model-registry-grpc", + Image: config.DefaultGrpcImage, + }, + } + + if modelRegistryReconciler.HasIstio { + mrPod.Spec.Containers = append(mrPod.Spec.Containers, corev1.Container{ + Name: "istio-proxy", + Image: "istio-proxy", + }) } + err := k8sClient.Create(ctx, mrPod) + Expect(err).To(Not(HaveOccurred())) + By("Reconciling the custom resource created") Eventually(func() error { result, err := modelRegistryReconciler.Reconcile(ctx, reconcile.Request{ @@ -228,7 +509,6 @@ func validateRegistry(ctx context.Context, typeNamespaceName types.NamespacedNam // reconcile done! return nil }, time.Minute, time.Second).Should(Succeed()) - //Expect(err).To(Not(HaveOccurred())) By("Checking if Deployment was successfully created in the reconciliation") Eventually(func() error { @@ -236,29 +516,131 @@ func validateRegistry(ctx context.Context, typeNamespaceName types.NamespacedNam return k8sClient.Get(ctx, typeNamespaceName, found) }, time.Minute, time.Second).Should(Succeed()) + if modelRegistry.Spec.Istio != nil && modelRegistry.Spec.Istio.Gateway != nil && modelRegistryReconciler.IsOpenShift { + By("Checking if the Route was successfully created in the reconciliation") + routes := &routev1.RouteList{} + err = k8sClient.List(ctx, routes, client.MatchingLabels{ + "app": typeNamespaceName.Name, + "component": "model-registry", + "maistra.io/gateway-name": typeNamespaceName.Name, + }) + Expect(err).To(Not(HaveOccurred())) + + By("Mocking the conditions in the Route to perform the tests") + if len(routes.Items) > 0 { + for _, route := range routes.Items { + ingresses := []routev1.RouteIngress{ + { + Conditions: []routev1.RouteIngressCondition{ + { + Type: routev1.RouteAdmitted, + Status: corev1.ConditionTrue, + }, + }, + }, + } + + route.Status.Ingress = ingresses + + err = k8sClient.Status().Update(ctx, &route) + Expect(err).To(Not(HaveOccurred())) + } + + Eventually(func() error { + _, err := modelRegistryReconciler.Reconcile(ctx, reconcile.Request{ + NamespacedName: typeNamespaceName, + }) + + return err + }, time.Minute, time.Second).Should(Succeed()) + } + } + + if modelRegistryReconciler.CreateAuthResources { + By("Checking if the Auth resources were successfully created in the reconciliation") + authConfig := CreateAuthConfig() + Eventually(func() error { + return k8sClient.Get(ctx, typeNamespaceName, authConfig) + }, time.Minute, time.Second).Should(Succeed()) + + By("Mocking conditions in the AuthConfig to perform the tests") + err := unstructured.SetNestedMap(authConfig.Object, map[string]interface{}{ + "conditions": []interface{}{ + map[string]interface{}{ + "type": "Ready", + "status": "True", + }, + }}, "status") + Expect(err).To(Not(HaveOccurred())) + + By("Updating the AuthConfig to set the Ready condition to True") + err = k8sClient.Status().Update(ctx, authConfig) + Expect(err).To(Not(HaveOccurred())) + + authConfigTwo := CreateAuthConfig() + Eventually(func() error { + return k8sClient.Get(ctx, typeNamespaceName, authConfigTwo) + }, time.Minute, time.Second).Should(Succeed()) + + Eventually(func() error { + if authConfig.Object["status"] == nil { + return fmt.Errorf("status not set") + } + + return nil + }, time.Minute, time.Second).Should(Succeed()) + + Eventually(func() error { + _, err := modelRegistryReconciler.Reconcile(ctx, reconcile.Request{ + NamespacedName: typeNamespaceName, + }) + + return err + }, time.Minute, time.Second).Should(Succeed()) + } + By("Checking the latest Status Condition added to the ModelRegistry instance") Eventually(func() error { err := k8sClient.Get(ctx, typeNamespaceName, modelRegistry) Expect(err).To(Not(HaveOccurred())) - // also check hosts in status - hosts := modelRegistry.Status.Hosts - Expect(len(hosts)).To(Equal(3)) - name := modelRegistry.Name - namespace := modelRegistry.Namespace - Expect(hosts[0]). - To(Equal(fmt.Sprintf("%s.%s.svc.cluster.local", name, namespace))) - Expect(hosts[1]). - To(Equal(fmt.Sprintf("%s.%s", name, namespace))) - Expect(hosts[2]). - To(Equal(name)) - Expect(modelRegistry.Status.HostsStr).To(Equal(strings.Join(hosts, ","))) + if modelRegistry.Spec.Istio != nil && modelRegistry.Spec.Istio.Gateway != nil { + hosts := modelRegistry.Status.Hosts + Expect(len(hosts)).To(Equal(5)) + name := modelRegistry.Name + namespace := modelRegistry.Namespace + domain := modelRegistry.Spec.Istio.Gateway.Domain + Expect(hosts[0]). + To(Equal(fmt.Sprintf("%s-rest.%s", name, domain))) + Expect(hosts[1]). + To(Equal(fmt.Sprintf("%s-grpc.%s", name, domain))) + Expect(hosts[2]). + To(Equal(fmt.Sprintf("%s.%s.svc.cluster.local", name, namespace))) + Expect(hosts[3]). + To(Equal(fmt.Sprintf("%s.%s", name, namespace))) + Expect(hosts[4]). + To(Equal(name)) + Expect(modelRegistry.Status.HostsStr).To(Equal(strings.Join(hosts, ","))) + } else { + // also check hosts in status + hosts := modelRegistry.Status.Hosts + Expect(len(hosts)).To(Equal(3)) + name := modelRegistry.Name + namespace := modelRegistry.Namespace + Expect(hosts[0]). + To(Equal(fmt.Sprintf("%s.%s.svc.cluster.local", name, namespace))) + Expect(hosts[1]). + To(Equal(fmt.Sprintf("%s.%s", name, namespace))) + Expect(hosts[2]). + To(Equal(name)) + Expect(modelRegistry.Status.HostsStr).To(Equal(strings.Join(hosts, ","))) + } if !meta.IsStatusConditionTrue(modelRegistry.Status.Conditions, ConditionTypeProgressing) { return fmt.Errorf("Condition %s is not true", ConditionTypeProgressing) } if !meta.IsStatusConditionTrue(modelRegistry.Status.Conditions, ConditionTypeAvailable) { - return fmt.Errorf("Condition %s is not true", ConditionTypeAvailable) + return fmt.Errorf("Condition %s is not true: %+v", ConditionTypeAvailable, modelRegistry.Status) } return nil }, time.Minute, time.Second).Should(Succeed()) @@ -282,3 +664,67 @@ func validateRegistry(ctx context.Context, typeNamespaceName types.NamespacedNam return nil } } + +func validateRegistryOpenshift(ctx context.Context, typeNamespaceName types.NamespacedName, modelRegistry *v1alpha1.ModelRegistry, modelRegistryReconciler *ModelRegistryReconciler) func() error { + return func() error { + modelRegistryReconciler.IsOpenShift = true + + Eventually(validateRegistryBase(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler)).Should(Succeed()) + + By("Checking if the Openshift Group was successfully created in the reconciliation") + Eventually(func() error { + found := &userv1.Group{} + + return k8sClient.Get(ctx, types.NamespacedName{Name: fmt.Sprintf("%s-users", modelRegistry.Name)}, found) + }, 5*time.Second, time.Second).Should(Succeed()) + + By("Checking if the Openshift RoleBinding was successfully created in the reconciliation") + Eventually(func() error { + found := &rbacv1.RoleBinding{} + + return k8sClient.Get(ctx, types.NamespacedName{Name: fmt.Sprintf("%s-users", modelRegistry.Name), Namespace: modelRegistry.Namespace}, found) + }, 5*time.Second, time.Second).Should(Succeed()) + + return nil + } +} + +func validateRegistryIstio(ctx context.Context, typeNamespaceName types.NamespacedName, modelRegistry *v1alpha1.ModelRegistry, modelRegistryReconciler *ModelRegistryReconciler) func() error { + return func() error { + modelRegistryReconciler.HasIstio = true + + svc := corev1.Service{} + svc.Name = "istio" + svc.Namespace = typeNamespaceName.Namespace + svc.Labels = map[string]string{"istio": v1alpha1.DefaultIstioGateway} + svc.Spec.Ports = []corev1.ServicePort{ + { + Name: "http2", + Port: 80, + TargetPort: intstr.FromInt(80), + }, + { + Name: "https", + Port: 443, + TargetPort: intstr.FromInt(443), + }, + } + + err := k8sClient.Create(ctx, &svc) + Expect(err).To(Not(HaveOccurred())) + + Eventually(validateRegistryBase(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler)).Should(Succeed()) + + return nil + } +} + +func validateRegistryAuth(ctx context.Context, typeNamespaceName types.NamespacedName, modelRegistry *v1alpha1.ModelRegistry, modelRegistryReconciler *ModelRegistryReconciler) func() error { + return func() error { + modelRegistryReconciler.CreateAuthResources = true + + Eventually(validateRegistryIstio(ctx, typeNamespaceName, modelRegistry, modelRegistryReconciler)).Should(Succeed()) + + return nil + } +} diff --git a/internal/controller/operationresult_enum.go b/internal/controller/operationresult_enum.go deleted file mode 100644 index 15e7007..0000000 --- a/internal/controller/operationresult_enum.go +++ /dev/null @@ -1,263 +0,0 @@ -// Code generated by "go-enum -type=OperationResult"; DO NOT EDIT. - -// Install go-enum by `go get install github.com/searKing/golang/tools/go-enum` -package controller - -import ( - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "fmt" - "strconv" -) - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[ResourceUnchanged-0] - _ = x[ResourceCreated-1] - _ = x[ResourceUpdated-2] -} - -const _OperationResult_name = "ResourceUnchangedResourceCreatedResourceUpdated" - -var _OperationResult_index = [...]uint8{0, 17, 32, 47} - -func _() { - var _nil_OperationResult_value = func() (val OperationResult) { return }() - - // An "cannot convert OperationResult literal (type OperationResult) to type fmt.Stringer" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ fmt.Stringer = _nil_OperationResult_value -} - -func (i OperationResult) String() string { - if i < 0 || i >= OperationResult(len(_OperationResult_index)-1) { - return "OperationResult(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _OperationResult_name[_OperationResult_index[i]:_OperationResult_index[i+1]] -} - -// New returns a pointer to a new addr filled with the OperationResult value passed in. -func (i OperationResult) New() *OperationResult { - clone := i - return &clone -} - -var _OperationResult_values = []OperationResult{0, 1, 2} - -var _OperationResult_name_to_values = map[string]OperationResult{ - _OperationResult_name[0:17]: 0, - _OperationResult_name[17:32]: 1, - _OperationResult_name[32:47]: 2, -} - -// ParseOperationResultString retrieves an enum value from the enum constants string name. -// Throws an error if the param is not part of the enum. -func ParseOperationResultString(s string) (OperationResult, error) { - if val, ok := _OperationResult_name_to_values[s]; ok { - return val, nil - } - return 0, fmt.Errorf("%s does not belong to OperationResult values", s) -} - -// OperationResultValues returns all values of the enum -func OperationResultValues() []OperationResult { - return _OperationResult_values -} - -// IsAOperationResult returns "true" if the value is listed in the enum definition. "false" otherwise -func (i OperationResult) Registered() bool { - for _, v := range _OperationResult_values { - if i == v { - return true - } - } - return false -} - -func _() { - var _nil_OperationResult_value = func() (val OperationResult) { return }() - - // An "cannot convert OperationResult literal (type OperationResult) to type encoding.BinaryMarshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ encoding.BinaryMarshaler = &_nil_OperationResult_value - - // An "cannot convert OperationResult literal (type OperationResult) to type encoding.BinaryUnmarshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ encoding.BinaryUnmarshaler = &_nil_OperationResult_value -} - -// MarshalBinary implements the encoding.BinaryMarshaler interface for OperationResult -func (i OperationResult) MarshalBinary() (data []byte, err error) { - return []byte(i.String()), nil -} - -// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface for OperationResult -func (i *OperationResult) UnmarshalBinary(data []byte) error { - var err error - *i, err = ParseOperationResultString(string(data)) - return err -} - -func _() { - var _nil_OperationResult_value = func() (val OperationResult) { return }() - - // An "cannot convert OperationResult literal (type OperationResult) to type json.Marshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ json.Marshaler = _nil_OperationResult_value - - // An "cannot convert OperationResult literal (type OperationResult) to type encoding.Unmarshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ json.Unmarshaler = &_nil_OperationResult_value -} - -// MarshalJSON implements the json.Marshaler interface for OperationResult -func (i OperationResult) MarshalJSON() ([]byte, error) { - return json.Marshal(i.String()) -} - -// UnmarshalJSON implements the json.Unmarshaler interface for OperationResult -func (i *OperationResult) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return fmt.Errorf("OperationResult should be a string, got %s", data) - } - - var err error - *i, err = ParseOperationResultString(s) - return err -} - -func _() { - var _nil_OperationResult_value = func() (val OperationResult) { return }() - - // An "cannot convert OperationResult literal (type OperationResult) to type encoding.TextMarshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ encoding.TextMarshaler = _nil_OperationResult_value - - // An "cannot convert OperationResult literal (type OperationResult) to type encoding.TextUnmarshaler" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ encoding.TextUnmarshaler = &_nil_OperationResult_value -} - -// MarshalText implements the encoding.TextMarshaler interface for OperationResult -func (i OperationResult) MarshalText() ([]byte, error) { - return []byte(i.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface for OperationResult -func (i *OperationResult) UnmarshalText(text []byte) error { - var err error - *i, err = ParseOperationResultString(string(text)) - return err -} - -//func _() { -// var _nil_OperationResult_value = func() (val OperationResult) { return }() -// -// // An "cannot convert OperationResult literal (type OperationResult) to type yaml.Marshaler" compiler error signifies that the base type have changed. -// // Re-run the go-enum command to generate them again. -// var _ yaml.Marshaler = _nil_OperationResult_value -// -// // An "cannot convert OperationResult literal (type OperationResult) to type yaml.Unmarshaler" compiler error signifies that the base type have changed. -// // Re-run the go-enum command to generate them again. -// var _ yaml.Unmarshaler = &_nil_OperationResult_value -//} - -// MarshalYAML implements a YAML Marshaler for OperationResult -func (i OperationResult) MarshalYAML() (interface{}, error) { - return i.String(), nil -} - -// UnmarshalYAML implements a YAML Unmarshaler for OperationResult -func (i *OperationResult) UnmarshalYAML(unmarshal func(interface{}) error) error { - var s string - if err := unmarshal(&s); err != nil { - return err - } - - var err error - *i, err = ParseOperationResultString(s) - return err -} - -func _() { - var _nil_OperationResult_value = func() (val OperationResult) { return }() - - // An "cannot convert OperationResult literal (type OperationResult) to type driver.Valuer" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ driver.Valuer = _nil_OperationResult_value - - // An "cannot convert OperationResult literal (type OperationResult) to type sql.Scanner" compiler error signifies that the base type have changed. - // Re-run the go-enum command to generate them again. - var _ sql.Scanner = &_nil_OperationResult_value -} - -func (i OperationResult) Value() (driver.Value, error) { - return i.String(), nil -} - -func (i *OperationResult) Scan(value interface{}) error { - if value == nil { - return nil - } - - str, ok := value.(string) - if !ok { - bytes, ok := value.([]byte) - if !ok { - return fmt.Errorf("value is not a byte slice") - } - - str = string(bytes[:]) - } - - val, err := ParseOperationResultString(str) - if err != nil { - return err - } - - *i = val - return nil -} - -// OperationResultSliceContains reports whether sunEnums is within enums. -func OperationResultSliceContains(enums []OperationResult, sunEnums ...OperationResult) bool { - var seenEnums = map[OperationResult]bool{} - for _, e := range sunEnums { - seenEnums[e] = false - } - - for _, v := range enums { - if _, has := seenEnums[v]; has { - seenEnums[v] = true - } - } - - for _, seen := range seenEnums { - if !seen { - return false - } - } - - return true -} - -// OperationResultSliceContainsAny reports whether any sunEnum is within enums. -func OperationResultSliceContainsAny(enums []OperationResult, sunEnums ...OperationResult) bool { - var seenEnums = map[OperationResult]struct{}{} - for _, e := range sunEnums { - seenEnums[e] = struct{}{} - } - - for _, v := range enums { - if _, has := seenEnums[v]; has { - return true - } - } - - return false -} diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go index 884d388..94d3066 100644 --- a/internal/controller/suite_test.go +++ b/internal/controller/suite_test.go @@ -18,6 +18,7 @@ package controller import ( "fmt" + "os" "path/filepath" "runtime" "testing" @@ -25,23 +26,55 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "k8s.io/client-go/kubernetes/scheme" + routev1 "github.com/openshift/api/route/v1" + userv1 "github.com/openshift/api/user/v1" + istioclientv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" + istiosecurityv1beta1 "istio.io/client-go/pkg/apis/security/v1beta1" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/rest" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/envtest" logf "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" + "sigs.k8s.io/controller-runtime/pkg/scheme" modelregistryv1alpha1 "github.com/opendatahub-io/model-registry-operator/api/v1alpha1" + "github.com/opendatahub-io/model-registry-operator/internal/utils" //+kubebuilder:scaffold:imports ) // These tests use Ginkgo (BDD-style Go testing framework). Refer to // http://onsi.github.io/ginkgo/ to learn more about Ginkgo. -var cfg *rest.Config -var k8sClient client.Client -var testEnv *envtest.Environment +type remoteCRD struct { + url string + fileName string +} + +var ( + cfg *rest.Config + k8sClient client.Client + testEnv *envtest.Environment + testCRDLocalPath = "./testdata/crd" + remoteCRDs = []remoteCRD{ + { + url: "https://raw.githubusercontent.com/Kuadrant/authorino/refs/heads/main/install/crd/authorino.kuadrant.io_authconfigs.yaml", + fileName: "authorino.kuadrant.io_authconfigs.yaml", + }, + { + url: "https://raw.githubusercontent.com/istio/istio/refs/heads/master/manifests/charts/base/files/crd-all.gen.yaml", + fileName: "istio.yaml", + }, + { + url: "https://raw.githubusercontent.com/openshift/api/refs/heads/master/route/v1/zz_generated.crd-manifests/routes-Default.crd.yaml", + fileName: "route.openshift.io_routes.yaml", + }, + } +) func TestControllers(t *testing.T) { RegisterFailHandler(Fail) @@ -50,12 +83,63 @@ func TestControllers(t *testing.T) { } var _ = BeforeSuite(func() { + var err error + logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + schm := apiruntime.NewScheme() + + err = corev1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = appsv1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = rbacv1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = userv1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = routev1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = istiosecurityv1beta1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = istioclientv1beta1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + authorinoScheme := &scheme.Builder{GroupVersion: schema.GroupVersion{Group: "authorino.kuadrant.io", Version: "v1beta3"}} + + err = authorinoScheme.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + err = modelregistryv1alpha1.AddToScheme(schm) + Expect(err).NotTo(HaveOccurred()) + + //+kubebuilder:scaffold:scheme + + // Download CRDs + + if err := os.MkdirAll(filepath.Dir(testCRDLocalPath), 0755); err != nil { + Fail(err.Error()) + } + + for _, crd := range remoteCRDs { + if err := utils.DownloadFile(crd.url, filepath.Join(testCRDLocalPath, crd.fileName)); err != nil { + Fail(err.Error()) + } + } + By("bootstrapping test environment") useExistingCluster := false testEnv = &envtest.Environment{ - CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, + Scheme: schm, + CRDDirectoryPaths: []string{ + filepath.Join("..", "..", "config", "crd", "bases"), + filepath.Join("testdata", "crd"), + }, ErrorIfCRDPathMissing: true, // The BinaryAssetsDirectory is only required if you want to run the tests directly @@ -67,22 +151,14 @@ var _ = BeforeSuite(func() { fmt.Sprintf("1.28.0-%s-%s", runtime.GOOS, runtime.GOARCH)), UseExistingCluster: &useExistingCluster, } - - var err error // cfg is defined in this file globally. cfg, err = testEnv.Start() Expect(err).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) - err = modelregistryv1alpha1.AddToScheme(scheme.Scheme) - Expect(err).NotTo(HaveOccurred()) - - //+kubebuilder:scaffold:scheme - - k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + k8sClient, err = client.New(cfg, client.Options{Scheme: schm}) Expect(err).NotTo(HaveOccurred()) Expect(k8sClient).NotTo(BeNil()) - }) var _ = AfterSuite(func() { diff --git a/internal/controller/testdata/crd/.gitignore b/internal/controller/testdata/crd/.gitignore new file mode 100644 index 0000000..5142034 --- /dev/null +++ b/internal/controller/testdata/crd/.gitignore @@ -0,0 +1,4 @@ +# CRDs files will be fetched from the test suite + +*.yaml +!*_mocked.yaml diff --git a/internal/controller/testdata/crd/group.openshift.io_mocked.yaml b/internal/controller/testdata/crd/group.openshift.io_mocked.yaml new file mode 100644 index 0000000..be57c88 --- /dev/null +++ b/internal/controller/testdata/crd/group.openshift.io_mocked.yaml @@ -0,0 +1,40 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + creationTimestamp: null + name: groups.user.openshift.io +spec: + group: user.openshift.io + names: + kind: Group + listKind: GroupList + plural: groups + shortNames: [] + singular: group + scope: Cluster + versions: + - name: v1 + schema: + openAPIV3Schema: + description: Group is a list of users. + properties: + apiVersion: + type: string + kind: + type: string + metadata: + type: object + spec: + type: object + status: + type: object + required: [] + type: object + served: true + storage: true +status: + acceptedNames: + kind: "" + plural: "" + conditions: [] + storedVersions: [] diff --git a/internal/utils/io.go b/internal/utils/io.go new file mode 100644 index 0000000..aea3847 --- /dev/null +++ b/internal/utils/io.go @@ -0,0 +1,30 @@ +package utils + +import ( + "io" + "net/http" + "os" +) + +func DownloadFile(url string, path string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + + defer resp.Body.Close() + + file, err := os.Create(path) + if err != nil { + return err + } + + defer file.Close() + + _, err = io.Copy(file, resp.Body) + if err != nil { + return err + } + + return nil +}