Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add garden client and replace cluster implementation #1186

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lilac/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class AuthenticationInfo(BaseModel):

def has_garden_credentials() -> bool:
"""Returns whether the user has Garden credentials."""
# TODO: more granular checks based on user permissions
if env('LILAC_API_KEY') is not None:
return True
config = modal.config.Config().to_dict()
return (
'token_secret' in config
Expand Down
10 changes: 3 additions & 7 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import itertools
from typing import Callable, Iterator, Optional, Union, cast

import modal
import numpy as np
from tqdm import tqdm

from ..batch_utils import compress_docs, flatten_path_iter
from .. import garden_client
from ..batch_utils import flatten_path_iter
from ..dataset_format import DatasetFormatInputSelector
from ..embeddings.jina import JinaV2Small
from ..schema import (
Expand Down Expand Up @@ -311,11 +311,7 @@ def _hdbscan_cluster(
) -> Iterator[Item]:
"""Cluster docs with HDBSCAN."""
if use_garden:
remote_fn = modal.Function.lookup('cluster', 'Cluster.cluster').remote
with DebugTimer('Compressing docs for clustering remotely'):
gzipped_docs = compress_docs(list(docs))
response = remote_fn({'gzipped_docs': gzipped_docs, 'min_cluster_size': min_cluster_size})
yield from response['clusters']
yield from garden_client.cluster(list(docs), min_cluster_size=min_cluster_size)

if task_info:
task_info.message = 'Computing embeddings'
Expand Down
3 changes: 3 additions & 0 deletions lilac/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class LilacEnvironment(BaseModel):
COHERE_API_KEY: str = PydanticField(
description='The Cohere API key, used for computing `cohere` embeddings.'
)
LILAC_API_KEY: str = PydanticField(
description='The Lilac API key, used for running Lilac Garden computations.'
)

# HuggingFace demo.
HF_ACCESS_TOKEN: str = PydanticField(
Expand Down
47 changes: 47 additions & 0 deletions lilac/garden_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Client code for sending requests to Lilac Garden."""
import base64
import functools
import io
import json
from typing import Any, Callable, Iterator

import numpy as np
import requests

from .env import env
from .schema import Item
from .utils import DebugTimer

GARDEN_FRONT_GATE_URL = 'https://lilacai--front-gate-fastapi-app-dev.modal.run'
GARDEN_ENCODING_SCHEME_HEADER = 'X-Lilac-EncodingScheme'


def _decode_b64_npy(b: bytes) -> np.ndarray:
return np.load(io.BytesIO(base64.b64decode(b)))


DECODERS: dict[str, Callable[[bytes], Item]] = {'b64-npy': _decode_b64_npy, 'json': json.loads}


def _call_garden(endpoint_name: str, docs: list[Any], **kwargs: Any) -> Iterator[Item]:
lilac_api_key = env('LILAC_API_KEY')

with DebugTimer('Running garden endpoint %s' % endpoint_name):
with requests.post(
GARDEN_FRONT_GATE_URL + '/' + endpoint_name,
data=json.dumps(docs),
params={k: str(v) for k, v in kwargs.items()},
headers={
'Authorization': 'Bearer %s' % lilac_api_key,
brilee marked this conversation as resolved.
Show resolved Hide resolved
'X-Lilac-RowCount': str(len(docs)),
brilee marked this conversation as resolved.
Show resolved Hide resolved
},
stream=True,
) as response:
if response.status_code > 299:
raise requests.HTTPError(response.text)
decoder = DECODERS[response.headers[GARDEN_ENCODING_SCHEME_HEADER]]
for line in response.iter_lines():
yield decoder(line)


cluster = functools.partial(_call_garden, 'cluster')
Loading