Skip to content

Commit

Permalink
fix several bugs and add new parameter to support partition key. (#792)
Browse files Browse the repository at this point in the history
* bug fix: check the fieldName and indexName matches the describeIndex request

Signed-off-by: Nian Liu <nian.liu@zilliz.com>

* add partitionKey support for collection create, add ShardsNum, Array, etc.

Signed-off-by: Nian Liu <nian.liu@zilliz.com>

* Feature sync: update createCollection to add several new params; add AddFieldReq, update QueryResp and SearchResp

Signed-off-by: Nian Liu <nian.liu@zilliz.com>

---------

Signed-off-by: Nian Liu <nian.liu@zilliz.com>
(cherry picked from commit af62656)
  • Loading branch information
nianliuu committed Mar 8, 2024
1 parent 93fb99e commit c9eecfc
Show file tree
Hide file tree
Showing 21 changed files with 176 additions and 125 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ volumes/
examples/main/java/io/milvus/tls/*
!examples/main/java/io/milvus/tls/gen.sh
!examples/main/java/io/milvus/tls/openssl.cnf
/src/main/java/io/milvus/v2/examples/*
/src/main/java/io/milvus/v2/examples/ManageCollectionDemo.java
6 changes: 2 additions & 4 deletions src/main/java/io/milvus/v2/client/MilvusClientV2.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,10 @@ public void createCollection(CreateCollectionReq request) {

/**
* Creates a collection schema.
* @param enableDynamicField enable dynamic field
* @param description collection description
* @return CreateCollectionReq.CollectionSchema
*/
public CreateCollectionReq.CollectionSchema createSchema(Boolean enableDynamicField, String description) {
return collectionService.createSchema(enableDynamicField, description);
public CreateCollectionReq.CollectionSchema createSchema() {
return collectionService.createSchema();
}

/**
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/milvus/v2/examples/Simple.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;

public class Simple {
Integer dim = 2;
Expand All @@ -36,19 +35,20 @@ public static void main(String[] args) {

public void run() throws InterruptedException {
ConnectConfig connectConfig = ConnectConfig.builder()
.uri("https://in01-******.aws-us-west-2.vectordb.zillizcloud.com:19531")
.token("*****")
.uri("https://in01-***.aws-us-west-2.vectordb.zillizcloud.com:19531")
.token("***")
.build();
MilvusClientV2 client = new MilvusClientV2(connectConfig);
// check collection exists
if (client.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build())) {
logger.info("collection exists");
client.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
TimeUnit.SECONDS.sleep(1);
logger.info("collection dropped");
}
// create collection
CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionName(collectionName)
.description("simple collection")
.dimension(dim)
.build();
client.createCollection(createCollectionReq);
Expand Down
29 changes: 17 additions & 12 deletions src/main/java/io/milvus/v2/examples/Simple_Schema.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import io.milvus.v2.client.MilvusClientV2;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.collection.request.DropCollectionReq;
import io.milvus.v2.service.collection.request.HasCollectionReq;
import io.milvus.v2.service.collection.request.LoadCollectionReq;
import io.milvus.v2.service.collection.request.*;
import io.milvus.v2.service.index.request.CreateIndexReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.QueryReq;
Expand All @@ -30,25 +27,27 @@ public class Simple_Schema {
static Logger logger = LoggerFactory.getLogger(Simple_Schema.class);
public void run() throws InterruptedException {
ConnectConfig connectConfig = ConnectConfig.builder()
.uri("https://in01-****.aws-us-west-2.vectordb.zillizcloud.com:19531")
.token("******")
.uri("https://in01-***.aws-us-west-2.vectordb.zillizcloud.com:19531")
.token("***")
.build();
MilvusClientV2 client = new MilvusClientV2(connectConfig);
// check collection exists
if (client.hasCollection(HasCollectionReq.builder().collectionName(collectionName).build())) {
logger.info("collection exists");
client.dropCollection(DropCollectionReq.builder().collectionName(collectionName).build());
TimeUnit.SECONDS.sleep(1);
}
// create collection
CreateCollectionReq.CollectionSchema collectionSchema = client.createSchema(Boolean.TRUE, "");
collectionSchema.addPrimaryField("id", DataType.Int64, Boolean.TRUE, Boolean.FALSE);
collectionSchema.addVectorField("vector", DataType.FloatVector, dim);
collectionSchema.addScalarField("num", DataType.Int32);
CreateCollectionReq.CollectionSchema collectionSchema = client.createSchema();
collectionSchema.addField(AddFieldReq.builder().fieldName("id").dataType(DataType.Int64).isPrimaryKey(Boolean.TRUE).autoID(Boolean.FALSE).description("id").build());
collectionSchema.addField(AddFieldReq.builder().fieldName("vector").dataType(DataType.FloatVector).dimension(dim).build());
collectionSchema.addField(AddFieldReq.builder().fieldName("num").dataType(DataType.Int64).isPartitionKey(Boolean.TRUE).build());
collectionSchema.addField(AddFieldReq.builder().fieldName("array").dataType(DataType.Array).elementType(DataType.Int32).maxCapacity(10).description("array").build());

CreateCollectionReq createCollectionReq = CreateCollectionReq.builder()
.collectionName(collectionName)
.description("simple collection")
.collectionSchema(collectionSchema)
.enableDynamicField(Boolean.FALSE)
.build();
client.createCollection(createCollectionReq);
//create index
Expand All @@ -61,6 +60,7 @@ public void run() throws InterruptedException {
.indexParams(Collections.singletonList(indexParam))
.build();
client.createIndex(createIndexReq);
TimeUnit.SECONDS.sleep(1);
client.loadCollection(LoadCollectionReq.builder().collectionName(collectionName).build());
//insert data
List<JSONObject> insertData = new ArrayList<>();
Expand All @@ -71,9 +71,12 @@ public void run() throws InterruptedException {
// generate random float vector
vectorList.add(new Random().nextFloat());
}
List<Integer> array = new ArrayList<>();
array.add(i);
jsonObject.put("id", (long) i);
jsonObject.put("vector", vectorList);
jsonObject.put("num", i);
jsonObject.put("num", (long) i);
jsonObject.put("array", array);
insertData.add(jsonObject);
}

Expand All @@ -88,11 +91,13 @@ public void run() throws InterruptedException {
.filter("id in [0]")
.build();
QueryResp queryResp = client.query(queryReq);
queryResp.getQueryResults().get(0).getEntity().get("vector");
System.out.println(queryResp);
//search data
SearchReq searchReq = SearchReq.builder()
.collectionName(collectionName)
.data(Collections.singletonList(insertData.get(0).get("vector")))
.outputFields(Collections.singletonList("vector"))
.topK(10)
.build();
SearchResp searchResp = client.search(searchReq);
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/milvus/v2/exception/ErrorCode.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
public enum ErrorCode {
SUCCESS(0),
COLLECTION_NOT_FOUND(1),
SERVER_ERROR(2);
SERVER_ERROR(2),
INVALID_PARAMS(3);

private final int code;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public void createCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin

CollectionSchema schema = CollectionSchema.newBuilder()
.setName(request.getCollectionName())
.setDescription(request.getDescription())
.addFields(vectorSchema)
.addFields(idSchema)
.setEnableDynamicField(request.getEnableDynamicField())
Expand All @@ -53,6 +54,7 @@ public void createCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder()
.setCollectionName(request.getCollectionName())
.setSchema(schema.toByteString())
.setShardsNum(request.getNumShards())
.build();

Status status = blockingStub.createCollection(createCollectionRequest);
Expand All @@ -69,7 +71,12 @@ public void createCollection(MilvusServiceGrpc.MilvusServiceBlockingStub blockin
.build();
indexService.createIndex(blockingStub, createIndexReq);
//load collection
loadCollection(blockingStub, LoadCollectionReq.builder().collectionName(request.getCollectionName()).build());
try {
//TimeUnit.MILLISECONDS.sleep(1000);
loadCollection(blockingStub, LoadCollectionReq.builder().collectionName(request.getCollectionName()).build());
} catch (Exception e) {
throw new MilvusClientException(ErrorCode.SERVER_ERROR, "Load collection failed" + e.getMessage());
}
}

public void createCollectionWithSchema(MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub, CreateCollectionReq request) {
Expand All @@ -78,8 +85,8 @@ public void createCollectionWithSchema(MilvusServiceGrpc.MilvusServiceBlockingSt
//convert CollectionSchema to io.milvus.grpc.CollectionSchema
CollectionSchema grpcSchema = CollectionSchema.newBuilder()
.setName(request.getCollectionName())
.setDescription(request.getCollectionSchema().getDescription())
.setEnableDynamicField(request.getCollectionSchema().getEnableDynamicField())
.setDescription(request.getDescription())
.setEnableDynamicField(request.getEnableDynamicField())
.build();
for (CreateCollectionReq.FieldSchema fieldSchema : request.getCollectionSchema().getFieldSchemaList()) {
grpcSchema = grpcSchema.toBuilder().addFields(SchemaUtils.convertToGrpcFieldSchema(fieldSchema)).build();
Expand All @@ -89,8 +96,11 @@ public void createCollectionWithSchema(MilvusServiceGrpc.MilvusServiceBlockingSt
CreateCollectionRequest createCollectionRequest = CreateCollectionRequest.newBuilder()
.setCollectionName(request.getCollectionName())
.setSchema(grpcSchema.toByteString())
.setShardsNum(request.getNumShards())
.build();

if (request.getNumPartitions() != null) {
createCollectionRequest = createCollectionRequest.toBuilder().setNumPartitions(request.getNumPartitions()).build();
}
Status createCollectionResponse = blockingStub.createCollection(createCollectionRequest);
rpcUtils.handleResponse(title, createCollectionResponse);

Expand Down Expand Up @@ -181,7 +191,7 @@ public void loadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub milvusSer
String title = String.format("LoadCollectionRequest collectionName:%s", request.getCollectionName());
LoadCollectionRequest loadCollectionRequest = LoadCollectionRequest.newBuilder()
.setCollectionName(request.getCollectionName())
.setReplicaNumber(request.getReplicaNum())
.setReplicaNumber(request.getNumReplicas())
.build();
Status status = milvusServiceBlockingStub.loadCollection(loadCollectionRequest);
rpcUtils.handleResponse(title, status);
Expand Down Expand Up @@ -230,10 +240,8 @@ public GetCollectionStatsResp getCollectionStats(MilvusServiceGrpc.MilvusService
return getCollectionStatsResp;
}

public CreateCollectionReq.CollectionSchema createSchema(Boolean enableDynamicField, String description) {
public CreateCollectionReq.CollectionSchema createSchema() {
return CreateCollectionReq.CollectionSchema.builder()
.enableDynamicField(enableDynamicField)
.description(description)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.milvus.v2.service.collection.request;

import io.milvus.v2.common.DataType;
import lombok.Builder;
import lombok.Data;
import lombok.experimental.SuperBuilder;

@Data
@SuperBuilder
public class AddFieldReq {
private String fieldName;
@Builder.Default
private String description = "";
private DataType dataType;
@Builder.Default
private Integer maxLength = 65535;
@Builder.Default
private Boolean isPrimaryKey = Boolean.FALSE;
@Builder.Default
private Boolean isPartitionKey = Boolean.FALSE;
@Builder.Default
private Boolean autoID = Boolean.FALSE;
private Integer dimension;
private io.milvus.v2.common.DataType elementType;
private Integer maxCapacity;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.exception.ErrorCode;
import io.milvus.v2.exception.MilvusClientException;
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;
Expand All @@ -15,6 +17,8 @@
public class CreateCollectionReq {
@NonNull
private String collectionName;
@Builder.Default
private String description = "";
private Integer dimension;

@Builder.Default
Expand All @@ -29,23 +33,52 @@ public class CreateCollectionReq {
private String metricType = IndexParam.MetricType.IP.name();
@Builder.Default
private Boolean autoID = Boolean.FALSE;

// used by quickly create collections and create collections with schema
@Builder.Default
private Boolean enableDynamicField = Boolean.TRUE;
@Builder.Default
private Integer numShards = 1;

// create collections with schema
private CollectionSchema collectionSchema;

private List<IndexParam> indexParams;

//private String partitionKeyField;
private Integer numPartitions;

@Data
@SuperBuilder
public static class CollectionSchema {
@Builder.Default
private List<CreateCollectionReq.FieldSchema> fieldSchemaList = new ArrayList<>();
@Builder.Default
private String description = "";
@NonNull
private Boolean enableDynamicField;

public void addField(AddFieldReq addFieldReq) {
CreateCollectionReq.FieldSchema fieldSchema = FieldSchema.builder()
.name(addFieldReq.getFieldName())
.dataType(addFieldReq.getDataType())
.description(addFieldReq.getDescription())
.isPrimaryKey(addFieldReq.getIsPrimaryKey())
.isPartitionKey(addFieldReq.getIsPartitionKey())
.autoID(addFieldReq.getAutoID())
.build();
if (addFieldReq.getDataType().equals(DataType.Array)) {
if (addFieldReq.getElementType() == null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Element type, maxCapacity are required for array field");
}
fieldSchema.setElementType(addFieldReq.getElementType());
fieldSchema.setMaxCapacity(addFieldReq.getMaxCapacity());
} else if (addFieldReq.getDataType().equals(DataType.VarChar)) {
fieldSchema.setMaxLength(addFieldReq.getMaxLength());
} else if (addFieldReq.getDataType().equals(DataType.FloatVector) || addFieldReq.getDataType().equals(DataType.BinaryVector)) {
if (addFieldReq.getDimension() == null) {
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Dimension is required for vector field");
}
fieldSchema.setDimension(addFieldReq.getDimension());
}
fieldSchemaList.add(fieldSchema);
}

public CreateCollectionReq.FieldSchema getField(String fieldName) {
for (CreateCollectionReq.FieldSchema field : fieldSchemaList) {
Expand All @@ -55,72 +88,23 @@ public CreateCollectionReq.FieldSchema getField(String fieldName) {
}
return null;
}

public void addPrimaryField(String fieldName, DataType dataType, Boolean isPrimaryKey, Boolean autoID) {
// primary key field
CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder()
.name(fieldName)
.dataType(dataType)
.isPrimaryKey(isPrimaryKey)
.autoID(autoID)
.build();
fieldSchemaList.add(fieldSchema);
}

public void addPrimaryField(String fieldName, DataType dataType, Integer maxLength, Boolean isPrimaryKey, Boolean autoID) {
// primary key field
CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder()
.name(fieldName)
.dataType(dataType)
.maxLength(maxLength)
.isPrimaryKey(isPrimaryKey)
.autoID(autoID)
.build();
fieldSchemaList.add(fieldSchema);
}

public void addVectorField(String fieldName, DataType dataType, Integer dimension) {
// vector field
CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder()
.name(fieldName)
.dataType(dataType)
.dimension(dimension)
.build();
fieldSchemaList.add(fieldSchema);
}

public void addScalarField(String fieldName, DataType dataType, Integer maxLength) {
// scalar field
CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder()
.name(fieldName)
.dataType(dataType)
.maxLength(maxLength)
.build();
fieldSchemaList.add(fieldSchema);
}

public void addScalarField(String fieldName, DataType dataType) {
// scalar field
CreateCollectionReq.FieldSchema fieldSchema = CreateCollectionReq.FieldSchema.builder()
.name(fieldName)
.dataType(dataType)
.build();
fieldSchemaList.add(fieldSchema);
}
}

@Data
@SuperBuilder
public static class FieldSchema {
//TODO: check here
private String name;
@Builder.Default
private String description = "";
private DataType dataType;
@Builder.Default
private Integer maxLength = 65535;
private Integer dimension;
@Builder.Default
private Boolean isPrimaryKey = Boolean.FALSE;
@Builder.Default
private Boolean isPartitionKey = Boolean.FALSE;
@Builder.Default
private Boolean autoID = Boolean.FALSE;
private DataType elementType;
private Integer maxCapacity;
Expand Down
Loading

0 comments on commit c9eecfc

Please sign in to comment.