diff --git a/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml b/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml index c85735284c5..a6ff2f51c5a 100644 --- a/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml +++ b/.github/workflows/docker-build-push-cloud-web-container-on-tag.yml @@ -65,6 +65,7 @@ jobs: NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }} NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }} NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }} + NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }} NEXT_PUBLIC_GTM_ENABLED=true NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true diff --git a/backend/ee/onyx/configs/app_configs.py b/backend/ee/onyx/configs/app_configs.py index e34f87c315f..b567db38acd 100644 --- a/backend/ee/onyx/configs/app_configs.py +++ b/backend/ee/onyx/configs/app_configs.py @@ -77,3 +77,5 @@ HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL") ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user" + +GATED_TENANTS_KEY = "gated_tenants" diff --git a/backend/ee/onyx/server/query_and_chat/query_backend.py b/backend/ee/onyx/server/query_and_chat/query_backend.py index 34fc9dbaf3f..2910ac3a76a 100644 --- a/backend/ee/onyx/server/query_and_chat/query_backend.py +++ b/backend/ee/onyx/server/query_and_chat/query_backend.py @@ -83,6 +83,7 @@ def handle_search_request( user=user, llm=llm, fast_llm=fast_llm, + skip_query_analysis=False, db_session=db_session, bypass_acl=False, ) diff --git a/backend/ee/onyx/server/tenants/api.py b/backend/ee/onyx/server/tenants/api.py index ed0e26d768c..95e4cc1b8cd 100644 --- a/backend/ee/onyx/server/tenants/api.py +++ b/backend/ee/onyx/server/tenants/api.py @@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path from ee.onyx.server.tenants.billing import fetch_billing_information +from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information from ee.onyx.server.tenants.models import AnonymousUserPath from ee.onyx.server.tenants.models import BillingInformation from ee.onyx.server.tenants.models import ImpersonateRequest from ee.onyx.server.tenants.models import ProductGatingRequest +from ee.onyx.server.tenants.models import ProductGatingResponse +from ee.onyx.server.tenants.models import SubscriptionSessionResponse +from ee.onyx.server.tenants.models import SubscriptionStatusResponse +from ee.onyx.server.tenants.product_gating import store_product_gating from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant @@ -39,12 +44,9 @@ from onyx.db.engine import get_current_tenant_id from onyx.db.engine import get_session from onyx.db.engine import get_session_with_tenant -from onyx.db.notification import create_notification from onyx.db.users import delete_user_from_db from onyx.db.users import get_user_by_email from onyx.server.manage.models import UserByEmail -from onyx.server.settings.store import load_settings -from onyx.server.settings.store import store_settings from onyx.utils.logger import setup_logger from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -126,37 +128,29 @@ async def login_as_anonymous_user( @router.post("/product-gating") def gate_product( product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) -) -> None: +) -> ProductGatingResponse: """ Gating the product means that the product is not available to the tenant. They will be directed to the billing page. - We gate the product when - 1) User has ended free trial without adding payment method - 2) User's card has declined + We gate the product when their subscription has ended. """ - tenant_id = product_gating_request.tenant_id - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - - settings = load_settings() - settings.product_gating = product_gating_request.product_gating - store_settings(settings) - - if product_gating_request.notification: - with get_session_with_tenant(tenant_id) as db_session: - create_notification(None, product_gating_request.notification, db_session) + try: + store_product_gating( + product_gating_request.tenant_id, product_gating_request.application_status + ) + return ProductGatingResponse(updated=True, error=None) - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + except Exception as e: + logger.exception("Failed to gate product") + return ProductGatingResponse(updated=False, error=str(e)) -@router.get("/billing-information", response_model=BillingInformation) +@router.get("/billing-information") async def billing_information( _: User = Depends(current_admin_user), -) -> BillingInformation: +) -> BillingInformation | SubscriptionStatusResponse: logger.info("Fetching billing information") - return BillingInformation( - **fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get()) - ) + return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get()) @router.post("/create-customer-portal-session") @@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user)) if not stripe_customer_id: raise HTTPException(status_code=400, detail="Stripe customer ID not found") logger.info(stripe_customer_id) + portal_session = stripe.billing_portal.Session.create( customer=stripe_customer_id, - return_url=f"{WEB_DOMAIN}/admin/cloud-settings", + return_url=f"{WEB_DOMAIN}/admin/billing", ) logger.info(portal_session) return {"url": portal_session.url} @@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user)) raise HTTPException(status_code=500, detail=str(e)) +@router.post("/create-subscription-session") +async def create_subscription_session( + _: User = Depends(current_admin_user), +) -> SubscriptionSessionResponse: + try: + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + session_id = fetch_stripe_checkout_session(tenant_id) + return SubscriptionSessionResponse(sessionId=session_id) + + except Exception as e: + logger.exception("Failed to create resubscription session") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/impersonate") async def impersonate_user( impersonate_request: ImpersonateRequest, diff --git a/backend/ee/onyx/server/tenants/billing.py b/backend/ee/onyx/server/tenants/billing.py index e8d0e39ab60..98de75a9aef 100644 --- a/backend/ee/onyx/server/tenants/billing.py +++ b/backend/ee/onyx/server/tenants/billing.py @@ -6,6 +6,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY from ee.onyx.server.tenants.access import generate_data_plane_token +from ee.onyx.server.tenants.models import BillingInformation from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.utils.logger import setup_logger @@ -14,6 +15,19 @@ logger = setup_logger() +def fetch_stripe_checkout_session(tenant_id: str) -> str: + token = generate_data_plane_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session" + params = {"tenant_id": tenant_id} + response = requests.post(url, headers=headers, params=params) + response.raise_for_status() + return response.json()["sessionId"] + + def fetch_tenant_stripe_information(tenant_id: str) -> dict: token = generate_data_plane_token() headers = { @@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict: return response.json() -def fetch_billing_information(tenant_id: str) -> dict: +def fetch_billing_information(tenant_id: str) -> BillingInformation: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { @@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict: params = {"tenant_id": tenant_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() - billing_info = response.json() + billing_info = BillingInformation(**response.json()) return billing_info diff --git a/backend/ee/onyx/server/tenants/models.py b/backend/ee/onyx/server/tenants/models.py index 6be53d7dae9..7931a06a7a0 100644 --- a/backend/ee/onyx/server/tenants/models.py +++ b/backend/ee/onyx/server/tenants/models.py @@ -1,7 +1,8 @@ +from datetime import datetime + from pydantic import BaseModel -from onyx.configs.constants import NotificationType -from onyx.server.settings.models import GatingType +from onyx.server.settings.models import ApplicationStatus class CheckoutSessionCreationRequest(BaseModel): @@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel): class ProductGatingRequest(BaseModel): tenant_id: str - product_gating: GatingType - notification: NotificationType | None = None + application_status: ApplicationStatus + + +class SubscriptionStatusResponse(BaseModel): + subscribed: bool class BillingInformation(BaseModel): + stripe_subscription_id: str + status: str + current_period_start: datetime + current_period_end: datetime + number_of_seats: int + cancel_at_period_end: bool + canceled_at: datetime | None + trial_start: datetime | None + trial_end: datetime | None seats: int - subscription_status: str - billing_start: str - billing_end: str payment_method_enabled: bool @@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel): class AnonymousUserPath(BaseModel): anonymous_user_path: str | None + + +class ProductGatingResponse(BaseModel): + updated: bool + error: str | None + + +class SubscriptionSessionResponse(BaseModel): + sessionId: str diff --git a/backend/ee/onyx/server/tenants/product_gating.py b/backend/ee/onyx/server/tenants/product_gating.py new file mode 100644 index 00000000000..43123aaaed1 --- /dev/null +++ b/backend/ee/onyx/server/tenants/product_gating.py @@ -0,0 +1,51 @@ +from typing import cast + +from ee.onyx.configs.app_configs import GATED_TENANTS_KEY +from onyx.configs.constants import ONYX_CLOUD_TENANT_ID +from onyx.redis.redis_pool import get_redis_client +from onyx.redis.redis_pool import get_redis_replica_client +from onyx.server.settings.models import ApplicationStatus +from onyx.server.settings.store import load_settings +from onyx.server.settings.store import store_settings +from onyx.setup import setup_logger +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR + +logger = setup_logger() + + +def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None: + redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) + + # Store the full status + status_key = f"tenant:{tenant_id}:status" + redis_client.set(status_key, status.value) + + # Maintain the GATED_ACCESS set + if status == ApplicationStatus.GATED_ACCESS: + redis_client.sadd(GATED_TENANTS_KEY, tenant_id) + else: + redis_client.srem(GATED_TENANTS_KEY, tenant_id) + + +def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None: + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + settings = load_settings() + settings.application_status = application_status + store_settings(settings) + + # Store gated tenant information in Redis + update_tenant_gating(tenant_id, application_status) + + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + except Exception: + logger.exception("Failed to gate product") + raise + + +def get_gated_tenants() -> set[str]: + redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID) + return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY)) diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py index 1867da66ceb..aba7e0f69ce 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py @@ -21,10 +21,11 @@ from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.agent_configs import AGENT_RERANKING_STATS from onyx.context.search.models import InferenceSection -from onyx.context.search.models import SearchRequest -from onyx.context.search.pipeline import retrieval_preprocessing +from onyx.context.search.models import RerankingDetails from onyx.context.search.postprocessing.postprocessing import rerank_sections +from onyx.context.search.postprocessing.postprocessing import should_rerank from onyx.db.engine import get_session_context_manager +from onyx.db.search_settings import get_current_search_settings def rerank_documents( @@ -39,6 +40,8 @@ def rerank_documents( # Rerank post retrieval and verification. First, create a search query # then create the list of reranked sections + # If no question defined/question is None in the state, use the original + # question from the search request as query graph_config = cast(GraphConfig, config["metadata"]["config"]) question = ( @@ -47,39 +50,28 @@ def rerank_documents( assert ( graph_config.tooling.search_tool ), "search_tool must be provided for agentic search" - with get_session_context_manager() as db_session: - # we ignore some of the user specified fields since this search is - # internal to agentic search, but we still want to pass through - # persona (for stuff like document sets) and rerank settings - # (to not make an unnecessary db call). - search_request = SearchRequest( - query=question, - persona=graph_config.inputs.search_request.persona, - rerank_settings=graph_config.inputs.search_request.rerank_settings, - ) - _search_query = retrieval_preprocessing( - search_request=search_request, - user=graph_config.tooling.search_tool.user, # bit of a hack - llm=graph_config.tooling.fast_llm, - db_session=db_session, - ) - # skip section filtering + # Note that these are passed in values from the API and are overrides which are typically None + rerank_settings = graph_config.inputs.search_request.rerank_settings - if ( - _search_query.rerank_settings - and _search_query.rerank_settings.rerank_model_name - and _search_query.rerank_settings.num_rerank > 0 - and len(verified_documents) > 0 - ): + if rerank_settings is None: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + if not search_settings.disable_rerank_for_streaming: + rerank_settings = RerankingDetails.from_db_model(search_settings) + + if should_rerank(rerank_settings) and len(verified_documents) > 0: if len(verified_documents) > 1: reranked_documents = rerank_sections( - _search_query, - verified_documents, + query_str=question, + # if runnable, then rerank_settings is not None + rerank_settings=cast(RerankingDetails, rerank_settings), + sections_to_rerank=verified_documents, ) else: - num = "No" if len(verified_documents) == 0 else "One" - logger.warning(f"{num} verified document(s) found, skipping reranking") + logger.warning( + f"{len(verified_documents)} verified document(s) found, skipping reranking" + ) reranked_documents = verified_documents else: logger.warning("No reranking settings found, using unranked documents") diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py index 4fe84d0381c..b0347f75eef 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py @@ -23,6 +23,7 @@ from onyx.context.search.models import InferenceSection from onyx.db.engine import get_session_context_manager from onyx.tools.models import SearchQueryInfo +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) @@ -67,9 +68,12 @@ def retrieve_documents( with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=query_to_retrieve, - force_no_rerank=True, - alternate_db_session=db_session, - retrieved_sections_callback=callback_container.append, + override_kwargs=SearchToolOverrideKwargs( + force_no_rerank=True, + alternate_db_session=db_session, + retrieved_sections_callback=callback_container.append, + skip_query_analysis=not state.base_search, + ), ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index e4539d50e5e..86c7c0b490e 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -58,6 +58,7 @@ ) from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.tools.force import ForceUseTool +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, @@ -218,7 +219,10 @@ def get_test_config( using_tool_calling_llm=using_tool_calling_llm, ) - chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID") + chat_session_id = ( + os.environ.get("ONYX_AS_CHAT_SESSION_ID") + or "00000000-0000-0000-0000-000000000000" + ) assert ( chat_session_id is not None ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests" @@ -341,8 +345,12 @@ def retrieve_search_docs( with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=question, - force_no_rerank=True, - alternate_db_session=db_session, + override_kwargs=SearchToolOverrideKwargs( + force_no_rerank=True, + alternate_db_session=db_session, + retrieved_sections_callback=None, + skip_query_analysis=False, + ), ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index baa523247f2..08025605e80 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -478,14 +478,15 @@ def update_external_document_permissions_task( ) doc_id = document_external_access.doc_id external_access = document_external_access.external_access + try: with get_session_with_tenant(tenant_id) as db_session: - # Add the users to the DB if they don't exist batch_add_ext_perm_user_if_not_exists( db_session=db_session, emails=list(external_access.external_user_emails), + continue_on_error=True, ) - # Then we upsert the document's external permissions in postgres + # Then upsert the document's external permissions created_new_doc = upsert_document_external_perms( db_session=db_session, doc_id=doc_id, @@ -509,11 +510,11 @@ def update_external_document_permissions_task( f"action=update_permissions " f"elapsed={elapsed:.2f}" ) + except Exception: task_logger.exception( f"Exception in update_external_document_permissions_task: " - f"connector_id={connector_id} " - f"doc_id={doc_id}" + f"connector_id={connector_id} doc_id={doc_id}" ) return False diff --git a/backend/onyx/background/celery/tasks/monitoring/tasks.py b/backend/onyx/background/celery/tasks/monitoring/tasks.py index 91761091c59..5fa1440ca88 100644 --- a/backend/onyx/background/celery/tasks/monitoring/tasks.py +++ b/backend/onyx/background/celery/tasks/monitoring/tasks.py @@ -421,6 +421,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric] - Throughput (docs/min) (only if success) - Raw start/end times for each sync """ + one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1) # Get all sync records that ended in the last hour @@ -588,6 +589,10 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric] entity = db_session.scalar( select(UserGroup).where(UserGroup.id == sync_record.entity_id) ) + else: + # Only user groups and document set sync records have + # an associated entity we can use for latency metrics + continue if entity is None: task_logger.error( @@ -778,7 +783,7 @@ def cloud_check_alembic() -> bool | None: tenant_to_revision[tenant_id] = result_scalar except Exception: - task_logger.warning(f"Tenant {tenant_id} has no revision!") + task_logger.error(f"Tenant {tenant_id} has no revision!") tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION # get the total count of each revision diff --git a/backend/onyx/background/celery/tasks/shared/tasks.py b/backend/onyx/background/celery/tasks/shared/tasks.py index 34ded164f9d..a27c4723a95 100644 --- a/backend/onyx/background/celery/tasks/shared/tasks.py +++ b/backend/onyx/background/celery/tasks/shared/tasks.py @@ -8,6 +8,7 @@ from redis.lock import Lock as RedisLock from tenacity import RetryError +from ee.onyx.server.tenants.product_gating import get_gated_tenants from onyx.access.access import get_access_for_document from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT @@ -252,7 +253,11 @@ def cloud_beat_task_generator( try: tenant_ids = get_all_tenant_ids() + gated_tenants = get_gated_tenants() for tenant_id in tenant_ids: + if tenant_id in gated_tenants: + continue + current_time = time.monotonic() if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4): lock_beat.reacquire() diff --git a/backend/onyx/connectors/slack/utils.py b/backend/onyx/connectors/slack/utils.py index 4a20cd92ebc..d992eb55620 100644 --- a/backend/onyx/connectors/slack/utils.py +++ b/backend/onyx/connectors/slack/utils.py @@ -39,19 +39,6 @@ def get_message_link( return permalink -def _make_slack_api_call_logged( - call: Callable[..., SlackResponse], -) -> Callable[..., SlackResponse]: - @wraps(call) - def logged_call(**kwargs: Any) -> SlackResponse: - logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'") - result = call(**kwargs) - logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'") - return result - - return logged_call - - def _make_slack_api_call_paginated( call: Callable[..., SlackResponse], ) -> Callable[..., Generator[dict[str, Any], None, None]]: @@ -127,18 +114,14 @@ def rate_limited_call(**kwargs: Any) -> SlackResponse: def make_slack_api_call_w_retries( call: Callable[..., SlackResponse], **kwargs: Any ) -> SlackResponse: - return basic_retry_wrapper( - make_slack_api_rate_limited(_make_slack_api_call_logged(call)) - )(**kwargs) + return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs) def make_paginated_slack_api_call_w_retries( call: Callable[..., SlackResponse], **kwargs: Any ) -> Generator[dict[str, Any], None, None]: return _make_slack_api_call_paginated( - basic_retry_wrapper( - make_slack_api_rate_limited(_make_slack_api_call_logged(call)) - ) + basic_retry_wrapper(make_slack_api_rate_limited(call)) )(**kwargs) diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index c3d5177cd48..faf7a898892 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -51,6 +51,7 @@ def __init__( user: User | None, llm: LLM, fast_llm: LLM, + skip_query_analysis: bool, db_session: Session, bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION retrieval_metrics_callback: ( @@ -61,10 +62,13 @@ def __init__( rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, prompt_config: PromptConfig | None = None, ): + # NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None + # and typically are None. The preprocessing will fetch default values to replace these empty overrides. self.search_request = search_request self.user = user self.llm = llm self.fast_llm = fast_llm + self.skip_query_analysis = skip_query_analysis self.db_session = db_session self.bypass_acl = bypass_acl self.retrieval_metrics_callback = retrieval_metrics_callback @@ -106,6 +110,7 @@ def _run_preprocessing(self) -> None: search_request=self.search_request, user=self.user, llm=self.llm, + skip_query_analysis=self.skip_query_analysis, db_session=self.db_session, bypass_acl=self.bypass_acl, ) @@ -160,6 +165,12 @@ def _get_sections(self) -> list[InferenceSection]: that have a corresponding chunk. This step should be fast for any document index implementation. + + Current implementation timing is approximately broken down in timing as: + - 200 ms to get the embedding of the query + - 15 ms to get chunks from the document index + - possibly more to get additional surrounding chunks + - possibly more for query expansion (multilingual) """ if self._retrieved_sections is not None: return self._retrieved_sections diff --git a/backend/onyx/context/search/postprocessing/postprocessing.py b/backend/onyx/context/search/postprocessing/postprocessing.py index 41951405478..34f8a18d923 100644 --- a/backend/onyx/context/search/postprocessing/postprocessing.py +++ b/backend/onyx/context/search/postprocessing/postprocessing.py @@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunkUncleaned from onyx.context.search.models import InferenceSection from onyx.context.search.models import MAX_METRICS_CONTENT +from onyx.context.search.models import RerankingDetails from onyx.context.search.models import RerankMetricsContainer from onyx.context.search.models import SearchQuery from onyx.document_index.document_index_utils import ( @@ -77,7 +78,8 @@ def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str: @log_function_time(print_only=True) def semantic_reranking( - query: SearchQuery, + query_str: str, + rerank_settings: RerankingDetails, chunks: list[InferenceChunk], model_min: int = CROSS_ENCODER_RANGE_MIN, model_max: int = CROSS_ENCODER_RANGE_MAX, @@ -88,11 +90,9 @@ def semantic_reranking( Note: this updates the chunks in place, it updates the chunk scores which came from retrieval """ - rerank_settings = query.rerank_settings - - if not rerank_settings or not rerank_settings.rerank_model_name: - # Should never reach this part of the flow without reranking settings - raise RuntimeError("Reranking flow should not be running") + assert ( + rerank_settings.rerank_model_name + ), "Reranking flow cannot run without a specific model" chunks_to_rerank = chunks[: rerank_settings.num_rerank] @@ -107,7 +107,7 @@ def semantic_reranking( f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}" for chunk in chunks_to_rerank ] - sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages) + sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages) # Old logic to handle multiple cross-encoders preserved but not used sim_scores = [numpy.array(sim_scores_floats)] @@ -165,8 +165,20 @@ def semantic_reranking( return list(ranked_chunks), list(ranked_indices) +def should_rerank(rerank_settings: RerankingDetails | None) -> bool: + """Based on the RerankingDetails model, only run rerank if the following conditions are met: + - rerank_model_name is not None + - num_rerank is greater than 0 + """ + if not rerank_settings: + return False + + return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0) + + def rerank_sections( - query: SearchQuery, + query_str: str, + rerank_settings: RerankingDetails, sections_to_rerank: list[InferenceSection], rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, ) -> list[InferenceSection]: @@ -181,16 +193,13 @@ def rerank_sections( """ chunks_to_rerank = [section.center_chunk for section in sections_to_rerank] - if not query.rerank_settings: - # Should never reach this part of the flow without reranking settings - raise RuntimeError("Reranking settings not found") - ranked_chunks, _ = semantic_reranking( - query=query, + query_str=query_str, + rerank_settings=rerank_settings, chunks=chunks_to_rerank, rerank_metrics_callback=rerank_metrics_callback, ) - lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :] + lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :] # Scores from rerank cannot be meaningfully combined with scores without rerank # However the ordering is still important @@ -260,16 +269,13 @@ def search_postprocessing( rerank_task_id = None sections_yielded = False - if ( - search_query.rerank_settings - and search_query.rerank_settings.rerank_model_name - and search_query.rerank_settings.num_rerank > 0 - ): + if should_rerank(search_query.rerank_settings): post_processing_tasks.append( FunctionCall( rerank_sections, ( - search_query, + search_query.query, + search_query.rerank_settings, # Cannot be None here retrieved_sections, rerank_metrics_callback, ), diff --git a/backend/onyx/context/search/preprocessing/preprocessing.py b/backend/onyx/context/search/preprocessing/preprocessing.py index da228f5f1fb..2e63ed0e39e 100644 --- a/backend/onyx/context/search/preprocessing/preprocessing.py +++ b/backend/onyx/context/search/preprocessing/preprocessing.py @@ -50,11 +50,11 @@ def retrieval_preprocessing( search_request: SearchRequest, user: User | None, llm: LLM, + skip_query_analysis: bool, db_session: Session, - bypass_acl: bool = False, - skip_query_analysis: bool = False, - base_recency_decay: float = BASE_RECENCY_DECAY, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, + base_recency_decay: float = BASE_RECENCY_DECAY, + bypass_acl: bool = False, ) -> SearchQuery: """Logic is as follows: Any global disables apply first @@ -146,7 +146,7 @@ def retrieval_preprocessing( is_keyword, extracted_keywords = ( parallel_results[run_query_analysis.result_id] if run_query_analysis - else (None, None) + else (False, None) ) all_query_terms = query.split() diff --git a/backend/onyx/db/users.py b/backend/onyx/db/users.py index 12fd5d15c5e..894f8466e01 100644 --- a/backend/onyx/db/users.py +++ b/backend/onyx/db/users.py @@ -6,6 +6,7 @@ from fastapi_users.password import PasswordHelper from sqlalchemy import func from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from sqlalchemy.sql import expression from sqlalchemy.sql.elements import ColumnElement @@ -274,7 +275,7 @@ def _generate_ext_permissioned_user(email: str) -> User: def batch_add_ext_perm_user_if_not_exists( - db_session: Session, emails: list[str] + db_session: Session, emails: list[str], continue_on_error: bool = False ) -> list[User]: lower_emails = [email.lower() for email in emails] found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails) @@ -283,10 +284,23 @@ def batch_add_ext_perm_user_if_not_exists( for email in missing_lower_emails: new_users.append(_generate_ext_permissioned_user(email=email)) - db_session.add_all(new_users) - db_session.commit() - - return found_users + new_users + try: + db_session.add_all(new_users) + db_session.commit() + except IntegrityError: + db_session.rollback() + if not continue_on_error: + raise + for user in new_users: + try: + db_session.add(user) + db_session.commit() + except IntegrityError: + db_session.rollback() + continue + # Fetch all users again to ensure we have the most up-to-date list + all_users, _ = _get_users_by_emails(db_session, lower_emails) + return all_users def delete_user_from_db( diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index f8c9b58ce4c..6dc4249bd73 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -17,6 +17,7 @@ import httpx # type: ignore import requests # type: ignore +from retry import retry from onyx.configs.chat_configs import DOC_TIME_DECAY from onyx.configs.chat_configs import NUM_RETURNED_HITS @@ -549,6 +550,11 @@ def update( time.monotonic() - update_start, ) + @retry( + tries=3, + delay=1, + backoff=2, + ) def _update_single_chunk( self, doc_chunk_id: UUID, @@ -559,6 +565,7 @@ def _update_single_chunk( ) -> None: """ Update a single "chunk" (document) in Vespa using its chunk ID. + Retries if we encounter transient HTTPStatusError (e.g., overload). """ update_dict: dict[str, dict] = {"fields": {}} @@ -567,13 +574,11 @@ def _update_single_chunk( update_dict["fields"][BOOST] = {"assign": fields.boost} if fields.document_sets is not None: - # WeightedSet needs a map { item: weight, ... } update_dict["fields"][DOCUMENT_SETS] = { "assign": {document_set: 1 for document_set in fields.document_sets} } if fields.access is not None: - # Similar to above update_dict["fields"][ACCESS_CONTROL_LIST] = { "assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()} } @@ -585,7 +590,10 @@ def _update_single_chunk( logger.error("Update request received but nothing to update.") return - vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true" + vespa_url = ( + f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" + "?create=true" + ) try: resp = http_client.put( @@ -595,8 +603,11 @@ def _update_single_chunk( ) resp.raise_for_status() except httpx.HTTPStatusError as e: - error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}" - logger.error(error_message) + logger.error( + f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). " + f"Details: {e.response.text}" + ) + # Re-raise so the @retry decorator will catch and retry raise def update_single( diff --git a/backend/onyx/natural_language_processing/utils.py b/backend/onyx/natural_language_processing/utils.py index 7b68b20d8e9..3c4d1392088 100644 --- a/backend/onyx/natural_language_processing/utils.py +++ b/backend/onyx/natural_language_processing/utils.py @@ -99,7 +99,7 @@ def _check_tokenizer_cache( if not tokenizer: logger.info( - f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" + f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}" ) tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) diff --git a/backend/onyx/server/gpts/api.py b/backend/onyx/server/gpts/api.py index 58796d6199b..ea2aad20b1e 100644 --- a/backend/onyx/server/gpts/api.py +++ b/backend/onyx/server/gpts/api.py @@ -76,6 +76,7 @@ def gpt_search( user=None, llm=llm, fast_llm=fast_llm, + skip_query_analysis=True, db_session=db_session, ).reranked_sections diff --git a/backend/onyx/server/settings/models.py b/backend/onyx/server/settings/models.py index fa6a5bcbffc..f0c542797d7 100644 --- a/backend/onyx/server/settings/models.py +++ b/backend/onyx/server/settings/models.py @@ -12,10 +12,10 @@ class PageType(str, Enum): SEARCH = "search" -class GatingType(str, Enum): - FULL = "full" # Complete restriction of access to the product or service - PARTIAL = "partial" # Full access but warning (no credit card on file) - NONE = "none" # No restrictions, full access to all features +class ApplicationStatus(str, Enum): + PAYMENT_REMINDER = "payment_reminder" + GATED_ACCESS = "gated_access" + ACTIVE = "active" class Notification(BaseModel): @@ -43,7 +43,7 @@ class Settings(BaseModel): maximum_chat_retention_days: int | None = None gpu_enabled: bool | None = None - product_gating: GatingType = GatingType.NONE + application_status: ApplicationStatus = ApplicationStatus.ACTIVE anonymous_user_enabled: bool | None = None pro_search_disabled: bool | None = None auto_scroll: bool | None = None diff --git a/backend/onyx/tools/base_tool.py b/backend/onyx/tools/base_tool.py index 16ec5d92aa0..4b8479b75bc 100644 --- a/backend/onyx/tools/base_tool.py +++ b/backend/onyx/tools/base_tool.py @@ -34,7 +34,7 @@ def build_user_message_for_non_tool_calling_llm( """.strip() -class BaseTool(Tool): +class BaseTool(Tool[None]): def build_next_prompt( self, prompt_builder: "AnswerPromptBuilder", diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index a8918b691e4..1e343e74cb3 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -1,11 +1,14 @@ +from collections.abc import Callable from typing import Any from uuid import UUID from pydantic import BaseModel from pydantic import model_validator +from sqlalchemy.orm import Session from onyx.context.search.enums import SearchType from onyx.context.search.models import IndexFilters +from onyx.context.search.models import InferenceSection class ToolResponse(BaseModel): @@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel): recency_bias_multiplier: float +class SearchToolOverrideKwargs(BaseModel): + force_no_rerank: bool + alternate_db_session: Session | None + retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None + skip_query_analysis: bool + + class Config: + arbitrary_types_allowed = True + + CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" diff --git a/backend/onyx/tools/tool.py b/backend/onyx/tools/tool.py index 4a8ba80996e..2c7f53647f0 100644 --- a/backend/onyx/tools/tool.py +++ b/backend/onyx/tools/tool.py @@ -1,7 +1,9 @@ import abc from collections.abc import Generator from typing import Any +from typing import Generic from typing import TYPE_CHECKING +from typing import TypeVar from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage @@ -14,7 +16,10 @@ from onyx.tools.models import ToolResponse -class Tool(abc.ABC): +OVERRIDE_T = TypeVar("OVERRIDE_T") + + +class Tool(abc.ABC, Generic[OVERRIDE_T]): @property @abc.abstractmethod def name(self) -> str: @@ -57,7 +62,9 @@ def get_args_for_non_tool_calling_llm( """Actual execution of the tool""" @abc.abstractmethod - def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]: + def run( + self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any + ) -> Generator["ToolResponse", None, None]: raise NotImplementedError @abc.abstractmethod diff --git a/backend/onyx/tools/tool_implementations/custom/custom_tool.py b/backend/onyx/tools/tool_implementations/custom/custom_tool.py index a235383a71c..932989e44e5 100644 --- a/backend/onyx/tools/tool_implementations/custom/custom_tool.py +++ b/backend/onyx/tools/tool_implementations/custom/custom_tool.py @@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel): tool_result: Any # The response data +# override_kwargs is not supported for custom tools class CustomTool(BaseTool): def __init__( self, @@ -235,7 +236,9 @@ def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]: """Actual execution of the tool""" - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any + ) -> Generator[ToolResponse, None, None]: request_body = kwargs.get(REQUEST_BODY) path_params = {} diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index f4e19e1c283..3185b4a001d 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -79,7 +79,8 @@ class ImageShape(str, Enum): LANDSCAPE = "landscape" -class ImageGenerationTool(Tool): +# override_kwargs is not supported for image generation tools +class ImageGenerationTool(Tool[None]): _NAME = "run_image_generation" _DESCRIPTION = "Generate an image from a prompt." _DISPLAY_NAME = "Image Generation" @@ -255,7 +256,9 @@ def _generate_image( "An error occurred during image generation. Please try again later." ) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: None = None, **kwargs: str + ) -> Generator[ToolResponse, None, None]: prompt = cast(str, kwargs["prompt"]) shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE)) format = self.output_format diff --git a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py index 474fa2d675f..1c6b3f21cc1 100644 --- a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py +++ b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py @@ -106,7 +106,8 @@ def internet_search_response_to_search_docs( ] -class InternetSearchTool(Tool): +# override_kwargs is not supported for internet search tools +class InternetSearchTool(Tool[None]): _NAME = "run_internet_search" _DISPLAY_NAME = "Internet Search" _DESCRIPTION = "Perform an internet search for up-to-date information." @@ -242,7 +243,9 @@ def _perform_search(self, query: str) -> InternetSearchResponse: ], ) - def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + def run( + self, override_kwargs: None = None, **kwargs: str + ) -> Generator[ToolResponse, None, None]: query = cast(str, kwargs["internet_search_query"]) results = self._perform_search(query) diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 2666b2014a4..11d147526a0 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase from onyx.tools.message import ToolCallSummary from onyx.tools.models import SearchQueryInfo +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict @@ -77,7 +78,7 @@ class SearchResponseSummary(SearchQueryInfo): """ -class SearchTool(Tool): +class SearchTool(Tool[SearchToolOverrideKwargs]): _NAME = "run_search" _DISPLAY_NAME = "Search Tool" _DESCRIPTION = SEARCH_TOOL_DESCRIPTION @@ -275,14 +276,19 @@ def _build_response_for_specified_sections( yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) - def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: - query = cast(str, kwargs["query"]) - force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) - alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) - retrieved_sections_callback = cast( - Callable[[list[InferenceSection]], None], - kwargs.get("retrieved_sections_callback"), - ) + def run( + self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any + ) -> Generator[ToolResponse, None, None]: + query = cast(str, llm_kwargs["query"]) + force_no_rerank = False + alternate_db_session = None + retrieved_sections_callback = None + skip_query_analysis = False + if override_kwargs: + force_no_rerank = override_kwargs.force_no_rerank + alternate_db_session = override_kwargs.alternate_db_session + retrieved_sections_callback = override_kwargs.retrieved_sections_callback + skip_query_analysis = override_kwargs.skip_query_analysis if self.selected_sections: yield from self._build_response_for_specified_sections(query) @@ -324,6 +330,7 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: user=self.user, llm=self.llm, fast_llm=self.fast_llm, + skip_query_analysis=skip_query_analysis, bypass_acl=self.bypass_acl, db_session=alternate_db_session or self.db_session, prompt_config=self.prompt_config, diff --git a/backend/scripts/debugging/onyx_vespa.py b/backend/scripts/debugging/onyx_vespa.py index 39c1b8b0d93..4f9b3f6ab97 100644 --- a/backend/scripts/debugging/onyx_vespa.py +++ b/backend/scripts/debugging/onyx_vespa.py @@ -256,16 +256,28 @@ def get_documents_for_tenant_connector( def search_for_document( - index_name: str, document_id: str, max_hits: int | None = 10 + index_name: str, + document_id: str | None = None, + tenant_id: str | None = None, + max_hits: int | None = 10, ) -> List[Dict[str, Any]]: - yql_query = ( - f'select * from sources {index_name} where document_id contains "{document_id}"' - ) + yql_query = f"select * from sources {index_name}" + + conditions = [] + if document_id is not None: + conditions.append(f'document_id contains "{document_id}"') + + if tenant_id is not None: + conditions.append(f'tenant_id contains "{tenant_id}"') + + if conditions: + yql_query += " where " + " and ".join(conditions) + params: dict[str, Any] = {"yql": yql_query} if max_hits is not None: params["hits"] = max_hits with get_vespa_http_client() as client: - response = client.get(f"{SEARCH_ENDPOINT}/search/", params=params) + response = client.get(f"{SEARCH_ENDPOINT}search/", params=params) response.raise_for_status() result = response.json() documents = result.get("root", {}).get("children", []) @@ -582,8 +594,15 @@ def update_document( ) -> None: update_document(self.tenant_id, connector_id, doc_id, fields) - def search_for_document(self, document_id: str) -> List[Dict[str, Any]]: - return search_for_document(self.index_name, document_id) + def delete_documents_for_tenant(self, count: int | None = None) -> None: + if not self.tenant_id: + raise Exception("Tenant ID is not set") + delete_documents_for_tenant(self.index_name, self.tenant_id, count=count) + + def search_for_document( + self, document_id: str | None = None, tenant_id: str | None = None + ) -> List[Dict[str, Any]]: + return search_for_document(self.index_name, document_id, tenant_id) def delete_document(self, connector_id: int, doc_id: str) -> None: # Delete a document. @@ -600,6 +619,147 @@ def acls(self, cc_pair_id: int, n: int | None = 10) -> None: get_document_acls(self.tenant_id, cc_pair_id, n) +def delete_where( + index_name: str, + selection: str, + cluster: str = "default", + bucket_space: str | None = None, + continuation: str | None = None, + time_chunk: str | None = None, + timeout: str | None = None, + tracelevel: int | None = None, +) -> None: + """ + Removes visited documents in `cluster` where the given selection + is true, using Vespa's 'delete where' endpoint. + + :param index_name: Typically / from your schema + :param selection: The selection string, e.g., "true" or "foo contains 'bar'" + :param cluster: The name of the cluster where documents reside + :param bucket_space: e.g. 'global' or 'default' + :param continuation: For chunked visits + :param time_chunk: If you want to chunk the visit by time + :param timeout: e.g. '10s' + :param tracelevel: Increase for verbose logs + """ + # Using index_name of form /, e.g. "nomic_ai_nomic_embed_text_v1" + # This route ends with "/docid/" since the actual ID is not specified — we rely on "selection". + path = f"/document/v1/{index_name}/docid/" + + params = { + "cluster": cluster, + "selection": selection, + } + + # Optional parameters + if bucket_space is not None: + params["bucketSpace"] = bucket_space + if continuation is not None: + params["continuation"] = continuation + if time_chunk is not None: + params["timeChunk"] = time_chunk + if timeout is not None: + params["timeout"] = timeout + if tracelevel is not None: + params["tracelevel"] = tracelevel # type: ignore + + with get_vespa_http_client() as client: + url = f"{VESPA_APPLICATION_ENDPOINT}{path}" + logger.info(f"Performing 'delete where' on {url} with selection={selection}...") + response = client.delete(url, params=params) + # (Optionally, you can keep fetching `continuation` from the JSON response + # if you have more documents to delete in chunks.) + response.raise_for_status() # will raise HTTPError if not 2xx + logger.info(f"Delete where completed with status: {response.status_code}") + print(f"Delete where completed with status: {response.status_code}") + + +def delete_documents_for_tenant( + index_name: str, + tenant_id: str, + route: str | None = None, + condition: str | None = None, + timeout: str | None = None, + tracelevel: int | None = None, + count: int | None = None, +) -> None: + """ + For the given tenant_id and index_name (often in the form /), + find documents via search_for_document, then delete them one at a time using Vespa's + /document/v1///docid/ endpoint. + + :param index_name: Typically / from your schema + :param tenant_id: The tenant to match in your Vespa search + :param route: Optional route parameter for delete + :param condition: Optional conditional remove + :param timeout: e.g. '10s' + :param tracelevel: Increase for verbose logs + """ + deleted_count = 0 + while True: + # Search for documents with the given tenant_id + docs = search_for_document( + index_name=index_name, + document_id=None, + tenant_id=tenant_id, + max_hits=100, # Fetch in batches of 100 + ) + + if not docs: + logger.info("No more documents found to delete.") + break + + with get_vespa_http_client() as client: + for doc in docs: + if count is not None and deleted_count >= count: + logger.info(f"Reached maximum delete limit of {count} documents.") + return + + fields = doc.get("fields", {}) + doc_id_value = fields.get("document_id") or fields.get("documentid") + tenant_id = fields.get("tenant_id") + if tenant_id != tenant_id: + raise Exception("Tenant ID mismatch") + + if not doc_id_value: + logger.warning( + "Skipping a document that has no document_id in 'fields'." + ) + continue + + url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_id_value}" + + params = {} + if condition: + params["condition"] = condition + if route: + params["route"] = route + if timeout: + params["timeout"] = timeout + if tracelevel is not None: + params["tracelevel"] = str(tracelevel) + + response = client.delete(url, params=params) + if response.status_code == 200: + logger.info(f"Successfully deleted doc_id={doc_id_value}") + deleted_count += 1 + else: + logger.error( + f"Failed to delete doc_id={doc_id_value}, " + f"status={response.status_code}, response={response.text}" + ) + print( + f"Could not delete doc_id={doc_id_value}. " + f"Status={response.status_code}, response={response.text}" + ) + raise Exception( + f"Could not delete doc_id={doc_id_value}. " + f"Status={response.status_code}, response={response.text}" + ) + + logger.info(f"Deleted {deleted_count} documents in total.") + + def main() -> None: parser = argparse.ArgumentParser(description="Vespa debugging tool") parser.add_argument( @@ -612,6 +772,7 @@ def main() -> None: "update", "delete", "get_acls", + "delete-all-documents", ], required=True, help="Action to perform", @@ -626,11 +787,20 @@ def main() -> None: parser.add_argument( "--fields", help="Fields to update, in JSON format (for update)" ) + parser.add_argument( + "--count", + type=int, + help="Maximum number of documents to delete (for delete-all-documents)", + ) args = parser.parse_args() vespa_debug = VespaDebugging(args.tenant_id) - if args.action == "config": + if args.action == "delete-all-documents": + if not args.tenant_id: + parser.error("--tenant-id is required for delete-all-documents action") + vespa_debug.delete_documents_for_tenant(count=args.count) + elif args.action == "config": vespa_debug.print_config() elif args.action == "connect": vespa_debug.check_connectivity() diff --git a/web/Dockerfile b/web/Dockerfile index b4d74110cfc..a5712207447 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -84,6 +84,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED} ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK} +ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY +ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY} + # Use NODE_OPTIONS in the build command RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build @@ -145,7 +148,6 @@ ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL} - ARG NEXT_PUBLIC_POSTHOG_KEY ARG NEXT_PUBLIC_POSTHOG_HOST ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY} @@ -166,6 +168,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED} ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK} +ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY +ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY} + # Note: Don't expose ports here, Compose will handle that for us if necessary. # If you want to run this without compose, specify the ports to # expose via cli diff --git a/web/src/app/admin/settings/interfaces.ts b/web/src/app/admin/settings/interfaces.ts index 4574ad8a0ae..245c4dd1b6b 100644 --- a/web/src/app/admin/settings/interfaces.ts +++ b/web/src/app/admin/settings/interfaces.ts @@ -1,7 +1,7 @@ -export enum GatingType { - FULL = "full", - PARTIAL = "partial", - NONE = "none", +export enum ApplicationStatus { + PAYMENT_REMINDER = "payment_reminder", + GATED_ACCESS = "gated_access", + ACTIVE = "active", } export interface Settings { @@ -11,7 +11,7 @@ export interface Settings { needs_reindexing: boolean; gpu_enabled: boolean; pro_search_disabled: boolean | null; - product_gating: GatingType; + application_status: ApplicationStatus; auto_scroll: boolean; } diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index f425f0dc259..2ffe8aafa9d 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -2291,8 +2291,6 @@ export function ChatPage({ bg-opacity-80 duration-300 ease-in-out - - ${ !untoggled && (showHistorySidebar || sidebarVisible) ? "opacity-100 w-[250px] translate-x-0" diff --git a/web/src/app/ee/admin/billing/BillingAlerts.tsx b/web/src/app/ee/admin/billing/BillingAlerts.tsx new file mode 100644 index 00000000000..7015034c847 --- /dev/null +++ b/web/src/app/ee/admin/billing/BillingAlerts.tsx @@ -0,0 +1,73 @@ +import React from "react"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { CircleAlert, Info } from "lucide-react"; +import { BillingInformation, BillingStatus } from "./interfaces"; + +export function BillingAlerts({ + billingInformation, +}: { + billingInformation: BillingInformation; +}) { + const isTrialing = billingInformation.status === BillingStatus.TRIALING; + const isCancelled = billingInformation.cancel_at_period_end; + const isExpired = + new Date(billingInformation.current_period_end) < new Date(); + const noPaymentMethod = !billingInformation.payment_method_enabled; + + const messages: string[] = []; + + if (isExpired) { + messages.push( + "Your subscription has expired. Please resubscribe to continue using the service." + ); + } + if (isCancelled && !isExpired) { + messages.push( + `Your subscription will cancel on ${new Date( + billingInformation.current_period_end + ).toLocaleDateString()}. You can resubscribe before this date to remain uninterrupted.` + ); + } + if (isTrialing) { + messages.push( + `You're currently on a trial. Your trial ends on ${ + billingInformation.trial_end + ? new Date(billingInformation.trial_end).toLocaleDateString() + : "N/A" + }.` + ); + } + if (noPaymentMethod) { + messages.push( + "You currently have no payment method on file. Please add one to avoid service interruption." + ); + } + + const variant = isExpired || noPaymentMethod ? "destructive" : "default"; + + if (messages.length === 0) return null; + + return ( + + + {variant === "destructive" ? ( + + ) : ( + + )} + + {variant === "destructive" + ? "Important Subscription Notice" + : "Subscription Notice"} + + + +
    + {messages.map((msg, idx) => ( +
  • {msg}
  • + ))} +
