Skip to content

Commit

Permalink
fix: ensure create_index and load_collection are fully completed (#…
Browse files Browse the repository at this point in the history
…2476)

- Add `wait_for_creating_index`, `wait_for_loading_collection` to ensure
`create_index` and `load_collection` are fully completed

Signed-off-by: Ruichen Bao <ruichen.bao@zju.edu.cn>
  • Loading branch information
brcarry authored Dec 19, 2024
1 parent 2841bdb commit e4bb0ef
Showing 1 changed file with 109 additions and 1 deletion.
110 changes: 109 additions & 1 deletion pymilvus/client/async_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
import base64
import copy
import socket
import time
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from urllib import parse

import grpc
from grpc._cython import cygrpc

from pymilvus.decorators import retry_on_rpc_failure, upgrade_reminder
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure, upgrade_reminder
from pymilvus.exceptions import (
AmbiguousIndexName,
DescribeCollectionException,
ExceptionsMessage,
MilvusException,
ParamError,
)
Expand All @@ -32,6 +35,7 @@
from .types import (
DataType,
ExtraList,
IndexState,
Status,
get_cost_extra,
)
Expand Down Expand Up @@ -291,6 +295,43 @@ async def load_collection(
response = await self._async_stub.LoadCollection(request, timeout=timeout)
check_status(response)

await self.wait_for_loading_collection(collection_name, timeout, is_refresh=refresh)

@retry_on_rpc_failure()
async def wait_for_loading_collection(
self, collection_name: str, timeout: Optional[float] = None, is_refresh: bool = False
):
start = time.time()

def can_loop(t: int) -> bool:
return True if timeout is None else t <= (start + timeout)

while can_loop(time.time()):
progress = await self.get_loading_progress(
collection_name, timeout=timeout, is_refresh=is_refresh
)
if progress >= 100:
return
await asyncio.sleep(Config.WaitTimeDurationWhenLoad)
raise MilvusException(
message=f"wait for loading collection timeout, collection: {collection_name}"
)

@retry_on_rpc_failure()
async def get_loading_progress(
self,
collection_name: str,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
is_refresh: bool = False,
):
request = Prepare.get_loading_progress(collection_name, partition_names)
response = await self._async_stub.GetLoadingProgress(request, timeout=timeout)
check_status(response.status)
if is_refresh:
return response.refresh_progress
return response.progress

@retry_on_rpc_failure()
async def describe_collection(
self, collection_name: str, timeout: Optional[float] = None, **kwargs
Expand Down Expand Up @@ -635,8 +676,67 @@ async def create_index(
status = await self._async_stub.CreateIndex(index_param, timeout=timeout)
check_status(status)

index_success, fail_reason = await self.wait_for_creating_index(
collection_name=collection_name,
index_name=index_name,
timeout=timeout,
field_name=field_name,
)

if not index_success:
raise MilvusException(message=fail_reason)

return Status(status.code, status.reason)

@retry_on_rpc_failure()
async def wait_for_creating_index(
self, collection_name: str, index_name: str, timeout: Optional[float] = None, **kwargs
):
timestamp = await self.alloc_timestamp()
start = time.time()
while True:
await asyncio.sleep(0.5)
state, fail_reason = await self.get_index_state(
collection_name, index_name, timeout=timeout, timestamp=timestamp, **kwargs
)
if state == IndexState.Finished:
return True, fail_reason
if state == IndexState.Failed:
return False, fail_reason
end = time.time()
if isinstance(timeout, int) and end - start > timeout:
msg = (
f"collection {collection_name} create index {index_name} "
f"timeout in {timeout}s"
)
raise MilvusException(message=msg)

@retry_on_rpc_failure()
async def get_index_state(
self,
collection_name: str,
index_name: str,
timeout: Optional[float] = None,
timestamp: Optional[int] = None,
**kwargs,
):
request = Prepare.describe_index_request(collection_name, index_name, timestamp)
response = await self._async_stub.DescribeIndex(request, timeout=timeout)
status = response.status
check_status(status)

if len(response.index_descriptions) == 1:
index_desc = response.index_descriptions[0]
return index_desc.state, index_desc.index_state_fail_reason
# just for create_index.
field_name = kwargs.pop("field_name", "")
if field_name != "":
for index_desc in response.index_descriptions:
if index_desc.field_name == field_name:
return index_desc.state, index_desc.index_state_fail_reason

raise AmbiguousIndexName(message=ExceptionsMessage.AmbiguousIndexName)

@retry_on_rpc_failure()
async def get(
self,
Expand Down Expand Up @@ -705,3 +805,11 @@ async def wait_for_connect_response():

check_status(response.status)
return response.identifier

@retry_on_rpc_failure()
@ignore_unimplemented(0)
async def alloc_timestamp(self, timeout: Optional[float] = None) -> int:
request = milvus_types.AllocTimestampRequest()
response = await self._async_stub.AllocTimestamp(request, timeout=timeout)
check_status(response.status)
return response.timestamp

0 comments on commit e4bb0ef

Please sign in to comment.