diff --git a/lilac/auth.py b/lilac/auth.py index a4a0b450..8cdb9067 100644 --- a/lilac/auth.py +++ b/lilac/auth.py @@ -79,6 +79,9 @@ class AuthenticationInfo(BaseModel): def has_garden_credentials() -> bool: """Returns whether the user has Garden credentials.""" + # TODO: more granular checks based on user permissions + if env('LILAC_API_KEY') is not None: + return True config = modal.config.Config().to_dict() return ( 'token_secret' in config diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index 7530aafe..a0e3bd37 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -3,11 +3,11 @@ import itertools from typing import Callable, Iterator, Optional, Union, cast -import modal import numpy as np from tqdm import tqdm -from ..batch_utils import compress_docs, flatten_path_iter +from .. import garden_client +from ..batch_utils import flatten_path_iter from ..dataset_format import DatasetFormatInputSelector from ..embeddings.jina import JinaV2Small from ..schema import ( @@ -301,11 +301,7 @@ def _hdbscan_cluster( ) -> Iterator[Item]: """Cluster docs with HDBSCAN.""" if use_garden: - remote_fn = modal.Function.lookup('cluster', 'Cluster.cluster').remote - with DebugTimer('Compressing docs for clustering remotely'): - gzipped_docs = compress_docs(list(docs)) - response = remote_fn({'gzipped_docs': gzipped_docs, 'min_cluster_size': min_cluster_size}) - yield from response['clusters'] + yield from garden_client.cluster(list(docs), min_cluster_size=min_cluster_size) if task_info: task_info.message = 'Computing embeddings' diff --git a/lilac/env.py b/lilac/env.py index 841684f3..8434fdd0 100644 --- a/lilac/env.py +++ b/lilac/env.py @@ -56,6 +56,9 @@ class LilacEnvironment(BaseModel): COHERE_API_KEY: str = PydanticField( description='The Cohere API key, used for computing `cohere` embeddings.' ) + LILAC_API_KEY: str = PydanticField( + description='The Lilac API key, used for running Lilac Garden computations.' + ) # HuggingFace demo. HF_ACCESS_TOKEN: str = PydanticField( diff --git a/lilac/garden_client.py b/lilac/garden_client.py new file mode 100644 index 00000000..9188c6cf --- /dev/null +++ b/lilac/garden_client.py @@ -0,0 +1,63 @@ +"""Client code for sending requests to Lilac Garden.""" +import base64 +import functools +import io +import json +from typing import Any, Callable, Iterator + +import numpy as np +import requests +from pydantic import BaseModel + +from .env import env +from .schema import Item +from .utils import DebugTimer + +GARDEN_FRONT_GATE_URL = 'https://lilacai--front-gate-fastapi-app.modal.run' +GARDEN_ENCODING_SCHEME_HEADER = 'X-Lilac-EncodingScheme' +GARDEN_DOC_COUNT_HEADER = 'X-Lilac-DocCount' +PLATFORM_LIST_ENDPOINTS = 'https://platform-v3omj4i5vq-uc.a.run.app/api/v1/entitlements/list' + + +def _decode_b64_npy(b: bytes) -> np.ndarray: + return np.load(io.BytesIO(base64.b64decode(b))) + + +DECODERS: dict[str, Callable[[bytes], Item]] = {'b64-npy': _decode_b64_npy, 'json': json.loads} + + +def _call_garden(endpoint_name: str, docs: list[Any], **kwargs: Any) -> Iterator[Item]: + lilac_api_key = env('LILAC_API_KEY') + + with DebugTimer('Running garden endpoint %s' % endpoint_name): + with requests.post( + GARDEN_FRONT_GATE_URL + '/' + endpoint_name, + data=json.dumps(docs), + params={k: str(v) for k, v in kwargs.items()}, + headers={ + 'Authorization': f'Bearer {lilac_api_key}', + GARDEN_DOC_COUNT_HEADER: str(len(docs)), + }, + stream=True, + ) as response: + if response.status_code > 299: + raise requests.HTTPError(response.text) + decoder = DECODERS[response.headers[GARDEN_ENCODING_SCHEME_HEADER]] + for line in response.iter_lines(): + yield decoder(line) + + +cluster = functools.partial(_call_garden, 'clustering') +pii = functools.partial(_call_garden, 'pii') + + +class GardenEndpoint(BaseModel): + entitlement_name: str + limitations: dict[str, Any] + + +def list_endpoints() -> list[GardenEndpoint]: + """List the available endpoints.""" + # TODO: put api key in Auth header instead of json body. + lilac_api_key = env('LILAC_API_KEY') + return requests.post(PLATFORM_LIST_ENDPOINTS, json={'api_key': lilac_api_key}).json() diff --git a/lilac/router_garden.py b/lilac/router_garden.py new file mode 100644 index 00000000..41062956 --- /dev/null +++ b/lilac/router_garden.py @@ -0,0 +1,14 @@ +"""Endpoints for Lilac Garden.""" + +from fastapi import APIRouter + +from .garden_client import GardenEndpoint, list_endpoints +from .router_utils import RouteErrorHandler + +router = APIRouter(route_class=RouteErrorHandler) + + +@router.get('/') +def get_available_endpoints() -> list[GardenEndpoint]: + """Get the list of available sources.""" + return list_endpoints() diff --git a/lilac/server.py b/lilac/server.py index 9bf7019c..6250ee8d 100644 --- a/lilac/server.py +++ b/lilac/server.py @@ -32,6 +32,7 @@ router_data_loader, router_dataset, router_dataset_signals, + router_garden, router_google_login, router_rag, router_signal, @@ -141,6 +142,7 @@ def module_not_found_error(request: Request, exc: ModuleNotFoundError) -> JSONRe v1_router.include_router(router_signal.router, prefix='/signals', tags=['signals']) v1_router.include_router(router_tasks.router, prefix='/tasks', tags=['tasks']) v1_router.include_router(router_rag.router, prefix='/rag', tags=['rag']) +v1_router.include_router(router_garden.router, prefix='/garden', tags=['garden']) for source_name, source in registered_sources().items(): if source.router: diff --git a/lilac/signals/pii.py b/lilac/signals/pii.py index 12a2b4c3..1eb97459 100644 --- a/lilac/signals/pii.py +++ b/lilac/signals/pii.py @@ -1,14 +1,13 @@ """Compute text statistics for a document.""" from typing import ClassVar, Iterator, Optional -import modal from typing_extensions import override -from ..batch_utils import compress_docs +from .. import garden_client from ..schema import Field, Item, RichData, SignalInputType, field from ..signal import TextSignal from ..tasks import TaskExecutionType -from ..utils import DebugTimer, chunks +from ..utils import DebugTimer SECRETS_KEY = 'secrets' # Selected categories. For all categories, see: @@ -72,10 +71,5 @@ def compute(self, data: list[RichData]) -> list[Optional[Item]]: @override def compute_garden(self, docs: Iterator[str]) -> Iterator[Item]: - pii = modal.Function.lookup('pii', 'PII.detect') with DebugTimer('Computing PII on Lilac Garden'): - batches = chunks(docs, PII_REMOTE_BATCH_SIZE) - requests = ({'gzipped_docs': compress_docs(b)} for b in batches) - for response in pii.map(requests, order_outputs=True): - for item in response['result']: - yield item + yield from garden_client.pii(list(docs)) diff --git a/web/blueprint/src/lib/components/ComputeClusterModal.svelte b/web/blueprint/src/lib/components/ComputeClusterModal.svelte index 1380a7af..83f18b89 100644 --- a/web/blueprint/src/lib/components/ComputeClusterModal.svelte +++ b/web/blueprint/src/lib/components/ComputeClusterModal.svelte @@ -96,10 +96,10 @@ options.namespace, options.datasetName, { - input: selectedFormatSelector == null ? options.input : null, + input: selectedFormatSelector === 'none' ? options.input : null, use_garden: options.use_garden, output_path: outputColumn, - input_selector: selectedFormatSelector, + input_selector: selectedFormatSelector === 'none' ? null : selectedFormatSelector, overwrite: options.overwrite } ]); diff --git a/web/lib/fastapi_client/index.ts b/web/lib/fastapi_client/index.ts index 50ed213d..4f6ebbe2 100644 --- a/web/lib/fastapi_client/index.ts +++ b/web/lib/fastapi_client/index.ts @@ -47,6 +47,7 @@ export type { ExampleIn } from './models/ExampleIn'; export type { ExampleOrigin } from './models/ExampleOrigin'; export type { ExportOptions } from './models/ExportOptions'; export type { Field } from './models/Field'; +export type { GardenEndpoint } from './models/GardenEndpoint'; export type { GetStatsOptions } from './models/GetStatsOptions'; export type { GroupsSortBy } from './models/GroupsSortBy'; export type { HTTPValidationError } from './models/HTTPValidationError'; @@ -113,6 +114,7 @@ export { ConceptsService } from './services/ConceptsService'; export { DataLoadersService } from './services/DataLoadersService'; export { DatasetsService } from './services/DatasetsService'; export { DefaultService } from './services/DefaultService'; +export { GardenService } from './services/GardenService'; export { GoogleLoginService } from './services/GoogleLoginService'; export { LangsmithService } from './services/LangsmithService'; export { RagService } from './services/RagService'; diff --git a/web/lib/fastapi_client/models/GardenEndpoint.ts b/web/lib/fastapi_client/models/GardenEndpoint.ts new file mode 100644 index 00000000..e179ceb8 --- /dev/null +++ b/web/lib/fastapi_client/models/GardenEndpoint.ts @@ -0,0 +1,10 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type GardenEndpoint = { + entitlement_name: string; + limitations: Record; +}; + diff --git a/web/lib/fastapi_client/services/GardenService.ts b/web/lib/fastapi_client/services/GardenService.ts new file mode 100644 index 00000000..6ad706e4 --- /dev/null +++ b/web/lib/fastapi_client/services/GardenService.ts @@ -0,0 +1,26 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { GardenEndpoint } from '../models/GardenEndpoint'; + +import type { CancelablePromise } from '../core/CancelablePromise'; +import { OpenAPI } from '../core/OpenAPI'; +import { request as __request } from '../core/request'; + +export class GardenService { + + /** + * Get Available Endpoints + * Get the list of available sources. + * @returns GardenEndpoint Successful Response + * @throws ApiError + */ + public static getAvailableEndpoints(): CancelablePromise> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/garden/', + }); + } + +}