+
+
+ ); +} diff --git a/web/src/app/ee/admin/billing/BillingInformationPage.tsx b/web/src/app/ee/admin/billing/BillingInformationPage.tsx index 0d00b97bef2..7eb6a5c8b26 100644 --- a/web/src/app/ee/admin/billing/BillingInformationPage.tsx +++ b/web/src/app/ee/admin/billing/BillingInformationPage.tsx @@ -1,18 +1,21 @@ "use client"; -import { CreditCard, ArrowFatUp } from "@phosphor-icons/react"; -import { useState } from "react"; import { useRouter } from "next/navigation"; -import { loadStripe } from "@stripe/stripe-js"; +import { useEffect } from "react"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { SettingsIcon } from "@/components/icons/icons"; +import { fetchCustomerPortal, useBillingInformation } from "./utils"; + import { - updateSubscriptionQuantity, - fetchCustomerPortal, - statusToDisplay, - useBillingInformation, -} from "./utils"; -import { useEffect } from "react"; + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { Button } from "@/components/ui/button"; +import { CreditCard, ArrowFatUp } from "@phosphor-icons/react"; +import { SubscriptionSummary } from "./SubscriptionSummary"; +import { BillingAlerts } from "./BillingAlerts"; export default function BillingInformationPage() { const router = useRouter(); @@ -24,9 +27,6 @@ export default function BillingInformationPage() { isLoading, } = useBillingInformation(); - if (error) { - console.error("Failed to fetch billing information:", error); - } useEffect(() => { const url = new URL(window.location.href); if (url.searchParams.has("session_id")) { @@ -35,22 +35,33 @@ export default function BillingInformationPage() { "Congratulations! Your subscription has been updated successfully.", type: "success", }); - // Remove the session_id from the URL url.searchParams.delete("session_id"); window.history.replaceState({}, "", url.toString()); - // You might want to refresh the billing information here - // by calling an API endpoint to get the latest data } }, [setPopup]); if (isLoading) { - return
Loading...
; + return
Loading...
; + } + + if (error) { + console.error("Failed to fetch billing information:", error); + return ( +
+ Error loading billing information. Please try again later. +
+ ); + } + + if (!billingInformation) { + return ( +
No billing information available.
+ ); } const handleManageSubscription = async () => { try { const response = await fetchCustomerPortal(); - if (!response.ok) { const errorData = await response.json(); throw new Error( @@ -61,11 +72,9 @@ export default function BillingInformationPage() { } const { url } = await response.json(); - if (!url) { throw new Error("No portal URL returned from the server"); } - router.push(url); } catch (error) { console.error("Error creating customer portal session:", error); @@ -75,138 +84,39 @@ export default function BillingInformationPage() { }); } }; - if (!billingInformation) { - return
Loading...
; - } return (
-
- {popup} - -

- {/* */} - Subscription Details -

- -
-
-
-
-

Seats

-

- Number of licensed users -

-
-

- {billingInformation.seats} -

-
-
- -
-
-
-

- Subscription Status -

-

- Current state of your subscription -

-
-

- {statusToDisplay(billingInformation.subscription_status)} -

-
-
- -
-
-
-

- Billing Start -

-

- Start date of current billing cycle -

-
-

- {new Date( - billingInformation.billing_start - ).toLocaleDateString()} -

-
-
- -
-
-
-

Billing End

-

- End date of current billing cycle -

-
-

- {new Date(billingInformation.billing_end).toLocaleDateString()} -

-
-
-
- - {!billingInformation.payment_method_enabled && ( -
-

Notice:

-

- You'll need to add a payment method before your trial ends to - continue using the service. -

-
- )} - - {billingInformation.subscription_status === "trialing" ? ( -
-

- No cap on users during trial -

-
- ) : ( -
-
-

- Current Seats: -

-

- {billingInformation.seats} -

-
-

- Seats automatically update based on adding, removing, or inviting - users. -

-
- )} -
- -
-
-
-

- Manage Subscription -

-

- View your plan, update payment, or change subscription -

-
- -
- -
+ {popup} + + + + + Subscription Details + + + + + + + + + + + + Manage Subscription + + + View your plan, update payment, or change subscription + + + + + +
); } diff --git a/web/src/app/ee/admin/billing/InfoItem.tsx b/web/src/app/ee/admin/billing/InfoItem.tsx new file mode 100644 index 00000000000..a4f8e745144 --- /dev/null +++ b/web/src/app/ee/admin/billing/InfoItem.tsx @@ -0,0 +1,17 @@ +import React from "react"; + +interface InfoItemProps { + title: string; + value: string; +} + +export function InfoItem({ title, value }: InfoItemProps) { + return ( +
+

{title}

+

+ {value} +

+
+ ); +} diff --git a/web/src/app/ee/admin/billing/SubscriptionSummary.tsx b/web/src/app/ee/admin/billing/SubscriptionSummary.tsx new file mode 100644 index 00000000000..56e682f607f --- /dev/null +++ b/web/src/app/ee/admin/billing/SubscriptionSummary.tsx @@ -0,0 +1,33 @@ +import React from "react"; +import { InfoItem } from "./InfoItem"; +import { statusToDisplay } from "./utils"; + +interface SubscriptionSummaryProps { + billingInformation: any; +} + +export function SubscriptionSummary({ + billingInformation, +}: SubscriptionSummaryProps) { + return ( +
+ + + + +
+ ); +} diff --git a/web/src/app/ee/admin/billing/interfaces.ts b/web/src/app/ee/admin/billing/interfaces.ts new file mode 100644 index 00000000000..6b0a48d08ff --- /dev/null +++ b/web/src/app/ee/admin/billing/interfaces.ts @@ -0,0 +1,19 @@ +export interface BillingInformation { + status: string; + trial_end: Date | null; + current_period_end: Date; + payment_method_enabled: boolean; + cancel_at_period_end: boolean; + current_period_start: Date; + number_of_seats: number; + canceled_at: Date | null; + trial_start: Date | null; + seats: number; +} + +export enum BillingStatus { + TRIALING = "trialing", + ACTIVE = "active", + CANCELLED = "cancelled", + EXPIRED = "expired", +} diff --git a/web/src/app/ee/admin/billing/page.tsx b/web/src/app/ee/admin/billing/page.tsx index 96cfe7585b3..b52fd0150b4 100644 --- a/web/src/app/ee/admin/billing/page.tsx +++ b/web/src/app/ee/admin/billing/page.tsx @@ -3,10 +3,16 @@ import BillingInformationPage from "./BillingInformationPage"; import { MdOutlineCreditCard } from "react-icons/md"; export interface BillingInformation { + stripe_subscription_id: string; + status: string; + current_period_start: Date; + current_period_end: Date; + number_of_seats: number; + cancel_at_period_end: boolean; + canceled_at: Date | null; + trial_start: Date | null; + trial_end: Date | null; seats: number; - subscription_status: string; - billing_start: Date; - billing_end: Date; payment_method_enabled: boolean; } diff --git a/web/src/app/ee/admin/billing/utils.ts b/web/src/app/ee/admin/billing/utils.ts index 1f2aaa8e8eb..be652ad051b 100644 --- a/web/src/app/ee/admin/billing/utils.ts +++ b/web/src/app/ee/admin/billing/utils.ts @@ -35,9 +35,16 @@ export const statusToDisplay = (status: string) => { export const useBillingInformation = () => { const url = "/api/tenants/billing-information"; - const swrResponse = useSWR(url, (url: string) => - fetch(url).then((res) => res.json()) - ); + const swrResponse = useSWR(url, async (url: string) => { + const res = await fetch(url); + if (!res.ok) { + const errorData = await res.json(); + throw new Error( + errorData.message || "Failed to fetch billing information" + ); + } + return res.json(); + }); return { ...swrResponse, diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index b0059650fe8..87cb1ddd289 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -13,7 +13,10 @@ import { import { Metadata } from "next"; import { buildClientUrl } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; -import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces"; +import { + EnterpriseSettings, + ApplicationStatus, +} from "./admin/settings/interfaces"; import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata"; import { AppProvider } from "@/components/context/AppProvider"; import { PHProvider } from "./providers"; @@ -28,6 +31,7 @@ import { WebVitals } from "./web-vitals"; import { ThemeProvider } from "next-themes"; import CloudError from "@/components/errorPages/CloudErrorPage"; import Error from "@/components/errorPages/ErrorPage"; +import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage"; const inter = Inter({ subsets: ["latin"], @@ -75,7 +79,7 @@ export default async function RootLayout({ ]); const productGating = - combinedSettings?.settings.product_gating ?? GatingType.NONE; + combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE; const getPageContent = async (content: React.ReactNode) => ( ); - if (!combinedSettings) { - return getPageContent( - NEXT_PUBLIC_CLOUD_ENABLED ? : - ); + if (productGating === ApplicationStatus.GATED_ACCESS) { + return getPageContent(); } - if (productGating === GatingType.FULL) { + if (!combinedSettings) { return getPageContent( -
-
- -
- -

- Access Restricted -

-

- We regret to inform you that your access to Onyx has been - temporarily suspended due to a lapse in your subscription. -

-

- To reinstate your access and continue benefiting from Onyx's - powerful features, please update your payment information. -

-

- If you're an admin, you can resolve this by visiting the - billing section. For other users, please reach out to your - administrator to address this matter. -

-
-
+ NEXT_PUBLIC_CLOUD_ENABLED ? : ); } diff --git a/web/src/components/admin/ClientLayout.tsx b/web/src/components/admin/ClientLayout.tsx index 63efa45dc74..9ccec914af5 100644 --- a/web/src/components/admin/ClientLayout.tsx +++ b/web/src/components/admin/ClientLayout.tsx @@ -33,6 +33,9 @@ import { MdOutlineCreditCard } from "react-icons/md"; import { UserSettingsModal } from "@/app/chat/modal/UserSettingsModal"; import { usePopup } from "./connectors/Popup"; import { useChatContext } from "../context/ChatContext"; +import { ApplicationStatus } from "@/app/admin/settings/interfaces"; +import Link from "next/link"; +import { Button } from "../ui/button"; export function ClientLayout({ user, @@ -74,6 +77,23 @@ export function ClientLayout({ defaultModel={user?.preferences?.default_model!} /> )} + {settings?.settings.application_status === + ApplicationStatus.PAYMENT_REMINDER && ( +
+ Warning: Your trial ends in + less than 2 days and no payment method has been added. +
+ + + +
+
+ )}
{ + const response = await fetch("/api/tenants/create-subscription-session", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + if (!response.ok) { + throw new Error("Failed to create resubscription session"); + } + return response.json(); +}; + +export default function AccessRestricted() { + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + const router = useRouter(); + + const handleManageSubscription = async () => { + setIsLoading(true); + setError(null); + try { + const response = await fetchCustomerPortal(); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error( + `Failed to create customer portal session: ${ + errorData.message || response.statusText + }` + ); + } + + const { url } = await response.json(); + + if (!url) { + throw new Error("No portal URL returned from the server"); + } + + router.push(url); + } catch (error) { + console.error("Error creating customer portal session:", error); + setError("Error opening customer portal. Please try again later."); + } finally { + setIsLoading(false); + } + }; + + const handleResubscribe = async () => { + setIsLoading(true); + setError(null); + if (!NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY) { + setError("Stripe public key not found"); + setIsLoading(false); + return; + } + try { + const { sessionId } = await fetchResubscriptionSession(); + const stripe = await loadStripe(NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY); + + if (stripe) { + await stripe.redirectToCheckout({ sessionId }); + } else { + throw new Error("Stripe failed to load"); + } + } catch (error) { + console.error("Error creating resubscription session:", error); + setError("Error opening resubscription page. Please try again later."); + } finally { + setIsLoading(false); + } + }; + + return ( + +

+

Access Restricted

+ +

+
+

+ We regret to inform you that your access to Onyx has been temporarily + suspended due to a lapse in your subscription. +

+

+ To reinstate your access and continue benefiting from Onyx's + powerful features, please update your payment information. +

+

+ If you're an admin, you can manage your subscription by clicking + the button below. For other users, please reach out to your + administrator to address this matter. +

+
+ + + +
+ {error &&

{error}

} +

+ Need help? Join our{" "} + + Slack community + {" "} + for support. +

+
+
+ ); +} diff --git a/web/src/components/settings/lib.ts b/web/src/components/settings/lib.ts index e3024b2c27b..d8994fe759b 100644 --- a/web/src/components/settings/lib.ts +++ b/web/src/components/settings/lib.ts @@ -1,7 +1,7 @@ import { CombinedSettings, EnterpriseSettings, - GatingType, + ApplicationStatus, Settings, } from "@/app/admin/settings/interfaces"; import { @@ -45,7 +45,7 @@ export async function fetchSettingsSS(): Promise { if (results[0].status === 403 || results[0].status === 401) { settings = { auto_scroll: true, - product_gating: GatingType.NONE, + application_status: ApplicationStatus.ACTIVE, gpu_enabled: false, maximum_chat_retention_days: null, notifications: [], diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index 7b40f23b23e..37fff721d43 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -91,3 +91,6 @@ export const NEXT_PUBLIC_ENABLE_CHROME_EXTENSION = export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK = process.env.NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK?.toLowerCase() === "true"; + +export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY = + process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;