From 7802f417c318ad53e251504d965ccd287d042b0c Mon Sep 17 00:00:00 2001 From: ianlee Date: Sat, 24 Aug 2024 23:52:21 +0900 Subject: [PATCH] mathpresso custom initialize --- core/api/replicate_manager.go | 5 +- core/reader/collection_reader.go | 5 +- core/reader/replicate_channel_manager.go | 38 +- core/reader/target_client.go | 12 +- core/util/devon_client_resource.go | 114 ++++++ core/writer/config_option.go | 42 +- core/writer/devon_handler.go | 467 +++++++++++++++++++++++ server/cdc_impl.go | 102 +++-- server/model/common.go | 4 +- 9 files changed, 739 insertions(+), 50 deletions(-) create mode 100644 core/util/devon_client_resource.go create mode 100644 core/writer/devon_handler.go diff --git a/core/api/replicate_manager.go b/core/api/replicate_manager.go index 4b61e019..e200102c 100644 --- a/core/api/replicate_manager.go +++ b/core/api/replicate_manager.go @@ -35,7 +35,8 @@ type ChannelManager interface { AddDroppedCollection(ids []int64) AddDroppedPartition(ids []int64) - StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error + StartReadCollection(ctx context.Context, info *pb.CollectionInfo, targetDBType string, seekPositions []*msgpb.MsgPosition) error + StopReadCollection(ctx context.Context, info *pb.CollectionInfo) error AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error @@ -103,7 +104,7 @@ func (d *DefaultChannelManager) AddDroppedPartition(ids []int64) { log.Warn("AddDroppedPartition is not implemented, please check it") } -func (d *DefaultChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { +func (d *DefaultChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, targetDBType string, seekPositions []*msgpb.MsgPosition) error { log.Warn("StartReadCollection is not implemented, please check it") return nil } diff --git a/core/reader/collection_reader.go b/core/reader/collection_reader.go index bdf0ccba..175fd8c4 100644 --- a/core/reader/collection_reader.go +++ b/core/reader/collection_reader.go @@ -68,6 +68,7 @@ type CollectionReader struct { quitOnce sync.Once retryOptions []retry.Option + targetDBType string } func NewCollectionReader(id string, @@ -124,7 +125,7 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { Timestamp: info.CreateTime, }) } - if err := reader.channelManager.StartReadCollection(ctx, info, startPositions); err != nil { + if err := reader.channelManager.StartReadCollection(ctx, info, reader.targetDBType, startPositions); err != nil { collectionLog.Warn("fail to start to replicate the collection data in the watch process", zap.Any("info", info), zap.Error(err)) reader.sendError(err) } @@ -251,7 +252,7 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { zap.String("name", info.Schema.Name), zap.Int64("collection_id", info.ID), zap.String("state", info.State.String())) - if err := reader.channelManager.StartReadCollection(ctx, info, seekPositions); err != nil { + if err := reader.channelManager.StartReadCollection(ctx, info, reader.targetDBType, seekPositions); err != nil { readerLog.Warn("fail to start to replicate the collection data", zap.Any("collection", info), zap.Error(err)) reader.sendError(err) } diff --git a/core/reader/replicate_channel_manager.go b/core/reader/replicate_channel_manager.go index 2d737332..b43f3461 100644 --- a/core/reader/replicate_channel_manager.go +++ b/core/reader/replicate_channel_manager.go @@ -20,6 +20,7 @@ package reader import ( "context" + "fmt" "io" "math" "sort" @@ -159,7 +160,7 @@ func (r *replicateChannelManager) AddDroppedPartition(ids []int64) { log.Info("has removed dropped partitions", zap.Int64s("ids", ids)) } -func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { +func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, targetDBType string, seekPositions []*msgpb.MsgPosition) error { r.addCollectionLock.Lock() *r.addCollectionCnt++ r.addCollectionLock.Unlock() @@ -261,6 +262,7 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info } targetMsgCount := len(info.StartPositions) + barrier := NewBarrier(targetMsgCount, func(msgTs uint64, b *Barrier) { select { case <-b.CloseChan: @@ -284,9 +286,16 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info var successChannels []string var channelHandlers []*replicateChannelHandler - err = ForeachChannel(info.VirtualChannelNames, targetInfo.VChannels, func(sourceVChannel, targetVChannel string) error { - sourcePChannel := funcutil.ToPhysicalChannel(sourceVChannel) - targetPChannel := funcutil.ToPhysicalChannel(targetVChannel) + toPhysicalChannel := func(vChannel string) string { + if strings.ToLower(vChannel) != "mysql" && strings.ToLower(vChannel) != "bigquery" { + return funcutil.ToPhysicalChannel(vChannel) + } + + return strings.ToLower(vChannel) + } + err = ForeachChannel(targetDBType, info.VirtualChannelNames, targetInfo.VChannels, targetInfo.DatabaseName, targetInfo.CollectionName, func(sourceVChannel, targetVChannel string) error { + sourcePChannel := toPhysicalChannel(sourceVChannel) + targetPChannel := toPhysicalChannel(targetVChannel) channelHandler, err := r.startReadChannel(&model.SourceCollectionInfo{ PChannelName: sourcePChannel, VChannelName: sourceVChannel, @@ -340,12 +349,25 @@ func GetVChannelByPChannel(pChannel string, vChannels []string) string { return "" } -func ForeachChannel(sourcePChannels, targetPChannels []string, f func(sourcePChannel, targetPChannel string) error) error { - if len(sourcePChannels) != len(targetPChannels) { - return errors.New("the lengths of source and target channels are not equal") +func ForeachChannel(targetDBType string, sourcePChannels, targetPChannels []string, targetDBName, targetCollectionName string, f func(sourcePChannel, targetPChannel string) error) error { + if targetDBType == "milvus" { + if len(sourcePChannels) != len(targetPChannels) { + return errors.New("the lengths of source and target channels are not equal") + } } + sources := make([]string, len(sourcePChannels)) - targets := make([]string, len(targetPChannels)) + var targets []string + + if targetDBType == "milvus" { + targets = make([]string, len(targetPChannels)) + } else { + targets = make([]string, len(sourcePChannels)) + for i := 0; i < len(sourcePChannels); i++ { + targetPChannels = append(targetPChannels, fmt.Sprintf("%s-%s-%s-dml_%d_v0", targetDBName, targetCollectionName, targetDBType, i)) + } + } + copy(sources, sourcePChannels) copy(targets, targetPChannels) sort.Strings(sources) diff --git a/core/reader/target_client.go b/core/reader/target_client.go index 66280cad..c913c60c 100644 --- a/core/reader/target_client.go +++ b/core/reader/target_client.go @@ -45,10 +45,14 @@ type TargetClient struct { } type TargetConfig struct { - URI string - Token string - APIKey string - DialConfig util.DialConfig + TargetDBType string + URI string + Token string + ProjectId string + APIKey string + ConnectionTimeout int + DialConfig util.DialConfig + TargetCDCAgentUri string } func NewTarget(ctx context.Context, config TargetConfig) (api.TargetAPI, error) { diff --git a/core/util/devon_client_resource.go b/core/util/devon_client_resource.go new file mode 100644 index 00000000..70c6b4e2 --- /dev/null +++ b/core/util/devon_client_resource.go @@ -0,0 +1,114 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * custom by qanda + */ + +package util + +import ( + "context" + "fmt" + "net" + "reflect" + "sync" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/util/resource" + + "github.com/zilliztech/milvus-cdc/core/log" +) + +const ( + DBClientResourceTyp = "db_client" + DBClientExpireTime = 30 * time.Second +) + +var ( + dbClientManager *DBClientResourceManager + dbClientManagerOnce sync.Once +) + +type DBClientResourceManager struct { + manager resource.Manager +} + +func GetDBClientManager() *DBClientResourceManager { + dbClientManagerOnce.Do(func() { + manager := resource.NewManager(0, 0, nil) + dbClientManager = &DBClientResourceManager{ + manager: manager, + } + }) + return dbClientManager +} + +func (m *DBClientResourceManager) newDBClient(cdcAgentHost, cdcAgentPort, address, database, collection string, dialConfig DialConfig) resource.NewResourceFunc { + return func() (resource.Resource, error) { + conn, err := net.Dial("tcp", cdcAgentHost+":"+cdcAgentPort) + if err != nil { + log.Warn("Error connecting:", zap.Error(err)) + return nil, err + } + /* + c, err := client.NewClient(ctx, client.Config{ + Address: address, + APIKey: apiKey, + EnableTLSAuth: enableTLS, + DBName: database, + }) + if err != nil { + log.Warn("fail to new the db client", zap.String("database", database), zap.String("address", address), zap.Error(err)) + return nil, err + } + */ + res := resource.NewSimpleResource(conn, DBClientResourceTyp, fmt.Sprintf("%s:%s:%s", address, database, collection), DBClientExpireTime, func() { + _ = conn.Close() + }) + + return res, nil + } +} + +func (m *DBClientResourceManager) GetDBClient(ctx context.Context, cdcAgentHost, cdcAgentPort, address, database, collection string, dialConfig DialConfig, connectionTimeout int) (net.Conn, error) { + if database == "" { + database = DefaultDbName + } + ctxLog := log.Ctx(ctx).With(zap.String("database", database), zap.String("address", address)) + res, err := m.manager.Get(DBClientResourceTyp, + getDBClientResourceName(address, database, collection), + m.newDBClient(cdcAgentHost, cdcAgentPort, address, database, collection, dialConfig)) + if err != nil { + ctxLog.Error("fail to get db client", zap.Error(err)) + return nil, err + } + if obj, ok := res.Get().(net.Conn); ok && obj != nil { + return obj, nil + } + ctxLog.Warn("invalid resource object", zap.Any("obj", reflect.TypeOf(res.Get()))) + return nil, errors.New("invalid resource object") +} + +func (m *DBClientResourceManager) DeleteDBClient(address, database, collection string) { + _ = m.manager.Delete(DBClientResourceTyp, getDBClientResourceName(address, database, collection)) +} + +func getDBClientResourceName(address, database, collection string) string { + return fmt.Sprintf("%s:%s:%s", address, database, collection) +} diff --git a/core/writer/config_option.go b/core/writer/config_option.go index f2ad6a51..a6501299 100644 --- a/core/writer/config_option.go +++ b/core/writer/config_option.go @@ -23,19 +23,19 @@ import ( "github.com/zilliztech/milvus-cdc/core/util" ) -func TokenOption(token string) config.Option[*MilvusDataHandler] { +func MilvusTokenOption(token string) config.Option[*MilvusDataHandler] { return config.OptionFunc[*MilvusDataHandler](func(object *MilvusDataHandler) { object.token = token }) } -func URIOption(uri string) config.Option[*MilvusDataHandler] { +func MilvusURIOption(uri string) config.Option[*MilvusDataHandler] { return config.OptionFunc[*MilvusDataHandler](func(object *MilvusDataHandler) { object.uri = uri }) } -func ConnectTimeoutOption(timeout int) config.Option[*MilvusDataHandler] { +func MilvusConnectTimeoutOption(timeout int) config.Option[*MilvusDataHandler] { return config.OptionFunc[*MilvusDataHandler](func(object *MilvusDataHandler) { if timeout > 0 { object.connectTimeout = timeout @@ -43,14 +43,46 @@ func ConnectTimeoutOption(timeout int) config.Option[*MilvusDataHandler] { }) } -func IgnorePartitionOption(ignore bool) config.Option[*MilvusDataHandler] { +func MilvusIgnorePartitionOption(ignore bool) config.Option[*MilvusDataHandler] { return config.OptionFunc[*MilvusDataHandler](func(object *MilvusDataHandler) { object.ignorePartition = ignore }) } -func DialConfigOption(dialConfig util.DialConfig) config.Option[*MilvusDataHandler] { +func MilvusDialConfigOption(dialConfig util.DialConfig) config.Option[*MilvusDataHandler] { return config.OptionFunc[*MilvusDataHandler](func(object *MilvusDataHandler) { object.dialConfig = dialConfig }) } + +func TokenOption(token string) config.Option[*DataHandler] { + return config.OptionFunc[*DataHandler](func(object *DataHandler) { + object.token = token + }) +} + +func URIOption(uri string) config.Option[*DataHandler] { + return config.OptionFunc[*DataHandler](func(object *DataHandler) { + object.uri = uri + }) +} + +func ConnectTimeoutOption(timeout int) config.Option[*DataHandler] { + return config.OptionFunc[*DataHandler](func(object *DataHandler) { + if timeout > 0 { + object.connectTimeout = timeout + } + }) +} + +func IgnorePartitionOption(ignore bool) config.Option[*DataHandler] { + return config.OptionFunc[*DataHandler](func(object *DataHandler) { + object.ignorePartition = ignore + }) +} + +func DialConfigOption(dialConfig util.DialConfig) config.Option[*DataHandler] { + return config.OptionFunc[*DataHandler](func(object *DataHandler) { + object.dialConfig = dialConfig + }) +} diff --git a/core/writer/devon_handler.go b/core/writer/devon_handler.go new file mode 100644 index 00000000..cb67b76e --- /dev/null +++ b/core/writer/devon_handler.go @@ -0,0 +1,467 @@ +package writer + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/resource" + "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/zilliztech/milvus-cdc/core/api" + "github.com/zilliztech/milvus-cdc/core/config" + "github.com/zilliztech/milvus-cdc/core/log" + "github.com/zilliztech/milvus-cdc/core/util" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "net" + "strconv" + "strings" + "time" +) + +type DataHandler struct { + api.DataHandler + + agentHost string `json:"agent_host"` + agentPort string `json:"agent_port"` + token string `json:"token"` + uri string `json:"uri"` + address string `json:"address"` + database string `json:"database"` + collection string `json:"collection"` + username string `json:"user_name"` + password string `json:"password"` + enableTLS bool `json:"enable_tls"` + ignorePartition bool // sometimes the has partition api is a deny api + connectTimeout int `json:"connect_timeout"` + projectId string `json:"project_id"` + msgType commonpb.MsgType `json:"msg_type"` + msgBytes [][]byte `json:"msg_bytes"` + retryOptions []retry.Option `json:"retry_options"` + dialConfig util.DialConfig `json:"dial_config"` +} + +func NewDataHandler(options ...config.Option[*DataHandler]) (*DataHandler, error) { + handler := &DataHandler{ + connectTimeout: 5, + } + for _, option := range options { + option.Apply(handler) + } + if handler.address == "" { + return nil, errors.New("empty cdc agent address") + } + + var err error + timeoutContext, cancel := context.WithTimeout(context.Background(), time.Duration(handler.connectTimeout)*time.Second) + defer cancel() + + err = handler.DBOp(timeoutContext, func(db net.Conn) error { + return nil + }) + if err != nil { + log.Warn("fail to new the cdc agent client", zap.Error(err)) + return nil, err + } + handler.retryOptions = util.GetRetryOptions(config.GetCommonConfig().Retry) + return handler, nil +} + +func (m *DataHandler) DBOp(ctx context.Context, f func(db net.Conn) error) error { + retryDBAgentFunc := func(c net.Conn) error { + var err error + retryErr := retry.Do(ctx, func() error { + c, err = util.GetDBClientManager().GetDBClient(ctx, m.agentHost, m.agentPort, m.address, m.database, m.collection, util.DialConfig{}, m.connectTimeout) + if err != nil { + log.Warn("fail to get cdc agent client", zap.Error(err)) + return err + } + err = f(c) + if status.Code(err) == codes.Canceled { + util.GetDBClientManager().DeleteDBClient(m.address, m.database, m.collection) + log.Warn("grpc: the client connection is closing, waiting...", zap.Error(err)) + time.Sleep(resource.DefaultExpiration) + } + return err + }, m.retryOptions...) + if retryErr != nil && err != nil { + return err + } + if retryErr != nil { + return retryErr + } + return nil + } + + dbClient, err := util.GetDBClientManager().GetDBClient(ctx, m.agentHost, m.agentPort, m.address, m.database, m.collection, util.DialConfig{}, m.connectTimeout) + if err != nil { + log.Warn("fail to get cdc agent client", zap.Error(err)) + return err + } + + return retryDBAgentFunc(dbClient) +} + +func (m *DataHandler) ReplicaMessageHandler(ctx context.Context, msgType commonpb.MsgType, param *api.ReplicateMessageParam) error { + if param.ChannelName != "" { + m.database = strings.Split(param.ChannelName, "-")[0] + m.collection = strings.Split(param.ChannelName, "-")[1] + } + + m.msgType = msgType + m.msgBytes = param.MsgsBytes + + messageJSON, err := json.Marshal(m) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) CreateCollection(ctx context.Context, param *api.CreateCollectionParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) DropCollection(ctx context.Context, param *api.DropCollectionParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func float32SliceToStringSlice(input []float32) []string { + output := make([]string, len(input)) + for i, v := range input { + output[i] = strconv.FormatFloat(float64(v), 'f', -1, 32) + } + return output +} + +func (m *DataHandler) Insert(ctx context.Context, param *api.InsertParam) (err error) { + columns := []string{} + + rowValues := [][]interface{}{} + for _, col := range param.Columns { + columns = append(columns, fmt.Sprintf("`%s`", col.Name())) + + colValues := []interface{}{} + switch data := col.FieldData().GetField().(type) { + case *schemapb.FieldData_Scalars: + switch scalarData := data.Scalars.Data.(type) { + case *schemapb.ScalarField_LongData: + for _, v := range scalarData.LongData.Data { + colValues = append(colValues, v) + } + case *schemapb.ScalarField_BoolData: + for _, v := range scalarData.BoolData.Data { + colValues = append(colValues, v) + } + case *schemapb.ScalarField_StringData: + for _, v := range scalarData.StringData.Data { + colValues = append(colValues, fmt.Sprintf("'%s'", v)) + } + case *schemapb.ScalarField_ArrayData: + for _, v := range scalarData.ArrayData.Data { + colValues = append(colValues, fmt.Sprintf("'%s'", v)) + } + case *schemapb.ScalarField_IntData: + for _, v := range scalarData.IntData.Data { + colValues = append(colValues, v) + } + case *schemapb.ScalarField_FloatData: + for _, v := range scalarData.FloatData.Data { + colValues = append(colValues, v) + } + case *schemapb.ScalarField_DoubleData: + for _, v := range scalarData.DoubleData.Data { + colValues = append(colValues, v) + } + case *schemapb.ScalarField_JsonData: + for _, v := range scalarData.JsonData.Data { + colValues = append(colValues, fmt.Sprintf("'%s'", v)) + } + case *schemapb.ScalarField_BytesData: + for _, v := range scalarData.BytesData.Data { + colValues = append(colValues, fmt.Sprintf("'%s'", v)) + } + default: + return fmt.Errorf("unsupported scalar data type: %T", scalarData) + } + + rowValues = append(rowValues, colValues) + case *schemapb.FieldData_Vectors: + switch vectorData := data.Vectors.Data.(type) { + case *schemapb.VectorField_FloatVector: + dim := data.Vectors.Dim + cnt := int64(1) + var vec []float32 + for _, v := range vectorData.FloatVector.Data { + vec = append(vec, v) + + if cnt == dim { + colValues = append(colValues, fmt.Sprintf("[%v]", join(float32SliceToStringSlice(vec), ","))) + vec = []float32{} + cnt = 1 + } else { + cnt++ + } + } + + rowValues = append(rowValues, colValues) + default: + return fmt.Errorf("unsupported vector data type: %T", vectorData) + } + default: + return fmt.Errorf("unsupported field data type: %T", data) + } + } + + var value string + var values []string + + for rowCnt := 0; rowCnt < len(rowValues[0]); rowCnt++ { + for colNo, _ := range columns { + if colNo == 0 { + value = fmt.Sprintf("(%s", fmt.Sprintf("%v", rowValues[colNo][rowCnt])) + } else { + value = fmt.Sprintf("%s,%s", value, fmt.Sprintf("%v", rowValues[colNo][rowCnt])) + } + } + value = fmt.Sprintf("%s)", value) + + values = append(values, value) + } + + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + err = m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) + + if err != nil { + return err + } + + return nil +} + +func StructToReader(s interface{}) (*bytes.Reader, error) { + // 구조체를 JSON으로 직렬화 + data, err := json.Marshal(s) + if err != nil { + return nil, fmt.Errorf("error marshalling struct: %v", err) + } + + // *bytes.Reader로 변환하여 반환 + reader := bytes.NewReader(data) + return reader, nil +} + +func (m *DataHandler) Delete(ctx context.Context, param *api.DeleteParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) CreateIndex(ctx context.Context, param *api.CreateIndexParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) DropIndex(ctx context.Context, param *api.DropIndexParam) error { + // BigQuery는 인덱스를 지원하지 않습니다. + return nil +} + +func (m *DataHandler) CreateDatabase(ctx context.Context, param *api.CreateDatabaseParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) DropDatabase(ctx context.Context, param *api.DropDatabaseParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) ReplicateMessage(ctx context.Context, param *api.ReplicateMessageParam) error { + for i, msgBytes := range param.MsgsBytes { + header := &commonpb.MsgHeader{} + err := proto.Unmarshal(msgBytes, header) + if err != nil { + log.Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err)) + return err + } + + if header.GetBase() == nil { + log.Warn("msg header base is nil", zap.Int("index", i), zap.Error(err)) + return err + } + + err = m.ReplicaMessageHandler(ctx, header.GetBase().GetMsgType(), param) + if err != nil { + log.Warn("failed to replca message handle", zap.Int("index", i), zap.Error(err)) + return err + } + } + + return nil +} + +func (m *DataHandler) DescribeCollection(ctx context.Context, param *api.DescribeCollectionParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func (m *DataHandler) DescribeDatabase(ctx context.Context, param *api.DescribeDatabaseParam) error { + messageJSON, err := json.Marshal(param) + if err != nil { + log.Warn("Error marshalling JSON:", zap.Error(err)) + return err + } + + return m.DBOp(ctx, func(dbClient net.Conn) error { + _, err := dbClient.Write(messageJSON) + if err != nil { + log.Warn("failed to send message:", zap.Error(err)) + return err + } + + return nil + }) +} + +func join(elems []string, sep string) string { + switch len(elems) { + case 0: + return "" + case 1: + return elems[0] + } + n := len(sep) * (len(elems) - 1) + for i := 0; i < len(elems); i++ { + n += len(elems[i]) + } + + var b strings.Builder + b.Grow(n) + b.WriteString(elems[0]) + for _, s := range elems[1:] { + b.WriteString(sep) + b.WriteString(s) + } + return b.String() +} diff --git a/server/cdc_impl.go b/server/cdc_impl.go index e0ff8e62..ede07bc6 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -401,16 +401,26 @@ func (e *MetaCDC) validCreateRequest(req *request.CreateRequest) error { connectParam.Token = GetMilvusToken(connectParam) connectParam.URI = GetMilvusURI(connectParam) - _, err := cdcwriter.NewMilvusDataHandler( - cdcwriter.URIOption(connectParam.URI), - cdcwriter.TokenOption(connectParam.Token), - cdcwriter.IgnorePartitionOption(connectParam.IgnorePartition), - cdcwriter.ConnectTimeoutOption(connectParam.ConnectTimeout), - cdcwriter.DialConfigOption(connectParam.DialConfig), - ) + var err error + if req.MilvusConnectParam.TargetDBType == "milvus" { + _, err = cdcwriter.NewMilvusDataHandler( + cdcwriter.MilvusURIOption(connectParam.URI), + cdcwriter.MilvusTokenOption(connectParam.Token), + cdcwriter.MilvusIgnorePartitionOption(connectParam.IgnorePartition), + cdcwriter.MilvusConnectTimeoutOption(connectParam.ConnectTimeout), + cdcwriter.MilvusDialConfigOption(connectParam.DialConfig), + ) + } else { + _, err = cdcwriter.NewDataHandler( + cdcwriter.URIOption(connectParam.URI), + cdcwriter.TokenOption(connectParam.Token), + cdcwriter.ConnectTimeoutOption(connectParam.ConnectTimeout), + cdcwriter.DialConfigOption(connectParam.DialConfig), + ) + } if err != nil { - log.Warn("fail to connect the milvus", zap.Any("connect_param", connectParam), zap.Error(err)) - return errors.WithMessage(err, "fail to connect the milvus") + log.Warn(fmt.Sprintf("fail to connect the %s", req.MilvusConnectParam.TargetDBType), zap.Any("connect_param", connectParam), zap.Error(err)) + return errors.WithMessage(err, fmt.Sprintf("fail to connect the %s", req.MilvusConnectParam.TargetDBType)) } return nil } @@ -578,11 +588,27 @@ func (e *MetaCDC) newReplicateEntity(info *meta.TaskInfo) (*ReplicateEntity, err milvusConnectParam.Token = GetMilvusToken(milvusConnectParam) milvusConnectParam.URI = GetMilvusURI(milvusConnectParam) milvusAddress := milvusConnectParam.URI - milvusClient, err := cdcreader.NewTarget(timeoutCtx, cdcreader.TargetConfig{ - URI: milvusAddress, - Token: milvusConnectParam.Token, - DialConfig: milvusConnectParam.DialConfig, - }) + + var milvusClient api.TargetAPI + var err error + + if strings.ToLower(milvusConnectParam.TargetDBType) == "milvus" { + milvusClient, err = cdcreader.NewTarget(timeoutCtx, cdcreader.TargetConfig{ + URI: milvusAddress, + Token: milvusConnectParam.Token, + DialConfig: milvusConnectParam.DialConfig, + ProjectId: milvusConnectParam.ProjectId, + TargetDBType: milvusConnectParam.TargetDBType, + }) + } else { + milvusClient, err = cdcreader.NewTarget(timeoutCtx, cdcreader.TargetConfig{ + URI: milvusAddress, + Token: milvusConnectParam.Token, + ProjectId: milvusConnectParam.ProjectId, + TargetDBType: milvusConnectParam.TargetDBType, + }) + } + cancelFunc() if err != nil { taskLog.Warn("fail to new target", zap.String("address", milvusAddress), zap.Error(err)) @@ -630,21 +656,41 @@ func (e *MetaCDC) newReplicateEntity(info *meta.TaskInfo) (*ReplicateEntity, err return nil, servererror.NewClientError("fail to create replicate channel manager") } targetConfig := milvusConnectParam - dataHandler, err := cdcwriter.NewMilvusDataHandler( - cdcwriter.URIOption(targetConfig.URI), - cdcwriter.TokenOption(targetConfig.Token), - cdcwriter.IgnorePartitionOption(targetConfig.IgnorePartition), - cdcwriter.ConnectTimeoutOption(targetConfig.ConnectTimeout), - cdcwriter.DialConfigOption(targetConfig.DialConfig), - ) - if err != nil { - taskLog.Warn("fail to new the data handler", zap.Error(err)) - return nil, servererror.NewClientError("fail to new the data handler, task_id: ") + var writerObj api.Writer + + if strings.ToLower(targetConfig.TargetDBType) == "milvus" { + dataHandler, err := cdcwriter.NewMilvusDataHandler( + cdcwriter.MilvusURIOption(targetConfig.URI), + cdcwriter.MilvusTokenOption(targetConfig.Token), + cdcwriter.MilvusIgnorePartitionOption(targetConfig.IgnorePartition), + cdcwriter.MilvusConnectTimeoutOption(targetConfig.ConnectTimeout), + cdcwriter.MilvusDialConfigOption(targetConfig.DialConfig), + ) + if err != nil { + taskLog.Warn("fail to new the data handler", zap.Error(err)) + return nil, servererror.NewClientError("fail to new the data handler, task_id: ") + } + writerObj = cdcwriter.NewChannelWriter(dataHandler, config.WriterConfig{ + MessageBufferSize: bufferSize, + Retry: e.config.Retry, + }, metaOp.GetAllDroppedObj()) + } else { + dataHandler, err := cdcwriter.NewDataHandler( + cdcwriter.URIOption(targetConfig.URI), + cdcwriter.TokenOption(targetConfig.Token), + cdcwriter.ConnectTimeoutOption(targetConfig.ConnectTimeout), + cdcwriter.DialConfigOption(targetConfig.DialConfig), + ) + if err != nil { + taskLog.Warn("fail to new the data handler", zap.Error(err)) + return nil, servererror.NewClientError("fail to new the data handler, task_id: ") + } + writerObj = cdcwriter.NewChannelWriter(dataHandler, config.WriterConfig{ + MessageBufferSize: bufferSize, + Retry: e.config.Retry, + }, metaOp.GetAllDroppedObj()) } - writerObj := cdcwriter.NewChannelWriter(dataHandler, config.WriterConfig{ - MessageBufferSize: bufferSize, - Retry: e.config.Retry, - }, metaOp.GetAllDroppedObj()) + e.replicateEntityMap.Lock() defer e.replicateEntityMap.Unlock() entity, ok := e.replicateEntityMap.data[milvusAddress] diff --git a/server/model/common.go b/server/model/common.go index 2fd508ec..a432795b 100644 --- a/server/model/common.go +++ b/server/model/common.go @@ -22,6 +22,7 @@ import "github.com/zilliztech/milvus-cdc/core/util" //go:generate easytags $GOFILE json,mapstructure type MilvusConnectParam struct { + TargetDBType string `json:"target_db_type" mapstructure:"target_db_type"` // Deprecated: use uri instead Host string `json:"host" mapstructure:"host"` // Deprecated: use uri instead @@ -37,7 +38,8 @@ type MilvusConnectParam struct { DialConfig util.DialConfig `json:"dial_config,omitempty" mapstructure:"dial_config,omitempty"` IgnorePartition bool `json:"ignore_partition" mapstructure:"ignore_partition"` // ConnectTimeout unit: s - ConnectTimeout int `json:"connect_timeout" mapstructure:"connect_timeout"` + ConnectTimeout int `json:"connect_timeout" mapstructure:"connect_timeout"` + ProjectId string `json:"project_id" mapstructure:"project_id,omitempty"` } type CollectionInfo struct {