Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feat…
Browse files Browse the repository at this point in the history
…ure/propagate-exceptions
  • Loading branch information
Richard Kuo (Danswer) committed Feb 14, 2025
2 parents e40a844 + b70db15 commit f84720b
Show file tree
Hide file tree
Showing 45 changed files with 916 additions and 353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
Expand Down
2 changes: 2 additions & 0 deletions backend/ee/onyx/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,5 @@
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")

ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"

GATED_TENANTS_KEY = "gated_tenants"
1 change: 1 addition & 0 deletions backend/ee/onyx/server/query_and_chat/query_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)
Expand Down
59 changes: 34 additions & 25 deletions backend/ee/onyx/server/tenants/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
Expand All @@ -39,12 +44,9 @@
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

Expand Down Expand Up @@ -126,37 +128,29 @@ async def login_as_anonymous_user(
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
We gate the product when their subscription has ended.
"""
tenant_id = product_gating_request.tenant_id
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)

if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)

if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))


@router.get("/billing-information", response_model=BillingInformation)
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return BillingInformation(
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())


@router.post("/create-customer-portal-session")
Expand All @@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)

portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
Expand All @@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
raise HTTPException(status_code=500, detail=str(e))


@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)

except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))


@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
Expand Down
18 changes: 16 additions & 2 deletions backend/ee/onyx/server/tenants/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger

Expand All @@ -14,6 +15,19 @@
logger = setup_logger()


def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]


def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
Expand All @@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()


def fetch_billing_information(tenant_id: str) -> dict:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
Expand All @@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = response.json()
billing_info = BillingInformation(**response.json())
return billing_info


Expand Down
33 changes: 26 additions & 7 deletions backend/ee/onyx/server/tenants/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import datetime

from pydantic import BaseModel

from onyx.configs.constants import NotificationType
from onyx.server.settings.models import GatingType
from onyx.server.settings.models import ApplicationStatus


class CheckoutSessionCreationRequest(BaseModel):
Expand All @@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):

class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
notification: NotificationType | None = None
application_status: ApplicationStatus


class SubscriptionStatusResponse(BaseModel):
subscribed: bool


class BillingInformation(BaseModel):
stripe_subscription_id: str
status: str
current_period_start: datetime
current_period_end: datetime
number_of_seats: int
cancel_at_period_end: bool
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool


Expand All @@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):

class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None


class ProductGatingResponse(BaseModel):
updated: bool
error: str | None


class SubscriptionSessionResponse(BaseModel):
sessionId: str
51 changes: 51 additions & 0 deletions backend/ee/onyx/server/tenants/product_gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import cast

from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

logger = setup_logger()


def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)

# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)

# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
else:
redis_client.srem(GATED_TENANTS_KEY, tenant_id)


def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

settings = load_settings()
settings.application_status = application_status
store_settings(settings)

# Store gated tenant information in Redis
update_tenant_gating(tenant_id, application_status)

if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

except Exception:
logger.exception("Failed to gate product")
raise


def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings


def rerank_documents(
Expand All @@ -39,6 +40,8 @@ def rerank_documents(

# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query

graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
Expand All @@ -47,39 +50,28 @@ def rerank_documents(
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)

# skip section filtering
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings

if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if rerank_settings is None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)

if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
Expand Down Expand Up @@ -67,9 +68,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
Expand Down
Loading

0 comments on commit f84720b

Please sign in to comment.