From 3b377247d25050b78732001eccb77dd3189b440f Mon Sep 17 00:00:00 2001 From: huanghaoyuanhhy Date: Fri, 24 Jan 2025 11:30:32 +0800 Subject: [PATCH] *: cp d39b7d6 and impl mtls auth (#519) Signed-off-by: huanghaoyuanhhy --- .github/workflows/main.yaml | 1 - README.md | 79 +++++----- cmd/backup_yaml.go | 23 +-- configs/backup.yaml | 18 ++- core/backup_context.go | 41 +----- core/client/cfg.go | 71 --------- core/client/cfg_test.go | 49 ------- core/client/grpc.go | 265 +++++++++++++++++++++------------- core/client/grpc_test.go | 33 ++++- core/client/restful.go | 29 +++- core/client/restful_test.go | 15 ++ core/paramtable/base_table.go | 34 +++-- core/paramtable/params.go | 56 +++++-- 13 files changed, 364 insertions(+), 350 deletions(-) delete mode 100644 core/client/cfg.go delete mode 100644 core/client/cfg_test.go create mode 100644 core/client/restful_test.go diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c871138..f3a91aa 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -441,7 +441,6 @@ jobs: shell: bash run: | yq -i '.log.level = "debug"' configs/backup.yaml - yq -i '.milvus.authorizationEnabled = true' configs/backup.yaml yq -i '.minio.bucketName = "milvus-bucket"' configs/backup.yaml yq -i '.minio.rootPath = "file"' configs/backup.yaml yq -i '.minio.backupPort = 9010' configs/backup.yaml diff --git a/README.md b/README.md index 8d54147..ab69b5f 100644 --- a/README.md +++ b/README.md @@ -92,45 +92,46 @@ http://localhost:8080/api/v1/docs/index.html Below is a summary of the configurations supported in `backup.yaml`: -| **Section** | **Field** | **Description** | **Default/Example** | -|--------------------|-----------------------------------|-------------------------------------------------------------------------------------------------------|-----------------------------------------------| -| `log` | `level` | Logging level. Supported: `debug`, `info`, `warn`, `error`, `panic`, `fatal`. | `info` | -| | `console` | Whether to print logs to the console. | `true` | -| | `file.rootPath` | Path to the log file. | `logs/backup.log` | -| `http` | `simpleResponse` | Whether to enable simple HTTP responses. | `true` | -| `milvus` | `address` | Milvus proxy address. | `localhost` | -| | `port` | Milvus proxy port. | `19530` | -| | `authorizationEnabled` | Whether to enable authorization. | `false` | -| | `tlsMode` | TLS mode (0: none, 1: one-way, 2: two-way). | `0` | -| | `user` | Username for Milvus. | `root` | -| | `password` | Password for Milvus. | `Milvus` | -| | `tlsCertPath` | Path to your certificate file | `/path/to/certificate` | -| | `serverName ` | Server name | `localhost` | -| `minio` | `storageType` | Storage type for Milvus (e.g., `local`, `minio`, `s3`, `aws`, `gcp`, `ali(aliyun)`, `azure`, `tc(tencent)`). | `minio` | -| | `address` | MinIO/S3 address. | `localhost` | -| | `port` | MinIO/S3 port. | `9000` | -| | `accessKeyID` | MinIO/S3 access key ID. | `minioadmin` | -| | `secretAccessKey` | MinIO/S3 secret access key. | `minioadmin` | -| | `useSSL` | Whether to use SSL for MinIO/S3. | `false` | -| | `bucketName` | Bucket name in MinIO/S3. | `a-bucket` | -| | `rootPath` | Storage root path in MinIO/S3. | `files` | -| `minio (backup)` | `backupStorageType` | Backup storage type (e.g., `local`, `minio`, `s3`, `aws`, `gcp`, `ali(aliyun)`, `azure`, `tc(tencent)`). | `minio` | -| | `backupAddress` | Address of backup storage. | `localhost` | -| | `backupPort` | Port of backup storage. | `9000` | -| | `backupUseSSL` | Whether to use SSL for backup storage. | `false` | -| | `backupAccessKeyID` | Backup storage access key ID. | `minioadmin` | -| | `backupSecretAccessKey` | Backup storage secret access key. | `minioadmin` | -| | `backupBucketName` | Bucket name for backups. | `a-bucket` | -| | `backupRootPath` | Root path to store backup data. | `backup` | -| | `crossStorage` | Enable cross-storage backups (e.g., MinIO to AWS S3). | `false` | -| `backup` | `maxSegmentGroupSize` | Maximum segment group size for backups. | `2G` | -| | `parallelism.backupCollection` | Collection-level parallelism for backup. | `4` | -| | `parallelism.copydata` | Thread pool size for copying data. | `128` | -| | `parallelism.restoreCollection` | Collection-level parallelism for restore. | `2` | -| | `keepTempFiles` | Whether to keep temporary files during restore (for debugging). | `false` | -| | `gcPause.enable` | Pause Milvus garbage collection during backup. | `true` | -| | `gcPause.seconds` | Duration to pause garbage collection (in seconds). | `7200` | -| | `gcPause.address` | Address for Milvus garbage collection API. | `http://localhost:9091` | +| **Section** | **Field** | **Description** | **Default/Example** | +|--------------------|---------------------------------|--------------------------------------------------------------------------------------------------------------|-------------------------| +| `log` | `level` | Logging level. Supported: `debug`, `info`, `warn`, `error`, `panic`, `fatal`. | `info` | +| | `console` | Whether to print logs to the console. | `true` | +| | `file.rootPath` | Path to the log file. | `logs/backup.log` | +| `http` | `simpleResponse` | Whether to enable simple HTTP responses. | `true` | +| `milvus` | `address` | Milvus proxy address. | `localhost` | +| | `port` | Milvus proxy port. | `19530` | +| | `user` | Username for Milvus. | `root` | +| | `password` | Password for Milvus. | `Milvus` | +| | `tlsMode` | TLS mode (0: none, 1: one-way, 2: two-way/mtls). | `0` | +| | `caCertPath` | Path to your ca certificate file | `/path/to/certificate` | +| | `serverName ` | Server name | `localhost` | +| | `mtlsCertPath` | Path to your mtls certificate file | `/path/to/certificate` | +| | `mtlsKeyPath ` | Path to your mtls key file | `/path/to/key` | +| `minio` | `storageType` | Storage type for Milvus (e.g., `local`, `minio`, `s3`, `aws`, `gcp`, `ali(aliyun)`, `azure`, `tc(tencent)`). | `minio` | +| | `address` | MinIO/S3 address. | `localhost` | +| | `port` | MinIO/S3 port. | `9000` | +| | `accessKeyID` | MinIO/S3 access key ID. | `minioadmin` | +| | `secretAccessKey` | MinIO/S3 secret access key. | `minioadmin` | +| | `useSSL` | Whether to use SSL for MinIO/S3. | `false` | +| | `bucketName` | Bucket name in MinIO/S3. | `a-bucket` | +| | `rootPath` | Storage root path in MinIO/S3. | `files` | +| `minio (backup)` | `backupStorageType` | Backup storage type (e.g., `local`, `minio`, `s3`, `aws`, `gcp`, `ali(aliyun)`, `azure`, `tc(tencent)`). | `minio` | +| | `backupAddress` | Address of backup storage. | `localhost` | +| | `backupPort` | Port of backup storage. | `9000` | +| | `backupUseSSL` | Whether to use SSL for backup storage. | `false` | +| | `backupAccessKeyID` | Backup storage access key ID. | `minioadmin` | +| | `backupSecretAccessKey` | Backup storage secret access key. | `minioadmin` | +| | `backupBucketName` | Bucket name for backups. | `a-bucket` | +| | `backupRootPath` | Root path to store backup data. | `backup` | +| | `crossStorage` | Enable cross-storage backups (e.g., MinIO to AWS S3). | `false` | +| `backup` | `maxSegmentGroupSize` | Maximum segment group size for backups. | `2G` | +| | `parallelism.backupCollection` | Collection-level parallelism for backup. | `4` | +| | `parallelism.copydata` | Thread pool size for copying data. | `128` | +| | `parallelism.restoreCollection` | Collection-level parallelism for restore. | `2` | +| | `keepTempFiles` | Whether to keep temporary files during restore (for debugging). | `false` | +| | `gcPause.enable` | Pause Milvus garbage collection during backup. | `true` | +| | `gcPause.seconds` | Duration to pause garbage collection (in seconds). | `7200` | +| | `gcPause.address` | Address for Milvus garbage collection API. | `http://localhost:9091` | For more details, refer to the [backup.yaml](configs/backup.yaml) configuration file. diff --git a/cmd/backup_yaml.go b/cmd/backup_yaml.go index d16055a..f5582fd 100644 --- a/cmd/backup_yaml.go +++ b/cmd/backup_yaml.go @@ -35,14 +35,15 @@ type YAMLConFig struct { } `yaml:"http"` } `yaml:"log"` Milvus struct { - Address string `yaml:"address"` - Port int `yaml:"port"` - AuthorizationEnabled bool `yaml:"authorizationEnabled"` - TlsMode int `yaml:"tlsMode"` - User string `yaml:"user"` - Password string `yaml:"password"` - TlsCertPath string `yaml:"tlsCertPath"` - ServerName string `yaml:"serverName"` + Address string `yaml:"address"` + Port int `yaml:"port"` + User string `yaml:"user"` + Password string `yaml:"password"` + TlsMode int `yaml:"tlsMode"` + CACertPath string `yaml:"caCertPath"` + ServerName string `yaml:"serverName"` + mtlsCertPath string `yaml:"mtlsCertPath"` + mtlsKeyPath string `yaml:"mtlsKeyPath"` } `yaml:"milvus"` Minio struct { Address string `yaml:"address"` @@ -77,13 +78,13 @@ func printParams(base *paramtable.BackupParams) { yml.Milvus.Address = base.LoadWithDefault("milvus.address", "localhost") yml.Milvus.Port = base.ParseIntWithDefault("milvus.port", 19530) - yml.Milvus.AuthorizationEnabled = base.ParseBool("milvus.authorizationEnabled", false) yml.Milvus.TlsMode = base.ParseIntWithDefault("milvus.tlsMode", 0) yml.Milvus.User = base.BaseTable.LoadWithDefault("milvus.user", "") yml.Milvus.Password = base.BaseTable.LoadWithDefault("milvus.password", "") - yml.Milvus.TlsCertPath = base.BaseTable.LoadWithDefault("milvus.tlsCertPath", "") + yml.Milvus.CACertPath = base.BaseTable.LoadWithDefault("milvus.tlsCertPath", "") yml.Milvus.ServerName = base.BaseTable.LoadWithDefault("milvus.serverName", "localhost") - + yml.Milvus.mtlsCertPath = base.BaseTable.LoadWithDefault("milvus.mtlsCertPath", "") + yml.Milvus.mtlsKeyPath = base.BaseTable.LoadWithDefault("milvus.mtlsKeyPath", "") yml.Minio.Address = base.LoadWithDefault("minio.address", "localhost") yml.Minio.Port = base.ParseIntWithDefault("minio.port", 9000) yml.Minio.AccessKeyID = base.BaseTable.LoadWithDefault("minio.accessKeyID", "") diff --git a/configs/backup.yaml b/configs/backup.yaml index 3aa23ed..27a354b 100644 --- a/configs/backup.yaml +++ b/configs/backup.yaml @@ -12,14 +12,21 @@ http: milvus: address: localhost port: 19530 - authorizationEnabled: false - # tls mode values [0, 1, 2] - # 0 is close, 1 is one-way authentication, 2 is two-way authentication. - tlsMode: 0 user: "root" password: "Milvus" - tlsCertPath: "" + + # tls mode values [0, 1, 2] + # 0 is close, 1 is one-way authentication, 2 is mutual authentication + tlsMode: 0 + # tls cert path for validate server, will be used when tlsMode is 1 or 2 + caCertPath: "" serverName: "" + # mutual tls cert path, for server to validate client. + # Will be used when tlsMode is 2 + # for backward compatibility, if not set, will use tlsmode 1. + # WARN: in future version, if user set tlsmode 2, but not set mtlsCertPath, will cause error. + mtlsCertPath: "" + mtlsKeyPath: "" # Related configuration of minio, which is responsible for data persistence for Milvus. minio: @@ -43,6 +50,7 @@ minio: backupSecretAccessKey: minioadmin # MinIO/S3 encryption string backupBucketName: "a-bucket" # Bucket name to store backup data. Backup data will store to backupBucketName/backupRootPath backupRootPath: "backup" # Rootpath to store backup data. Backup data will store to backupBucketName/backupRootPath + backupUseSSL: false # Access to MinIO/S3 with SSL # If you need to back up or restore data between two different storage systems, direct client-side copying is not supported. # Set this option to true to enable data transfer through Milvus Backup. diff --git a/core/backup_context.go b/core/backup_context.go index b204bd5..3a85645 100644 --- a/core/backup_context.go +++ b/core/backup_context.go @@ -64,39 +64,8 @@ type BackupContext struct { bulkinsertWorkerPools sync.Map } -func paramsToCfg(params *paramtable.BackupParams) (*client.Cfg, error) { - ep := params.MilvusCfg.Address + ":" + params.MilvusCfg.Port - log.Debug("Start Milvus client", zap.String("endpoint", ep)) - - var enableTLS bool - switch params.MilvusCfg.TLSMode { - case 0: - enableTLS = false - case 1, 2: - enableTLS = true - default: - log.Error("milvus.TLSMode is illegal, support value 0, 1, 2") - return nil, fmt.Errorf("milvus.TLSMode is illegal, support value 0, 1, 2") - } - - cfg := &client.Cfg{ - Host: ep, - EnableTLS: enableTLS, - Username: params.MilvusCfg.User, - Password: params.MilvusCfg.Password, - } - - return cfg, nil -} - func CreateGrpcClient(params *paramtable.BackupParams) (client.Grpc, error) { - cfg, err := paramsToCfg(params) - if err != nil { - log.Error("failed to create milvus client", zap.Error(err)) - return nil, fmt.Errorf("failed to create milvus client: %w", err) - } - - cli, err := client.NewGrpc(cfg) + cli, err := client.NewGrpc(¶ms.MilvusCfg) if err != nil { log.Error("failed to create milvus client", zap.Error(err)) return nil, fmt.Errorf("failed to create milvus client: %w", err) @@ -105,13 +74,7 @@ func CreateGrpcClient(params *paramtable.BackupParams) (client.Grpc, error) { } func CreateRestfulClient(params *paramtable.BackupParams) (client.RestfulBulkInsert, error) { - cfg, err := paramsToCfg(params) - if err != nil { - log.Error("failed to create restful client", zap.Error(err)) - return nil, fmt.Errorf("failed to create restful client: %w", err) - } - - cli, err := client.NewRestful(cfg) + cli, err := client.NewRestful(¶ms.MilvusCfg) if err != nil { log.Error("failed to create restful client", zap.Error(err)) return nil, fmt.Errorf("failed to create restful client: %w", err) diff --git a/core/client/cfg.go b/core/client/cfg.go deleted file mode 100644 index 12d42c8..0000000 --- a/core/client/cfg.go +++ /dev/null @@ -1,71 +0,0 @@ -package client - -import ( - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "net/url" - - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" -) - -type Cfg struct { - Host string // Remote address, "localhost:19530". - EnableTLS bool // Enable TLS for connection. - DialOpts []grpc.DialOption - - Username string // Username for auth. - Password string // Password for auth. -} - -func (cfg *Cfg) parseGrpcAuth() string { - if cfg.Username != "" || cfg.Password != "" { - value := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", cfg.Username, cfg.Password))) - return value - } - - return "" -} - -func (cfg *Cfg) parseGrpc() (*url.URL, []grpc.DialOption, error) { - remoteURL := &url.URL{Host: cfg.Host} - // Remote Host should never be empty. - if remoteURL.Host == "" { - return nil, nil, errors.New("client: empty remote host of milvus address") - } - - opts := cfg.DialOpts - if cfg.EnableTLS { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - - if remoteURL.Port() == "" && cfg.EnableTLS { - remoteURL.Host += ":443" - } - - return remoteURL, opts, nil -} - -func (cfg *Cfg) parseRestfulAuth() string { - if cfg.Username != "" || cfg.Password != "" { - return fmt.Sprintf("%s:%s", cfg.Username, cfg.Password) - } - - return "" -} - -func (cfg *Cfg) parseRestful() *url.URL { - u := &url.URL{Host: cfg.Host} - if cfg.EnableTLS { - u.Scheme = "https" - } else { - u.Scheme = "http" - } - - return u -} diff --git a/core/client/cfg_test.go b/core/client/cfg_test.go deleted file mode 100644 index c3f7ec3..0000000 --- a/core/client/cfg_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package client - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCfg_ParseAuth(t *testing.T) { - cfg := &Cfg{Username: "username", Password: "password"} - got := cfg.parseGrpcAuth() - assert.Equal(t, "dXNlcm5hbWU6cGFzc3dvcmQ=", got) - - cfg = &Cfg{} - got = cfg.parseGrpcAuth() - assert.Equal(t, "", got) -} - -func TestCfg_ParseGrpc(t *testing.T) { - cfg := &Cfg{Host: ""} - _, _, err := cfg.parseGrpc() - assert.Error(t, err) - - cfg = &Cfg{Host: "localhost:19530"} - u, _, err := cfg.parseGrpc() - assert.NoError(t, err) - assert.Equal(t, "localhost:19530", u.Host) - - cfg = &Cfg{Host: "localhost", EnableTLS: true} - u, _, err = cfg.parseGrpc() - assert.NoError(t, err) - assert.Equal(t, "localhost:443", u.Host) -} - -func TestCfg_ParseRestful(t *testing.T) { - cfg := &Cfg{Host: "localhost:19530", EnableTLS: false} - u := cfg.parseRestful() - assert.Equal(t, "http", u.Scheme) - - cfg = &Cfg{Host: "localhost:19530", EnableTLS: true} - u = cfg.parseRestful() - assert.Equal(t, "https", u.Scheme) -} - -func TestCfg_ParseRestfulAuth(t *testing.T) { - cfg := &Cfg{Username: "username", Password: "password"} - got := cfg.parseRestfulAuth() - assert.Equal(t, "username:password", got) -} diff --git a/core/client/grpc.go b/core/client/grpc.go index eb2acde..20fc3e3 100644 --- a/core/client/grpc.go +++ b/core/client/grpc.go @@ -2,6 +2,9 @@ package client import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" "errors" "fmt" "math" @@ -14,13 +17,17 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/zilliztech/milvus-backup/core/paramtable" "github.com/zilliztech/milvus-backup/internal/log" "github.com/zilliztech/milvus-backup/version" ) @@ -93,7 +100,7 @@ const ( databaseHeader = `dbname` ) -func ok(status *commonpb.Status) bool { return status.GetCode() == 0 } +func statusOk(status *commonpb.Status) bool { return status.GetCode() == 0 } func checkResponse(resp any, err error) error { if err != nil { @@ -102,11 +109,11 @@ func checkResponse(resp any, err error) error { switch resp.(type) { case interface{ GetStatus() *commonpb.Status }: - if !ok(resp.(interface{ GetStatus() *commonpb.Status }).GetStatus()) { + if !statusOk(resp.(interface{ GetStatus() *commonpb.Status }).GetStatus()) { return fmt.Errorf("client: operation failed: %v", resp.(interface{ GetStatus() *commonpb.Status }).GetStatus()) } case *commonpb.Status: - if !ok(resp.(*commonpb.Status)) { + if !statusOk(resp.(*commonpb.Status)) { return fmt.Errorf("client: operation failed: %v", resp.(*commonpb.Status)) } } @@ -116,11 +123,10 @@ func checkResponse(resp any, err error) error { var _ Grpc = (*GrpcClient)(nil) type GrpcClient struct { - cfg *Cfg - conn *grpc.ClientConn srv milvuspb.MilvusServiceClient + user string auth string // get from connect @@ -129,51 +135,116 @@ type GrpcClient struct { flags uint64 } -func NewGrpc(cfg *Cfg) (*GrpcClient, error) { - addr, opts, err := cfg.parseGrpc() +func grpcAuth(username, password string) string { + if username != "" || password != "" { + value := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))) + return value + } + + return "" +} + +func transCred(cfg *paramtable.MilvusConfig) (credentials.TransportCredentials, error) { + if cfg.TLSMode < 0 || cfg.TLSMode > 2 { + return nil, errors.New("milvus.TLSMode is illegal, support value 0, 1, 2") + } + + // tls mode 0 disable tls + if cfg.TLSMode == 0 { + return insecure.NewCredentials(), nil + } + + // tls mode 1, 2 + + // validate server cert + tlsCfg := &tls.Config{ServerName: cfg.ServerName} + if cfg.CACertPath != "" { + b, err := os.ReadFile(cfg.CACertPath) + if err != nil { + return nil, fmt.Errorf("client: read ca cert %w", err) + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("client: failed to append ca certificates") + } + + tlsCfg.RootCAs = cp + } + + // tls mode 1, server tls + if cfg.TLSMode == 1 { + return credentials.NewTLS(tlsCfg), nil + } + + // tls mode 2, mutual tls + // use mTLS but key/cert path not set, for backward compatibility, use server tls instead + // WARN: this behavior will be removed after v0.6.0 + if cfg.TLSMode == 2 { + if cfg.MTLSKeyPath == "" || cfg.MTLSCertPath == "" { + log.Warn("client: mutual tls enabled but key/cert path not set! will use server tls instead") + return credentials.NewTLS(tlsCfg), nil + } + + // use mTLS + cert, err := tls.LoadX509KeyPair(cfg.MTLSCertPath, cfg.MTLSKeyPath) + if err != nil { + return nil, fmt.Errorf("client: load client cert: %w", err) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + + return credentials.NewTLS(tlsCfg), nil +} + +func NewGrpc(cfg *paramtable.MilvusConfig) (*GrpcClient, error) { + host := fmt.Sprintf("%s:%s", cfg.Address, cfg.Port) + log.Info("New milvus grpc client", zap.String("host", host)) + + auth := grpcAuth(cfg.User, cfg.Password) + + cerd, err := transCred(cfg) if err != nil { - return nil, fmt.Errorf("client: parse address failed: %w", err) + return nil, fmt.Errorf("client: create transport credentials: %w", err) } - auth := cfg.parseGrpcAuth() - opts = append(opts, defaultDialOpt()...) - conn, err := grpc.NewClient(addr.Host, opts...) + opts := defaultDialOpt() + opts = append(opts, grpc.WithTransportCredentials(cerd)) + conn, err := grpc.NewClient(host, opts...) if err != nil { return nil, fmt.Errorf("client: create grpc client failed: %w", err) } srv := milvuspb.NewMilvusServiceClient(conn) cli := &GrpcClient{ - cfg: cfg, - conn: conn, srv: srv, + user: cfg.User, auth: auth, } return cli, nil } -func (m *GrpcClient) hasFlags(flags uint64) bool { return (m.flags & flags) > 0 } -func (m *GrpcClient) SupportMultiDatabase() bool { return !m.hasFlags(disableDatabase) } +func (g *GrpcClient) hasFlags(flags uint64) bool { return (g.flags & flags) > 0 } +func (g *GrpcClient) SupportMultiDatabase() bool { return !g.hasFlags(disableDatabase) } -func (m *GrpcClient) newCtx(ctx context.Context) context.Context { - if m.auth != "" { - return metadata.AppendToOutgoingContext(ctx, authorizationHeader, m.auth) +func (g *GrpcClient) newCtx(ctx context.Context) context.Context { + if g.auth != "" { + return metadata.AppendToOutgoingContext(ctx, authorizationHeader, g.auth) } - if m.identifier != "" { - return metadata.AppendToOutgoingContext(ctx, identifierHeader, m.identifier) + if g.identifier != "" { + return metadata.AppendToOutgoingContext(ctx, identifierHeader, g.identifier) } return ctx } -func (m *GrpcClient) newCtxWithDB(ctx context.Context, db string) context.Context { - ctx = m.newCtx(ctx) +func (g *GrpcClient) newCtxWithDB(ctx context.Context, db string) context.Context { + ctx = g.newCtx(ctx) return metadata.AppendToOutgoingContext(ctx, databaseHeader, db) } -func (m *GrpcClient) connect(ctx context.Context) error { +func (g *GrpcClient) connect(ctx context.Context) error { hostName, err := os.Hostname() if err != nil { return fmt.Errorf("get hostname failed: %w", err) @@ -184,39 +255,39 @@ func (m *GrpcClient) connect(ctx context.Context) error { SdkType: "Backup Tool Custom SDK", SdkVersion: version.Version, LocalTime: time.Now().String(), - User: m.cfg.Username, + User: g.user, Host: hostName, }, } - resp, err := m.srv.Connect(ctx, connReq) + resp, err := g.srv.Connect(ctx, connReq) if err != nil { s, ok := status.FromError(err) if ok { if s.Code() == codes.Unimplemented { log.Info("The server does not support the Connect API, skipping") - m.flags |= disableDatabase + g.flags |= disableDatabase } } return fmt.Errorf("client: connect to server failed: %w", err) } - if !ok(resp.GetStatus()) { + if !statusOk(resp.GetStatus()) { return fmt.Errorf("client: connect to server failed: %v", resp.GetStatus()) } - m.serverVersion = resp.GetServerInfo().GetBuildTags() - m.identifier = strconv.FormatInt(resp.GetIdentifier(), 10) + g.serverVersion = resp.GetServerInfo().GetBuildTags() + g.identifier = strconv.FormatInt(resp.GetIdentifier(), 10) return nil } -func (m *GrpcClient) Close() error { - return m.conn.Close() +func (g *GrpcClient) Close() error { + return g.conn.Close() } -func (m *GrpcClient) GetVersion(ctx context.Context) (string, error) { - ctx = m.newCtx(ctx) - resp, err := m.srv.GetVersion(ctx, &milvuspb.GetVersionRequest{}) +func (g *GrpcClient) GetVersion(ctx context.Context) (string, error) { + ctx = g.newCtx(ctx) + resp, err := g.srv.GetVersion(ctx, &milvuspb.GetVersionRequest{}) if err := checkResponse(resp, err); err != nil { return "", fmt.Errorf("client: get version failed: %w", err) } @@ -224,13 +295,13 @@ func (m *GrpcClient) GetVersion(ctx context.Context) (string, error) { return resp.GetVersion(), nil } -func (m *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error { - ctx = m.newCtx(ctx) - if m.hasFlags(disableDatabase) { +func (g *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error { + ctx = g.newCtx(ctx) + if g.hasFlags(disableDatabase) { return errors.New("client: the server does not support database") } - resp, err := m.srv.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{DbName: dbName}) + resp, err := g.srv.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{DbName: dbName}) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: create database failed: %w", err) } @@ -238,13 +309,13 @@ func (m *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error { return nil } -func (m *GrpcClient) ListDatabases(ctx context.Context) ([]string, error) { - ctx = m.newCtx(ctx) - if m.hasFlags(disableDatabase) { +func (g *GrpcClient) ListDatabases(ctx context.Context) ([]string, error) { + ctx = g.newCtx(ctx) + if g.hasFlags(disableDatabase) { return nil, errors.New("client: the server does not support database") } - resp, err := m.srv.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) + resp, err := g.srv.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: list databases failed: %w", err) } @@ -252,9 +323,9 @@ func (m *GrpcClient) ListDatabases(ctx context.Context) ([]string, error) { return resp.GetDbNames(), nil } -func (m *GrpcClient) DescribeCollection(ctx context.Context, db, collName string) (*milvuspb.DescribeCollectionResponse, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{CollectionName: collName}) +func (g *GrpcClient) DescribeCollection(ctx context.Context, db, collName string) (*milvuspb.DescribeCollectionResponse, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{CollectionName: collName}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: describe collection failed: %w", err) } @@ -262,9 +333,9 @@ func (m *GrpcClient) DescribeCollection(ctx context.Context, db, collName string return resp, nil } -func (m *GrpcClient) ListIndex(ctx context.Context, db, collName string) ([]*milvuspb.IndexDescription, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{CollectionName: collName}) +func (g *GrpcClient) ListIndex(ctx context.Context, db, collName string) ([]*milvuspb.IndexDescription, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{CollectionName: collName}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: describe index failed: %w", err) } @@ -272,18 +343,18 @@ func (m *GrpcClient) ListIndex(ctx context.Context, db, collName string) ([]*mil return resp.IndexDescriptions, nil } -func (m *GrpcClient) ShowPartitions(ctx context.Context, db, collName string) (*milvuspb.ShowPartitionsResponse, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{CollectionName: collName}) +func (g *GrpcClient) ShowPartitions(ctx context.Context, db, collName string) (*milvuspb.ShowPartitionsResponse, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{CollectionName: collName}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: show partitions failed: %w", err) } return resp, nil } -func (m *GrpcClient) GetLoadingProgress(ctx context.Context, db, collName string, partitionNames []string) (int64, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{CollectionName: collName, PartitionNames: partitionNames}) +func (g *GrpcClient) GetLoadingProgress(ctx context.Context, db, collName string, partitionNames []string) (int64, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{CollectionName: collName, PartitionNames: partitionNames}) if err != nil { return 0, fmt.Errorf("client: get loading progress failed: %w", err) } @@ -291,9 +362,9 @@ func (m *GrpcClient) GetLoadingProgress(ctx context.Context, db, collName string return resp.GetProgress(), nil } -func (m *GrpcClient) GetPersistentSegmentInfo(ctx context.Context, db, collName string) ([]*milvuspb.PersistentSegmentInfo, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.GetPersistentSegmentInfo(ctx, &milvuspb.GetPersistentSegmentInfoRequest{CollectionName: collName}) +func (g *GrpcClient) GetPersistentSegmentInfo(ctx context.Context, db, collName string) ([]*milvuspb.PersistentSegmentInfo, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.GetPersistentSegmentInfo(ctx, &milvuspb.GetPersistentSegmentInfoRequest{CollectionName: collName}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: get persistent segment info failed: %w", err) } @@ -301,9 +372,9 @@ func (m *GrpcClient) GetPersistentSegmentInfo(ctx context.Context, db, collName return resp.GetInfos(), nil } -func (m *GrpcClient) Flush(ctx context.Context, db, collName string) (*milvuspb.FlushResponse, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.Flush(ctx, &milvuspb.FlushRequest{CollectionNames: []string{collName}}) +func (g *GrpcClient) Flush(ctx context.Context, db, collName string) (*milvuspb.FlushResponse, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.Flush(ctx, &milvuspb.FlushRequest{CollectionNames: []string{collName}}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: flush failed: %w", err) } @@ -312,7 +383,7 @@ func (m *GrpcClient) Flush(ctx context.Context, db, collName string) (*milvuspb. ids := segmentIDs.GetData() if has && len(ids) > 0 { flushed := func() bool { - getFlushResp, err := m.srv.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + getFlushResp, err := g.srv.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ SegmentIDs: ids, FlushTs: resp.GetCollFlushTs()[collName], CollectionName: collName, @@ -337,9 +408,9 @@ func (m *GrpcClient) Flush(ctx context.Context, db, collName string) (*milvuspb. return resp, nil } -func (m *GrpcClient) ListCollections(ctx context.Context, db string) (*milvuspb.ShowCollectionsResponse, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) +func (g *GrpcClient) ListCollections(ctx context.Context, db string) (*milvuspb.ShowCollectionsResponse, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: list collections failed: %w", err) } @@ -347,9 +418,9 @@ func (m *GrpcClient) ListCollections(ctx context.Context, db string) (*milvuspb. return resp, nil } -func (m *GrpcClient) HasCollection(ctx context.Context, db, collName string) (bool, error) { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.HasCollection(ctx, &milvuspb.HasCollectionRequest{CollectionName: collName}) +func (g *GrpcClient) HasCollection(ctx context.Context, db, collName string) (bool, error) { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.HasCollection(ctx, &milvuspb.HasCollectionRequest{CollectionName: collName}) if err := checkResponse(resp, err); err != nil { return false, fmt.Errorf("client: has collection failed: %w", err) } @@ -365,8 +436,8 @@ type GrpcBulkInsertInput struct { IsL0 bool } -func (m *GrpcClient) BulkInsert(ctx context.Context, input GrpcBulkInsertInput) (int64, error) { - ctx = m.newCtxWithDB(ctx, input.DB) +func (g *GrpcClient) BulkInsert(ctx context.Context, input GrpcBulkInsertInput) (int64, error) { + ctx = g.newCtxWithDB(ctx, input.DB) var opts []*commonpb.KeyValuePair if input.EndTime > 0 { opts = append(opts, &commonpb.KeyValuePair{Key: "end_time", Value: strconv.FormatInt(input.EndTime, 10)}) @@ -385,7 +456,7 @@ func (m *GrpcClient) BulkInsert(ctx context.Context, input GrpcBulkInsertInput) Files: input.Paths, Options: opts, } - resp, err := m.srv.Import(ctx, in) + resp, err := g.srv.Import(ctx, in) if err := checkResponse(resp, err); err != nil { return 0, fmt.Errorf("client: bulk insert failed: %w", err) } @@ -393,9 +464,9 @@ func (m *GrpcClient) BulkInsert(ctx context.Context, input GrpcBulkInsertInput) return resp.GetTasks()[0], nil } -func (m *GrpcClient) GetBulkInsertState(ctx context.Context, taskID int64) (*milvuspb.GetImportStateResponse, error) { - ctx = m.newCtx(ctx) - resp, err := m.srv.GetImportState(ctx, &milvuspb.GetImportStateRequest{Task: taskID}) +func (g *GrpcClient) GetBulkInsertState(ctx context.Context, taskID int64) (*milvuspb.GetImportStateResponse, error) { + ctx = g.newCtx(ctx) + resp, err := g.srv.GetImportState(ctx, &milvuspb.GetImportStateRequest{Task: taskID}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: get bulk insert state failed: %w", err) } @@ -411,8 +482,8 @@ type CreateCollectionInput struct { partitionNum int64 } -func (m *GrpcClient) CreateCollection(ctx context.Context, input CreateCollectionInput) error { - ctx = m.newCtxWithDB(ctx, input.DB) +func (g *GrpcClient) CreateCollection(ctx context.Context, input CreateCollectionInput) error { + ctx = g.newCtxWithDB(ctx, input.DB) bs, err := proto.Marshal(input.Schema) if err != nil { return fmt.Errorf("client: create collection failed: %w", err) @@ -425,7 +496,7 @@ func (m *GrpcClient) CreateCollection(ctx context.Context, input CreateCollectio NumPartitions: input.partitionNum, } - resp, err := m.srv.CreateCollection(ctx, in) + resp, err := g.srv.CreateCollection(ctx, in) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: create collection failed: %w", err) } @@ -433,9 +504,9 @@ func (m *GrpcClient) CreateCollection(ctx context.Context, input CreateCollectio return nil } -func (m *GrpcClient) DropCollection(ctx context.Context, db string, collectionName string) error { - ctx = m.newCtxWithDB(ctx, db) - resp, err := m.srv.DropCollection(ctx, &milvuspb.DropCollectionRequest{CollectionName: collectionName}) +func (g *GrpcClient) DropCollection(ctx context.Context, db string, collectionName string) error { + ctx = g.newCtxWithDB(ctx, db) + resp, err := g.srv.DropCollection(ctx, &milvuspb.DropCollectionRequest{CollectionName: collectionName}) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: drop collection failed: %w", err) } @@ -443,10 +514,10 @@ func (m *GrpcClient) DropCollection(ctx context.Context, db string, collectionNa return nil } -func (m *GrpcClient) CreatePartition(ctx context.Context, db, collName, partitionName string) error { - ctx = m.newCtxWithDB(ctx, db) +func (g *GrpcClient) CreatePartition(ctx context.Context, db, collName, partitionName string) error { + ctx = g.newCtxWithDB(ctx, db) in := &milvuspb.CreatePartitionRequest{CollectionName: collName, PartitionName: partitionName} - resp, err := m.srv.CreatePartition(ctx, in) + resp, err := g.srv.CreatePartition(ctx, in) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: create partition failed: %w", err) } @@ -454,10 +525,10 @@ func (m *GrpcClient) CreatePartition(ctx context.Context, db, collName, partitio return nil } -func (m *GrpcClient) HasPartition(ctx context.Context, db, collName string, partitionName string) (bool, error) { - ctx = m.newCtxWithDB(ctx, db) +func (g *GrpcClient) HasPartition(ctx context.Context, db, collName string, partitionName string) (bool, error) { + ctx = g.newCtxWithDB(ctx, db) in := &milvuspb.HasPartitionRequest{CollectionName: collName, PartitionName: partitionName} - resp, err := m.srv.HasPartition(ctx, in) + resp, err := g.srv.HasPartition(ctx, in) if err := checkResponse(resp, err); err != nil { return false, fmt.Errorf("client: has partition failed: %w", err) } @@ -481,8 +552,8 @@ type CreateIndexInput struct { Params map[string]string } -func (m *GrpcClient) CreateIndex(ctx context.Context, input CreateIndexInput) error { - ctx = m.newCtxWithDB(ctx, input.DB) +func (g *GrpcClient) CreateIndex(ctx context.Context, input CreateIndexInput) error { + ctx = g.newCtxWithDB(ctx, input.DB) in := &milvuspb.CreateIndexRequest{ CollectionName: input.CollectionName, FieldName: input.FieldName, @@ -490,7 +561,7 @@ func (m *GrpcClient) CreateIndex(ctx context.Context, input CreateIndexInput) er ExtraParams: mapKvPairs(input.Params), } - resp, err := m.srv.CreateIndex(ctx, in) + resp, err := g.srv.CreateIndex(ctx, in) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: create index failed: %w", err) } @@ -498,19 +569,19 @@ func (m *GrpcClient) CreateIndex(ctx context.Context, input CreateIndexInput) er return nil } -func (m *GrpcClient) DropIndex(ctx context.Context, db, collName, indexName string) error { - ctx = m.newCtxWithDB(ctx, db) +func (g *GrpcClient) DropIndex(ctx context.Context, db, collName, indexName string) error { + ctx = g.newCtxWithDB(ctx, db) in := &milvuspb.DropIndexRequest{CollectionName: collName, IndexName: indexName} - resp, err := m.srv.DropIndex(ctx, in) + resp, err := g.srv.DropIndex(ctx, in) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: drop index failed: %w", err) } return nil } -func (m *GrpcClient) BackupRBAC(ctx context.Context) (*milvuspb.BackupRBACMetaResponse, error) { - ctx = m.newCtx(ctx) - resp, err := m.srv.BackupRBAC(ctx, &milvuspb.BackupRBACMetaRequest{}) +func (g *GrpcClient) BackupRBAC(ctx context.Context) (*milvuspb.BackupRBACMetaResponse, error) { + ctx = g.newCtx(ctx) + resp, err := g.srv.BackupRBAC(ctx, &milvuspb.BackupRBACMetaRequest{}) if err := checkResponse(resp, err); err != nil { return nil, fmt.Errorf("client: backup rbac failed: %w", err) } @@ -518,9 +589,9 @@ func (m *GrpcClient) BackupRBAC(ctx context.Context) (*milvuspb.BackupRBACMetaRe return resp, nil } -func (m *GrpcClient) RestoreRBAC(ctx context.Context, rbacMeta *milvuspb.RBACMeta) error { - ctx = m.newCtx(ctx) - resp, err := m.srv.RestoreRBAC(ctx, &milvuspb.RestoreRBACMetaRequest{RBACMeta: rbacMeta}) +func (g *GrpcClient) RestoreRBAC(ctx context.Context, rbacMeta *milvuspb.RBACMeta) error { + ctx = g.newCtx(ctx) + resp, err := g.srv.RestoreRBAC(ctx, &milvuspb.RestoreRBACMetaRequest{RBACMeta: rbacMeta}) if err := checkResponse(resp, err); err != nil { return fmt.Errorf("client: restore rbac failed: %w", err) } diff --git a/core/client/grpc_test.go b/core/client/grpc_test.go index f72f220..a5d229c 100644 --- a/core/client/grpc_test.go +++ b/core/client/grpc_test.go @@ -7,12 +7,39 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/credentials/insecure" + + "github.com/zilliztech/milvus-backup/core/paramtable" ) -func TestOk(t *testing.T) { - assert.True(t, ok(&commonpb.Status{Code: 0})) +func TestGrpcAuth(t *testing.T) { + got := grpcAuth("username", "password") + assert.Equal(t, "dXNlcm5hbWU6cGFzc3dvcmQ=", got) + + got = grpcAuth("", "") + assert.Equal(t, "", got) +} + +func TestTransCred(t *testing.T) { + cred, err := transCred(¶mtable.MilvusConfig{TLSMode: 3}) + assert.Error(t, err) + assert.Nil(t, cred) + + cred, err = transCred(¶mtable.MilvusConfig{TLSMode: 0}) + assert.NoError(t, err) + assert.Equal(t, insecure.NewCredentials(), cred) + + cred, err = transCred(¶mtable.MilvusConfig{TLSMode: 1}) + assert.NoError(t, err) + + cred, err = transCred(¶mtable.MilvusConfig{TLSMode: 2}) + assert.NoError(t, err) +} + +func TestStatusOk(t *testing.T) { + assert.True(t, statusOk(&commonpb.Status{Code: 0})) - assert.False(t, ok(&commonpb.Status{Code: 1})) + assert.False(t, statusOk(&commonpb.Status{Code: 1})) } func TestCheckResponse(t *testing.T) { diff --git a/core/client/restful.go b/core/client/restful.go index 7087e0f..c92e61b 100644 --- a/core/client/restful.go +++ b/core/client/restful.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "strconv" "time" @@ -9,6 +10,7 @@ import ( "github.com/imroc/req/v3" "go.uber.org/zap" + "github.com/zilliztech/milvus-backup/core/paramtable" "github.com/zilliztech/milvus-backup/internal/log" ) @@ -153,11 +155,30 @@ func (r *RestfulClient) GetBulkInsertState(ctx context.Context, dbName, jobID st return &getResp, nil } -func NewRestful(cfg *Cfg) (*RestfulClient, error) { - baseURL := cfg.parseRestful() - cli := req.C().SetBaseURL(baseURL.String()) +func restfulAuth(username, password string) string { + if username != "" || password != "" { + return fmt.Sprintf("%s:%s", username, password) + } + + return "" +} + +func NewRestful(cfg *paramtable.MilvusConfig) (*RestfulClient, error) { + host := fmt.Sprintf("%s:%s", cfg.Address, cfg.Port) + log.Info("New milvus restful client", zap.String("host", host)) + + var baseURL string + if cfg.TLSMode == 0 { + baseURL = "http://" + host + } else if cfg.TLSMode == 1 || cfg.TLSMode == 2 { + baseURL = "https://" + host + } else { + return nil, errors.New("client: invalid tls mode") + } + + cli := req.C().SetBaseURL(baseURL) - if auth := cfg.parseRestfulAuth(); len(auth) != 0 { + if auth := restfulAuth(cfg.User, cfg.Password); len(auth) != 0 { cli.SetCommonBearerAuthToken(auth) } diff --git a/core/client/restful_test.go b/core/client/restful_test.go new file mode 100644 index 0000000..15118a6 --- /dev/null +++ b/core/client/restful_test.go @@ -0,0 +1,15 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRestfulAuth(t *testing.T) { + got := restfulAuth("username", "password") + assert.Equal(t, "username:password", got) + + got = restfulAuth("", "") + assert.Equal(t, "", got) +} diff --git a/core/paramtable/base_table.go b/core/paramtable/base_table.go index 7fb2e88..1c31af0 100644 --- a/core/paramtable/base_table.go +++ b/core/paramtable/base_table.go @@ -51,14 +51,13 @@ const ( DefaultStorageType = "minio" - DefaultMilvusAddress = "localhost" - DefaultMilvusPort = "19530" - DefaultMilvusAuthorizationEnabled = "false" - DefaultMilvusTlsMode = "0" - DefaultMilvusUser = "root" - DefaultMilvusPassword = "Milvus" - DefaultMilvusTLSCertPath = "" - DefaultMilvusServerName = "" + DefaultMilvusAddress = "localhost" + DefaultMilvusPort = "19530" + DefaultMilvusTlsMode = "0" + DefaultMilvusUser = "root" + DefaultMilvusPassword = "Milvus" + DefaultMilvusTLSCertPath = "" + DefaultMilvusServerName = "" ) var defaultYaml = DefaultBackupYaml @@ -512,11 +511,6 @@ func (gp *BaseTable) loadMilvusConfig() { _ = gp.Save("milvus.port", milvusPort) } - milvusAuthorizationEnabled := os.Getenv("MILVUS_AUTHORIZATION_ENABLED") - if milvusAuthorizationEnabled != "" { - _ = gp.Save("milvus.authorizationEnabled", milvusAuthorizationEnabled) - } - milvusTlsMode := os.Getenv("MILVUS_TLS_MODE") if milvusTlsMode != "" { _ = gp.Save("milvus.tlsMode", milvusTlsMode) @@ -532,13 +526,23 @@ func (gp *BaseTable) loadMilvusConfig() { _ = gp.Save("milvus.password", milvusPassword) } - milvusTLSCertPath := os.Getenv("MILVUS_TLS_CERTPATH") + milvusTLSCertPath := os.Getenv("MILVUS_CA_CERT_PATH") if milvusTLSCertPath != "" { - _ = gp.Save("milvus.tlsCertPath", milvusTLSCertPath) + _ = gp.Save("milvus.caCertPath", milvusTLSCertPath) } milvusServerName := os.Getenv("MILVUS_SERVER_NAME") if milvusServerName != "" { _ = gp.Save("milvus.serverName", milvusServerName) } + + milvusMtlsCertPath := os.Getenv("MILVUS_MTLS_CERT_PATH") + if milvusMtlsCertPath != "" { + _ = gp.Save("milvus.mtlsCertPath", milvusMtlsCertPath) + } + + milvusMtlsKeyPath := os.Getenv("MILVUS_MTLS_KEY_PATH") + if milvusMtlsKeyPath != "" { + _ = gp.Save("milvus.mtlsKeyPath", milvusMtlsKeyPath) + } } diff --git a/core/paramtable/params.go b/core/paramtable/params.go index 765c218..8be55fa 100644 --- a/core/paramtable/params.go +++ b/core/paramtable/params.go @@ -105,14 +105,21 @@ func (p *BackupConfig) initGcPauseAddress() { type MilvusConfig struct { Base *BaseTable - Address string - Port string - User string - Password string - TLSCertPath string - ServerName string - AuthorizationEnabled bool - TLSMode int + Address string + Port string + + User string + Password string + + TLSMode int + + // tls credentials for validate server + CACertPath string + ServerName string + + // tls credentials for validate client, eg: mTLS + MTLSCertPath string + MTLSKeyPath string } func (p *MilvusConfig) init(base *BaseTable) { @@ -120,12 +127,17 @@ func (p *MilvusConfig) init(base *BaseTable) { p.initAddress() p.initPort() + p.initUser() p.initPassword() - p.initTLSCertPath() - p.initServerName() - p.initAuthorizationEnabled() + p.initTLSMode() + + p.initCACertPath() + p.initServerName() + + p.initMTLSCertPath() + p.initMTLSKeyPath() } func (p *MilvusConfig) initAddress() { @@ -160,9 +172,9 @@ func (p *MilvusConfig) initPassword() { p.Password = password } -func (p *MilvusConfig) initTLSCertPath() { - tlsCertPath := p.Base.LoadWithDefault("milvus.tlsCertPath", "") - p.TLSCertPath = tlsCertPath +func (p *MilvusConfig) initCACertPath() { + caCertPath := p.Base.LoadWithDefault("milvus.caCertPath", "") + p.CACertPath = caCertPath } func (p *MilvusConfig) initServerName() { @@ -170,11 +182,23 @@ func (p *MilvusConfig) initServerName() { p.ServerName = serverName } -func (p *MilvusConfig) initAuthorizationEnabled() { - p.AuthorizationEnabled = p.Base.ParseBool("milvus.authorizationEnabled", false) +func (p *MilvusConfig) initMTLSCertPath() { + // for backward compatibility, if mTLS cert path is not set, use tls mode 1. + // WARN: This behavior will be removed in the version after v0.6.0 + mtlsCertPath := p.Base.LoadWithDefault("milvus.mtlsCertPath", "") + p.MTLSCertPath = mtlsCertPath +} + +func (p *MilvusConfig) initMTLSKeyPath() { + // for backward compatibility, if mTLS key path is not set, use tls mode 1 instead of 2. + // WARN: This behavior will be removed in the version after v0.6.0 + mtlsKeyPath := p.Base.LoadWithDefault("milvus.mtlsKeyPath", "") + p.MTLSKeyPath = mtlsKeyPath } func (p *MilvusConfig) initTLSMode() { + // for backward compatibility, if mTLS cert path is not set, use tls mode 1 instead of 2. + // WARN: This behavior will be removed in the version after v0.6.0 p.TLSMode = p.Base.ParseIntWithDefault("milvus.tlsMode", 0) }