From f1fc8ac19be458314b406dba6be16c29f0e99fde Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Sat, 15 Feb 2025 18:34:39 -0800 Subject: [PATCH] Connector checkpointing (#3876) * wip checkpointing/continue on failure more stuff for checkpointing Basic implementation FE stuff More checkpointing/failure handling rebase rebase initial scaffolding for IT IT to test checkpointing Cleanup cleanup Fix it Rebase Add todo Fix actions IT Test more Pagination + fixes + cleanup Fix IT networking fix it * rebase * Address misc comments * Address comments * Remove unused router * rebase * Fix mypy * Fixes * fix it * Fix tests * Add drop index * Add retries * reset lock timeout * Try hard drop of schema * Add timeout/retries to downgrade * rebase * test * test * test * Close all connections * test closing idle only * Fix it * fix * try using null pool * Test * fix * rebase * log * Fix * apply null pool * Fix other test * Fix quality checks * Test not using the fixture * Fix ordering * fix test * Change pooling behavior --- .github/workflows/pr-integration-tests.yml | 35 +- .vscode/launch.template.jsonc | 2 +- ...aa15_add_checkpointing_failure_handling.py | 124 +++++ .../external_permissions/slack/doc_sync.py | 4 +- .../background/celery/tasks/beat_schedule.py | 9 + .../background/celery/tasks/indexing/tasks.py | 64 +++ .../onyx/background/indexing/checkpointing.py | 80 --- .../indexing/checkpointing_utils.py | 200 +++++++ .../onyx/background/indexing/memory_tracer.py | 87 +++ backend/onyx/background/indexing/models.py | 40 ++ .../onyx/background/indexing/run_indexing.py | 483 +++++++++------- backend/onyx/background/indexing/tracer.py | 77 --- backend/onyx/configs/app_configs.py | 5 + backend/onyx/configs/constants.py | 10 + backend/onyx/connectors/connector_runner.py | 146 ++++- backend/onyx/connectors/factory.py | 31 +- backend/onyx/connectors/interfaces.py | 34 ++ .../connectors/mock_connector/connector.py | 86 +++ backend/onyx/connectors/models.py | 67 ++- backend/onyx/connectors/slack/connector.py | 395 ++++++++++--- backend/onyx/connectors/slack/utils.py | 11 +- backend/onyx/db/engine.py | 54 +- backend/onyx/db/index_attempt.py | 98 +++- backend/onyx/db/models.py | 65 ++- backend/onyx/indexing/embedder.py | 50 ++ backend/onyx/indexing/indexing_pipeline.py | 97 ++-- backend/onyx/indexing/models.py | 7 + backend/onyx/indexing/vector_db_insertion.py | 99 ++++ backend/onyx/main.py | 2 - backend/onyx/server/documents/cc_pair.py | 44 ++ backend/onyx/server/documents/indexing.py | 23 - backend/onyx/server/documents/models.py | 31 +- backend/onyx/utils/object_size_check.py | 26 + backend/scripts/dev_run_background_jobs.py | 2 +- backend/supervisord.conf | 2 +- .../tests/integration/common_utils/chat.py | 11 +- .../integration/common_utils/constants.py | 5 + .../common_utils/managers/cc_pair.py | 3 +- .../common_utils/managers/document.py | 41 ++ .../common_utils/managers/index_attempt.py | 130 ++++- .../tests/integration/common_utils/reset.py | 75 ++- .../common_utils/test_document_utils.py | 57 ++ .../tests/integration/common_utils/timeout.py | 18 + backend/tests/integration/conftest.py | 56 +- .../test_google_drive_permission_sync.py | 8 +- .../slack/test_permission_sync.py | 8 +- .../connector_job_tests/slack/test_prune.py | 4 +- .../docker-compose.mock-it-services.yml | 20 + .../mock_connector_server/Dockerfile | 9 + .../mock_connector_server/main.py | 76 +++ .../connector/test_connector_deletion.py | 28 +- .../tests/indexing/test_checkpointing.py | 518 ++++++++++++++++++ .../docker_compose/docker-compose.dev.yml | 2 + .../[ccPairId]/IndexAttemptErrorsModal.tsx | 141 +++++ .../[ccPairId]/IndexingAttemptsTable.tsx | 62 +-- .../connector/[ccPairId]/ReIndexButton.tsx | 40 +- web/src/app/admin/connector/[ccPairId]/lib.ts | 29 + .../app/admin/connector/[ccPairId]/page.tsx | 161 +++++- .../app/admin/connector/[ccPairId]/types.ts | 24 + .../indexing/[id]/IndexAttemptErrorsTable.tsx | 189 ------- web/src/app/admin/indexing/[id]/lib.ts | 3 - web/src/app/admin/indexing/[id]/page.tsx | 59 -- web/src/app/admin/indexing/[id]/types.ts | 15 - web/src/components/Status.tsx | 16 +- web/src/hooks/usePaginatedFetch.tsx | 11 +- web/src/lib/connectors/connectors.tsx | 2 + web/src/lib/sources.ts | 7 + web/src/lib/types.ts | 1 + 68 files changed, 3325 insertions(+), 1094 deletions(-) create mode 100644 backend/alembic/versions/b7a7eee5aa15_add_checkpointing_failure_handling.py delete mode 100644 backend/onyx/background/indexing/checkpointing.py create mode 100644 backend/onyx/background/indexing/checkpointing_utils.py create mode 100644 backend/onyx/background/indexing/memory_tracer.py create mode 100644 backend/onyx/background/indexing/models.py delete mode 100644 backend/onyx/background/indexing/tracer.py create mode 100644 backend/onyx/connectors/mock_connector/connector.py create mode 100644 backend/onyx/indexing/vector_db_insertion.py delete mode 100644 backend/onyx/server/documents/indexing.py create mode 100644 backend/onyx/utils/object_size_check.py create mode 100644 backend/tests/integration/common_utils/test_document_utils.py create mode 100644 backend/tests/integration/common_utils/timeout.py create mode 100644 backend/tests/integration/mock_services/docker-compose.mock-it-services.yml create mode 100644 backend/tests/integration/mock_services/mock_connector_server/Dockerfile create mode 100644 backend/tests/integration/mock_services/mock_connector_server/main.py create mode 100644 backend/tests/integration/tests/indexing/test_checkpointing.py create mode 100644 web/src/app/admin/connector/[ccPairId]/IndexAttemptErrorsModal.tsx delete mode 100644 web/src/app/admin/indexing/[id]/IndexAttemptErrorsTable.tsx delete mode 100644 web/src/app/admin/indexing/[id]/lib.ts delete mode 100644 web/src/app/admin/indexing/[id]/page.tsx delete mode 100644 web/src/app/admin/indexing/[id]/types.ts diff --git a/.github/workflows/pr-integration-tests.yml b/.github/workflows/pr-integration-tests.yml index 4d4ef5d3883..5573c51d9f2 100644 --- a/.github/workflows/pr-integration-tests.yml +++ b/.github/workflows/pr-integration-tests.yml @@ -99,7 +99,7 @@ jobs: DISABLE_TELEMETRY=true \ IMAGE_TAG=test \ DEV_MODE=true \ - docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d + docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d id: start_docker_multi_tenant # In practice, `cloud` Auth type would require OAUTH credentials to be set. @@ -108,12 +108,13 @@ jobs: echo "Waiting for 3 minutes to ensure API server is ready..." sleep 180 echo "Running integration tests..." - docker run --rm --network danswer-stack_default \ + docker run --rm --network onyx-stack_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ + -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ @@ -143,24 +144,27 @@ jobs: - name: Stop multi-tenant Docker containers run: | cd deployment/docker_compose - docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v - + docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v + + # NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections - name: Start Docker containers run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ AUTH_TYPE=basic \ + POSTGRES_POOL_PRE_PING=true \ + POSTGRES_USE_NULL_POOL=true \ REQUIRE_EMAIL_VERIFICATION=false \ DISABLE_TELEMETRY=true \ IMAGE_TAG=test \ - docker compose -f docker-compose.dev.yml -p danswer-stack up -d + docker compose -f docker-compose.dev.yml -p onyx-stack up -d id: start_docker - name: Wait for service to be ready run: | echo "Starting wait-for-service script..." - docker logs -f danswer-stack-api_server-1 & + docker logs -f onyx-stack-api_server-1 & start_time=$(date +%s) timeout=300 # 5 minutes in seconds @@ -190,15 +194,24 @@ jobs: done echo "Finished waiting for service." + - name: Start Mock Services + run: | + cd backend/tests/integration/mock_services + docker compose -f docker-compose.mock-it-services.yml \ + -p mock-it-services-stack up -d + + # NOTE: Use pre-ping/null to reduce flakiness due to dropped connections - name: Run Standard Integration Tests run: | echo "Running integration tests..." - docker run --rm --network danswer-stack_default \ + docker run --rm --network onyx-stack_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ + -e POSTGRES_POOL_PRE_PING=true \ + -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ @@ -208,6 +221,8 @@ jobs: -e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \ -e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \ -e TEST_WEB_HOSTNAME=test-runner \ + -e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \ + -e MOCK_CONNECTOR_SERVER_PORT=8001 \ onyxdotapp/onyx-integration:test \ /app/tests/integration/tests \ /app/tests/integration/connector_job_tests @@ -229,13 +244,13 @@ jobs: if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true + docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true - name: Dump all-container logs (optional) if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true + docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true - name: Upload logs if: always() @@ -249,4 +264,4 @@ jobs: if: always() run: | cd deployment/docker_compose - docker compose -f docker-compose.dev.yml -p danswer-stack down -v + docker compose -f docker-compose.dev.yml -p onyx-stack down -v diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 8c965d36e80..f1454ca8e9c 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -205,7 +205,7 @@ "--loglevel=INFO", "--hostname=light@%n", "-Q", - "vespa_metadata_sync,connector_deletion,doc_permissions_upsert", + "vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup", ], "presentation": { "group": "2", diff --git a/backend/alembic/versions/b7a7eee5aa15_add_checkpointing_failure_handling.py b/backend/alembic/versions/b7a7eee5aa15_add_checkpointing_failure_handling.py new file mode 100644 index 00000000000..82205548621 --- /dev/null +++ b/backend/alembic/versions/b7a7eee5aa15_add_checkpointing_failure_handling.py @@ -0,0 +1,124 @@ +"""Add checkpointing/failure handling + +Revision ID: b7a7eee5aa15 +Revises: f39c5794c10a +Create Date: 2025-01-24 15:17:36.763172 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "b7a7eee5aa15" +down_revision = "f39c5794c10a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "index_attempt", + sa.Column("checkpoint_pointer", sa.String(), nullable=True), + ) + op.add_column( + "index_attempt", + sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True), + ) + op.add_column( + "index_attempt", + sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True), + ) + + op.create_index( + "ix_index_attempt_cc_pair_settings_poll", + "index_attempt", + [ + "connector_credential_pair_id", + "search_settings_id", + "status", + sa.text("time_updated DESC"), + ], + ) + + # Drop the old IndexAttemptError table + op.drop_index("index_attempt_id", table_name="index_attempt_errors") + op.drop_table("index_attempt_errors") + + # Create the new version of the table + op.create_table( + "index_attempt_errors", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("index_attempt_id", sa.Integer(), nullable=False), + sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False), + sa.Column("document_id", sa.String(), nullable=True), + sa.Column("document_link", sa.String(), nullable=True), + sa.Column("entity_id", sa.String(), nullable=True), + sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True), + sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True), + sa.Column("failure_message", sa.Text(), nullable=False), + sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["index_attempt_id"], + ["index_attempt.id"], + ), + sa.ForeignKeyConstraint( + ["connector_credential_pair_id"], + ["connector_credential_pair.id"], + ), + ) + + +def downgrade() -> None: + op.execute("SET lock_timeout = '5s'") + + # try a few times to drop the table, this has been observed to fail due to other locks + # blocking the drop + NUM_TRIES = 10 + for i in range(NUM_TRIES): + try: + op.drop_table("index_attempt_errors") + break + except Exception as e: + if i == NUM_TRIES - 1: + raise e + print(f"Error dropping table: {e}. Retrying...") + + op.execute("SET lock_timeout = DEFAULT") + + # Recreate the old IndexAttemptError table + op.create_table( + "index_attempt_errors", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("index_attempt_id", sa.Integer(), nullable=True), + sa.Column("batch", sa.Integer(), nullable=True), + sa.Column("doc_summaries", postgresql.JSONB(), nullable=False), + sa.Column("error_msg", sa.Text(), nullable=True), + sa.Column("traceback", sa.Text(), nullable=True), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + ), + sa.ForeignKeyConstraint( + ["index_attempt_id"], + ["index_attempt.id"], + ), + ) + + op.create_index( + "index_attempt_id", + "index_attempt_errors", + ["time_created"], + ) + + op.drop_index("ix_index_attempt_cc_pair_settings_poll") + op.drop_column("index_attempt", "checkpoint_pointer") + op.drop_column("index_attempt", "poll_range_start") + op.drop_column("index_attempt", "poll_range_end") diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index 07623bae200..9522c906d2e 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -5,7 +5,7 @@ from onyx.access.models import ExternalAccess from onyx.connectors.slack.connector import get_channels from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries -from onyx.connectors.slack.connector import SlackPollConnector +from onyx.connectors.slack.connector import SlackConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger @@ -17,7 +17,7 @@ def _get_slack_document_ids_and_channels( cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> dict[str, list[str]]: - slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config) + slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config) slack_connector.load_credentials(cc_pair.credential.credential_json) slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback) diff --git a/backend/onyx/background/celery/tasks/beat_schedule.py b/backend/onyx/background/celery/tasks/beat_schedule.py index 01ca723810f..689e9436be8 100644 --- a/backend/onyx/background/celery/tasks/beat_schedule.py +++ b/backend/onyx/background/celery/tasks/beat_schedule.py @@ -36,6 +36,15 @@ "expires": BEAT_EXPIRES_DEFAULT, }, }, + { + "name": "check-for-checkpoint-cleanup", + "task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP, + "schedule": timedelta(hours=1), + "options": { + "priority": OnyxCeleryPriority.LOW, + "expires": BEAT_EXPIRES_DEFAULT, + }, + }, { "name": "check-for-connector-deletion", "task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION, diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 3f193d2693c..2b9c1422ec1 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -28,6 +28,10 @@ from onyx.background.celery.tasks.indexing.utils import IndexingCallback from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences +from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint +from onyx.background.indexing.checkpointing_utils import ( + get_index_attempts_with_old_checkpoints, +) from onyx.background.indexing.job_client import SimpleJob from onyx.background.indexing.job_client import SimpleJobClient from onyx.background.indexing.job_client import SimpleJobException @@ -38,6 +42,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT +from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks @@ -1069,3 +1074,62 @@ def connector_indexing_proxy_task( redis_connector_index.set_watchdog(False) return + + +@shared_task( + name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP, + soft_time_limit=300, +) +def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None: + """Clean up old checkpoints that are older than 7 days.""" + locked = False + redis_client = get_redis_client(tenant_id=tenant_id) + lock: RedisLock = redis_client.lock( + OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK, + timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, + ) + + # these tasks should never overlap + if not lock.acquire(blocking=False): + return None + + try: + locked = True + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + old_attempts = get_index_attempts_with_old_checkpoints(db_session) + for attempt in old_attempts: + task_logger.info( + f"Cleaning up checkpoint for index attempt {attempt.id}" + ) + cleanup_checkpoint_task.apply_async( + kwargs={ + "index_attempt_id": attempt.id, + "tenant_id": tenant_id, + }, + queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP, + ) + + except Exception: + task_logger.exception("Unexpected exception during checkpoint cleanup") + return None + finally: + if locked: + if lock.owned(): + lock.release() + else: + task_logger.error( + "check_for_checkpoint_cleanup - Lock not owned on completion: " + f"tenant={tenant_id}" + ) + + +@shared_task( + name=OnyxCeleryTask.CLEANUP_CHECKPOINT, + bind=True, +) +def cleanup_checkpoint_task( + self: Task, *, index_attempt_id: int, tenant_id: str | None +) -> None: + """Clean up a checkpoint for a given index attempt""" + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + cleanup_checkpoint(db_session, index_attempt_id) diff --git a/backend/onyx/background/indexing/checkpointing.py b/backend/onyx/background/indexing/checkpointing.py deleted file mode 100644 index a7376aab01b..00000000000 --- a/backend/onyx/background/indexing/checkpointing.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Experimental functionality related to splitting up indexing -into a series of checkpoints to better handle intermittent failures -/ jobs being killed by cloud providers.""" -import datetime - -from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED -from onyx.configs.constants import DocumentSource -from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc - - -def _2010_dt() -> datetime.datetime: - return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc) - - -def _2020_dt() -> datetime.datetime: - return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc) - - -def _default_end_time( - last_successful_run: datetime.datetime | None, -) -> datetime.datetime: - """If year is before 2010, go to the beginning of 2010. - If year is 2010-2020, go in 5 year increments. - If year > 2020, then go in 180 day increments. - - For connectors that don't support a `filter_by` and instead rely on `sort_by` - for polling, then this will cause a massive duplication of fetches. For these - connectors, you may want to override this function to return a more reasonable - plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher).""" - last_successful_run = ( - datetime_to_utc(last_successful_run) if last_successful_run else None - ) - if last_successful_run is None or last_successful_run < _2010_dt(): - return _2010_dt() - - if last_successful_run < _2020_dt(): - return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt()) - - return last_successful_run + datetime.timedelta(days=180) - - -def find_end_time_for_indexing_attempt( - last_successful_run: datetime.datetime | None, - # source_type can be used to override the default for certain connectors, currently unused - source_type: DocumentSource, -) -> datetime.datetime | None: - """Is the current time unless the connector is run over a large period, in which case it is - split up into large time segments that become smaller as it approaches the present - """ - # NOTE: source_type can be used to override the default for certain connectors - end_of_window = _default_end_time(last_successful_run) - now = datetime.datetime.now(tz=datetime.timezone.utc) - if end_of_window < now: - return end_of_window - - # None signals that we should index up to current time - return None - - -def get_time_windows_for_index_attempt( - last_successful_run: datetime.datetime, source_type: DocumentSource -) -> list[tuple[datetime.datetime, datetime.datetime]]: - if not EXPERIMENTAL_CHECKPOINTING_ENABLED: - return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))] - - time_windows: list[tuple[datetime.datetime, datetime.datetime]] = [] - start_of_window: datetime.datetime | None = last_successful_run - while start_of_window: - end_of_window = find_end_time_for_indexing_attempt( - last_successful_run=start_of_window, source_type=source_type - ) - time_windows.append( - ( - start_of_window, - end_of_window or datetime.datetime.now(tz=datetime.timezone.utc), - ) - ) - start_of_window = end_of_window - - return time_windows diff --git a/backend/onyx/background/indexing/checkpointing_utils.py b/backend/onyx/background/indexing/checkpointing_utils.py new file mode 100644 index 00000000000..254481e14af --- /dev/null +++ b/backend/onyx/background/indexing/checkpointing_utils.py @@ -0,0 +1,200 @@ +from datetime import datetime +from datetime import timedelta +from io import BytesIO + +from sqlalchemy import and_ +from sqlalchemy.orm import Session + +from onyx.configs.constants import FileOrigin +from onyx.connectors.models import ConnectorCheckpoint +from onyx.db.engine import get_db_current_time +from onyx.db.index_attempt import get_index_attempt +from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair +from onyx.db.models import IndexAttempt +from onyx.db.models import IndexingStatus +from onyx.file_store.file_store import get_default_file_store +from onyx.utils.logger import setup_logger +from onyx.utils.object_size_check import deep_getsizeof + + +logger = setup_logger() + +_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20 +_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100 + + +def _build_checkpoint_pointer(index_attempt_id: int) -> str: + return f"checkpoint_{index_attempt_id}.json" + + +def save_checkpoint( + db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint +) -> str: + """Save a checkpoint for a given index attempt to the file store""" + checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id) + + file_store = get_default_file_store(db_session) + file_store.save_file( + file_name=checkpoint_pointer, + content=BytesIO(checkpoint.model_dump_json().encode()), + display_name=checkpoint_pointer, + file_origin=FileOrigin.INDEXING_CHECKPOINT, + file_type="application/json", + ) + + index_attempt = get_index_attempt(db_session, index_attempt_id) + if not index_attempt: + raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") + index_attempt.checkpoint_pointer = checkpoint_pointer + db_session.add(index_attempt) + db_session.commit() + return checkpoint_pointer + + +def load_checkpoint( + db_session: Session, index_attempt_id: int +) -> ConnectorCheckpoint | None: + """Load a checkpoint for a given index attempt from the file store""" + checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id) + file_store = get_default_file_store(db_session) + try: + checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb") + checkpoint_data = checkpoint_io.read().decode("utf-8") + return ConnectorCheckpoint.model_validate_json(checkpoint_data) + except RuntimeError: + return None + + +def get_latest_valid_checkpoint( + db_session: Session, + cc_pair_id: int, + search_settings_id: int, + window_start: datetime, + window_end: datetime, +) -> ConnectorCheckpoint: + """Get the latest valid checkpoint for a given connector credential pair""" + checkpoint_candidates = get_recent_completed_attempts_for_cc_pair( + cc_pair_id=cc_pair_id, + search_settings_id=search_settings_id, + db_session=db_session, + limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER, + ) + checkpoint_candidates = [ + candidate + for candidate in checkpoint_candidates + if ( + candidate.poll_range_start == window_start + and candidate.poll_range_end == window_end + and candidate.status == IndexingStatus.FAILED + and candidate.checkpoint_pointer is not None + # we want to make sure that the checkpoint is actually useful + # if it's only gone through a few docs, it's probably not worth + # using. This also avoids weird cases where a connector is basically + # non-functional but still "makes progress" by slowly moving the + # checkpoint forward run after run + and candidate.total_docs_indexed + and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT + ) + ] + + # don't keep using checkpoints if we've had a bunch of failed attempts in a row + # for now, capped at 10 + if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER: + logger.warning( + f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found " + f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start " + "from scratch." + ) + return ConnectorCheckpoint.build_dummy_checkpoint() + + # assumes latest checkpoint is the furthest along. This only isn't true + # if something else has gone wrong. + latest_valid_checkpoint_candidate = ( + checkpoint_candidates[0] if checkpoint_candidates else None + ) + + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + if latest_valid_checkpoint_candidate: + try: + previous_checkpoint = load_checkpoint( + db_session=db_session, + index_attempt_id=latest_valid_checkpoint_candidate.id, + ) + except Exception: + logger.exception( + f"Failed to load checkpoint from previous failed attempt with ID " + f"{latest_valid_checkpoint_candidate.id}." + ) + previous_checkpoint = None + + if previous_checkpoint is not None: + logger.info( + f"Using checkpoint from previous failed attempt with ID " + f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: " + f"{previous_checkpoint}" + ) + save_checkpoint( + db_session=db_session, + index_attempt_id=latest_valid_checkpoint_candidate.id, + checkpoint=previous_checkpoint, + ) + checkpoint = previous_checkpoint + + return checkpoint + + +def get_index_attempts_with_old_checkpoints( + db_session: Session, days_to_keep: int = 7 +) -> list[IndexAttempt]: + """Get all index attempts with checkpoints older than the specified number of days. + + Args: + db_session: The database session + days_to_keep: Number of days to keep checkpoints for (default: 7) + + Returns: + Number of checkpoints deleted + """ + cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep) + + # Find all index attempts with checkpoints older than cutoff_date + old_attempts = ( + db_session.query(IndexAttempt) + .filter( + and_( + IndexAttempt.checkpoint_pointer.isnot(None), + IndexAttempt.time_created < cutoff_date, + ) + ) + .all() + ) + + return old_attempts + + +def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None: + """Clean up a checkpoint for a given index attempt""" + index_attempt = get_index_attempt(db_session, index_attempt_id) + if not index_attempt: + raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") + + if not index_attempt.checkpoint_pointer: + return None + + file_store = get_default_file_store(db_session) + file_store.delete_file(index_attempt.checkpoint_pointer) + + index_attempt.checkpoint_pointer = None + db_session.add(index_attempt) + db_session.commit() + + return None + + +def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None: + """Check if the checkpoint content size exceeds the limit (200MB)""" + content_size = deep_getsizeof(checkpoint.checkpoint_content) + if content_size > 200_000_000: # 200MB in bytes + raise ValueError( + f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit" + ) diff --git a/backend/onyx/background/indexing/memory_tracer.py b/backend/onyx/background/indexing/memory_tracer.py new file mode 100644 index 00000000000..a23af2e9226 --- /dev/null +++ b/backend/onyx/background/indexing/memory_tracer.py @@ -0,0 +1,87 @@ +import tracemalloc + +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +DANSWER_TRACEMALLOC_FRAMES = 10 + + +class MemoryTracer: + def __init__(self, interval: int = 0, num_print_entries: int = 5): + self.interval = interval + self.num_print_entries = num_print_entries + self.snapshot_first: tracemalloc.Snapshot | None = None + self.snapshot_prev: tracemalloc.Snapshot | None = None + self.snapshot: tracemalloc.Snapshot | None = None + self.counter = 0 + + def start(self) -> None: + """Start the memory tracer if interval is greater than 0.""" + if self.interval > 0: + logger.debug(f"Memory tracer starting: interval={self.interval}") + tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES) + self._take_snapshot() + + def stop(self) -> None: + """Stop the memory tracer if it's running.""" + if self.interval > 0: + self.log_final_diff() + tracemalloc.stop() + logger.debug("Memory tracer stopped.") + + def _take_snapshot(self) -> None: + """Take a snapshot and update internal snapshot states.""" + snapshot = tracemalloc.take_snapshot() + # Filter out irrelevant frames + snapshot = snapshot.filter_traces( + ( + tracemalloc.Filter(False, tracemalloc.__file__), + tracemalloc.Filter(False, ""), + tracemalloc.Filter(False, ""), + ) + ) + + if not self.snapshot_first: + self.snapshot_first = snapshot + + if self.snapshot: + self.snapshot_prev = self.snapshot + + self.snapshot = snapshot + + def _log_diff( + self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot + ) -> None: + """Log the memory difference between two snapshots.""" + stats = current.compare_to(previous, "traceback") + for s in stats[: self.num_print_entries]: + logger.debug(f"Tracer diff: {s}") + for line in s.traceback.format(): + logger.debug(f"* {line}") + + def increment_and_maybe_trace(self) -> None: + """Increment counter and perform trace if interval is hit.""" + if self.interval <= 0: + return + + self.counter += 1 + if self.counter % self.interval == 0: + logger.debug( + f"Running trace comparison for batch {self.counter}. interval={self.interval}" + ) + self._take_snapshot() + if self.snapshot and self.snapshot_prev: + self._log_diff(self.snapshot, self.snapshot_prev) + + def log_final_diff(self) -> None: + """Log the final memory diff between start and end of indexing.""" + if self.interval <= 0: + return + + logger.debug( + f"Running trace comparison between start and end of indexing. {self.counter} batches processed." + ) + self._take_snapshot() + if self.snapshot and self.snapshot_first: + self._log_diff(self.snapshot, self.snapshot_first) diff --git a/backend/onyx/background/indexing/models.py b/backend/onyx/background/indexing/models.py new file mode 100644 index 00000000000..6bfd364ccf1 --- /dev/null +++ b/backend/onyx/background/indexing/models.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from pydantic import BaseModel + +from onyx.db.models import IndexAttemptError + + +class IndexAttemptErrorPydantic(BaseModel): + id: int + connector_credential_pair_id: int + + document_id: str | None + document_link: str | None + + entity_id: str | None + failed_time_range_start: datetime | None + failed_time_range_end: datetime | None + + failure_message: str + is_resolved: bool = False + + time_created: datetime + + index_attempt_id: int + + @classmethod + def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic": + return cls( + id=model.id, + connector_credential_pair_id=model.connector_credential_pair_id, + document_id=model.document_id, + document_link=model.document_link, + entity_id=model.entity_id, + failed_time_range_start=model.failed_time_range_start, + failed_time_range_end=model.failed_time_range_end, + failure_message=model.failure_message, + is_resolved=model.is_resolved, + time_created=model.time_created, + index_attempt_id=model.index_attempt_id, + ) diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 88c121d45ed..53bea42ecff 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -1,5 +1,6 @@ import time import traceback +from collections import defaultdict from datetime import datetime from datetime import timedelta from datetime import timezone @@ -7,8 +8,11 @@ from pydantic import BaseModel from sqlalchemy.orm import Session -from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt -from onyx.background.indexing.tracer import OnyxTracer +from onyx.background.indexing.checkpointing_utils import check_checkpoint_size +from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint +from onyx.background.indexing.checkpointing_utils import save_checkpoint +from onyx.background.indexing.memory_tracer import MemoryTracer +from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE @@ -17,6 +21,8 @@ from onyx.configs.constants import MilestoneRecordType from onyx.connectors.connector_runner import ConnectorRunner from onyx.connectors.factory import instantiate_connector +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import IndexAttemptMetadata from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id @@ -24,15 +30,18 @@ from onyx.db.connector_credential_pair import update_connector_credential_pair from onyx.db.engine import get_session_with_tenant from onyx.db.enums import ConnectorCredentialPairStatus +from onyx.db.index_attempt import create_index_attempt_error from onyx.db.index_attempt import get_index_attempt +from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair +from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair from onyx.db.index_attempt import mark_attempt_canceled from onyx.db.index_attempt import mark_attempt_failed from onyx.db.index_attempt import mark_attempt_partially_succeeded from onyx.db.index_attempt import mark_attempt_succeeded from onyx.db.index_attempt import transition_attempt_to_in_progress from onyx.db.index_attempt import update_docs_indexed -from onyx.db.models import ConnectorCredentialPair from onyx.db.models import IndexAttempt +from onyx.db.models import IndexAttemptError from onyx.db.models import IndexingStatus from onyx.db.models import IndexModelStatus from onyx.document_index.factory import get_default_document_index @@ -53,6 +62,7 @@ def _get_connector_runner( db_session: Session, attempt: IndexAttempt, + batch_size: int, start_time: datetime, end_time: datetime, tenant_id: str | None, @@ -100,7 +110,9 @@ def _get_connector_runner( raise e return ConnectorRunner( - connector=runnable_connector, time_range=(start_time, end_time) + connector=runnable_connector, + batch_size=batch_size, + time_range=(start_time, end_time), ) @@ -159,6 +171,66 @@ class RunIndexingContext(BaseModel): search_settings_status: IndexModelStatus +def _check_connector_and_attempt_status( + db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int +) -> None: + """ + Checks the status of the connector credential pair and index attempt. + Raises a RuntimeError if any conditions are not met. + """ + cc_pair_loop = get_connector_credential_pair_from_id( + db_session_temp, + ctx.cc_pair_id, + ) + if not cc_pair_loop: + raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.") + + if ( + cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED + and ctx.search_settings_status != IndexModelStatus.FUTURE + ) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING: + raise RuntimeError("Connector was disabled mid run") + + index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id) + if not index_attempt_loop: + raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") + + if index_attempt_loop.status != IndexingStatus.IN_PROGRESS: + raise RuntimeError( + f"Index Attempt was canceled, status is {index_attempt_loop.status}" + ) + + +def _check_failure_threshold( + total_failures: int, + document_count: int, + batch_num: int, + last_failure: ConnectorFailure | None, +) -> None: + """Check if we've hit the failure threshold and raise an appropriate exception if so. + + We consider the threshold hit if: + 1. We have more than 3 failures AND + 2. Failures account for more than 10% of processed documents + """ + failure_ratio = total_failures / (document_count or 1) + + FAILURE_THRESHOLD = 3 + FAILURE_RATIO_THRESHOLD = 0.1 + if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD: + logger.error( + f"Connector run failed with '{total_failures}' errors " + f"after '{batch_num}' batches." + ) + if last_failure and last_failure.exception: + raise last_failure.exception from last_failure.exception + + raise RuntimeError( + f"Connector run encountered too many errors, aborting. " + f"Last error: {last_failure}" + ) + + def _run_indexing( db_session: Session, index_attempt_id: int, @@ -169,11 +241,8 @@ def _run_indexing( 1. Get documents which are either new or updated from specified application 2. Embed and index these documents into the chosen datastore (vespa) 3. Updates Postgres to record the indexed documents + the outcome of this run - - TODO: do not change index attempt statuses here ... instead, set signals in redis - and allow the monitor function to clean them up """ - start_time = time.time() + start_time = time.monotonic() # jsut used for logging with get_session_with_tenant(tenant_id) as db_session_temp: index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id) @@ -221,6 +290,46 @@ def _run_indexing( db_session=db_session_temp, ) ) + if last_successful_index_time > POLL_CONNECTOR_OFFSET: + window_start = datetime.fromtimestamp( + last_successful_index_time, tz=timezone.utc + ) - timedelta(minutes=POLL_CONNECTOR_OFFSET) + else: + # don't go into "negative" time if we've never indexed before + window_start = datetime.fromtimestamp(0, tz=timezone.utc) + + most_recent_attempt = next( + iter( + get_recent_completed_attempts_for_cc_pair( + cc_pair_id=ctx.cc_pair_id, + search_settings_id=index_attempt_start.search_settings_id, + db_session=db_session_temp, + limit=1, + ) + ), + None, + ) + # if the last attempt failed, try and use the same window. This is necessary + # to ensure correctness with checkpointing. If we don't do this, things like + # new slack channels could be missed (since existing slack channels are + # cached as part of the checkpoint). + if ( + most_recent_attempt + and most_recent_attempt.poll_range_end + and ( + most_recent_attempt.status == IndexingStatus.FAILED + or most_recent_attempt.status == IndexingStatus.CANCELED + ) + ): + window_end = most_recent_attempt.poll_range_end + else: + window_end = datetime.now(tz=timezone.utc) + + # add start/end now that they have been set + index_attempt_start.poll_range_start = window_start + index_attempt_start.poll_range_end = window_end + db_session_temp.add(index_attempt_start) + db_session_temp.commit() embedding_model = DefaultIndexingEmbedder.from_db_search_settings( search_settings=index_attempt_start.search_settings, @@ -234,7 +343,6 @@ def _run_indexing( ) indexing_pipeline = build_indexing_pipeline( - attempt_id=index_attempt_id, embedder=embedding_model, document_index=document_index, ignore_time_skip=( @@ -246,63 +354,73 @@ def _run_indexing( callback=callback, ) - tracer: OnyxTracer - if INDEXING_TRACER_INTERVAL > 0: - logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}") - tracer = OnyxTracer() - tracer.start() - tracer.snap() + # Initialize memory tracer. NOTE: won't actually do anything if + # `INDEXING_TRACER_INTERVAL` is 0. + memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL) + memory_tracer.start() index_attempt_md = IndexAttemptMetadata( connector_id=ctx.connector_id, credential_id=ctx.credential_id, ) + total_failures = 0 batch_num = 0 net_doc_change = 0 document_count = 0 chunk_count = 0 - run_end_dt = None - tracer_counter: int + try: + with get_session_with_tenant(tenant_id) as db_session_temp: + index_attempt = get_index_attempt(db_session_temp, index_attempt_id) + if not index_attempt: + raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") - for ind, (window_start, window_end) in enumerate( - get_time_windows_for_index_attempt( - last_successful_run=datetime.fromtimestamp( - last_successful_index_time, tz=timezone.utc - ), - source_type=db_connector.source, - ) - ): - cc_pair_loop: ConnectorCredentialPair | None = None - index_attempt_loop: IndexAttempt | None = None - tracer_counter = 0 - - try: - window_start = max( - window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET), - datetime(1970, 1, 1, tzinfo=timezone.utc), + connector_runner = _get_connector_runner( + db_session=db_session_temp, + attempt=index_attempt, + batch_size=INDEX_BATCH_SIZE, + start_time=window_start, + end_time=window_end, + tenant_id=tenant_id, ) - with get_session_with_tenant(tenant_id) as db_session_temp: - index_attempt_loop_start = get_index_attempt( - db_session_temp, index_attempt_id - ) - if not index_attempt_loop_start: - raise RuntimeError( - f"Index attempt {index_attempt_id} not found in DB." - ) - - connector_runner = _get_connector_runner( + # don't use a checkpoint if we're explicitly indexing from + # the beginning in order to avoid weird interactions between + # checkpointing / failure handling. + if index_attempt.from_beginning: + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + else: + checkpoint = get_latest_valid_checkpoint( db_session=db_session_temp, - attempt=index_attempt_loop_start, - start_time=window_start, - end_time=window_end, - tenant_id=tenant_id, + cc_pair_id=ctx.cc_pair_id, + search_settings_id=index_attempt.search_settings_id, + window_start=window_start, + window_end=window_end, ) - if INDEXING_TRACER_INTERVAL > 0: - tracer.snap() - for doc_batch in connector_runner.run(): + unresolved_errors = get_index_attempt_errors_for_cc_pair( + cc_pair_id=ctx.cc_pair_id, + unresolved_only=True, + db_session=db_session_temp, + ) + doc_id_to_unresolved_errors: dict[ + str, list[IndexAttemptError] + ] = defaultdict(list) + for error in unresolved_errors: + if error.document_id: + doc_id_to_unresolved_errors[error.document_id].append(error) + + entity_based_unresolved_errors = [ + error for error in unresolved_errors if error.entity_id + ] + + while checkpoint.has_more: + logger.info( + f"Running '{ctx.source}' connector with checkpoint: {checkpoint}" + ) + for document_batch, failure, next_checkpoint in connector_runner.run( + checkpoint + ): # Check if connector is disabled mid run and stop if so unless it's the secondary # index being built. We want to populate it even for paused connectors # Often paused connectors are sources that aren't updated frequently but the @@ -313,41 +431,37 @@ def _run_indexing( # TODO: should we move this into the above callback instead? with get_session_with_tenant(tenant_id) as db_session_temp: - cc_pair_loop = get_connector_credential_pair_from_id( - db_session_temp, - ctx.cc_pair_id, + # will exception if the connector/index attempt is marked as paused/failed + _check_connector_and_attempt_status( + db_session_temp, ctx, index_attempt_id ) - if not cc_pair_loop: - raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.") - if ( - ( - cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED - and ctx.search_settings_status != IndexModelStatus.FUTURE + # save record of any failures at the connector level + if failure is not None: + total_failures += 1 + with get_session_with_tenant(tenant_id) as db_session_temp: + create_index_attempt_error( + index_attempt_id, + ctx.cc_pair_id, + failure, + db_session_temp, ) - # if it's deleting, we don't care if this is a secondary index - or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING - ): - # let the `except` block handle this - raise RuntimeError("Connector was disabled mid run") - - index_attempt_loop = get_index_attempt( - db_session_temp, index_attempt_id + + _check_failure_threshold( + total_failures, document_count, batch_num, failure ) - if not index_attempt_loop: - raise RuntimeError( - f"Index attempt {index_attempt_id} not found in DB." - ) - if index_attempt_loop.status != IndexingStatus.IN_PROGRESS: - # Likely due to user manually disabling it or model swap - raise RuntimeError( - f"Index Attempt was canceled, status is {index_attempt_loop.status}" - ) + # save the new checkpoint (if one is provided) + if next_checkpoint: + checkpoint = next_checkpoint + + # below is all document processing logic, so if no batch we can just continue + if document_batch is None: + continue batch_description = [] - doc_batch_cleaned = strip_null_characters(doc_batch) + doc_batch_cleaned = strip_null_characters(document_batch) for doc in doc_batch_cleaned: batch_description.append(doc.to_short_descriptor()) @@ -377,15 +491,51 @@ def _run_indexing( chunk_count += index_pipeline_result.total_chunks document_count += index_pipeline_result.total_docs - # commit transaction so that the `update` below begins - # with a brand new transaction. Postgres uses the start - # of the transactions when computing `NOW()`, so if we have - # a long running transaction, the `time_updated` field will - # be inaccurate - db_session.commit() + # resolve errors for documents that were successfully indexed + failed_document_ids = [ + failure.failed_document.document_id + for failure in index_pipeline_result.failures + if failure.failed_document + ] + successful_document_ids = [ + document.id + for document in document_batch + if document.id not in failed_document_ids + ] + for document_id in successful_document_ids: + with get_session_with_tenant(tenant_id) as db_session_temp: + if document_id in doc_id_to_unresolved_errors: + logger.info( + f"Resolving IndexAttemptError for document '{document_id}'" + ) + for error in doc_id_to_unresolved_errors[document_id]: + error.is_resolved = True + db_session_temp.add(error) + db_session_temp.commit() + + # add brand new failures + if index_pipeline_result.failures: + total_failures += len(index_pipeline_result.failures) + with get_session_with_tenant(tenant_id) as db_session_temp: + for failure in index_pipeline_result.failures: + create_index_attempt_error( + index_attempt_id, + ctx.cc_pair_id, + failure, + db_session_temp, + ) + + _check_failure_threshold( + total_failures, + document_count, + batch_num, + index_pipeline_result.failures[-1], + ) # This new value is updated every batch, so UI can refresh per batch update with get_session_with_tenant(tenant_id) as db_session_temp: + # NOTE: Postgres uses the start of the transactions when computing `NOW()` + # so we need either to commit() or to use a new session update_docs_indexed( db_session=db_session_temp, index_attempt_id=index_attempt_id, @@ -397,126 +547,77 @@ def _run_indexing( if callback: callback.progress("_run_indexing", len(doc_batch_cleaned)) - tracer_counter += 1 - if ( - INDEXING_TRACER_INTERVAL > 0 - and tracer_counter % INDEXING_TRACER_INTERVAL == 0 - ): - logger.debug( - f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}" - ) - tracer.snap() - tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES) + memory_tracer.increment_and_maybe_trace() - run_end_dt = window_end - if ctx.is_primary: - with get_session_with_tenant(tenant_id) as db_session_temp: + # `make sure the checkpoints aren't getting too large`at some regular interval + CHECKPOINT_SIZE_CHECK_INTERVAL = 100 + if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0: + check_checkpoint_size(checkpoint) + + # save latest checkpoint + with get_session_with_tenant(tenant_id) as db_session_temp: + save_checkpoint( + db_session=db_session_temp, + index_attempt_id=index_attempt_id, + checkpoint=checkpoint, + ) + + except Exception as e: + logger.exception( + "Connector run exceptioned after elapsed time: " + f"{time.monotonic() - start_time} seconds" + ) + + if isinstance(e, ConnectorStopSignal): + with get_session_with_tenant(tenant_id) as db_session_temp: + mark_attempt_canceled( + index_attempt_id, + db_session_temp, + reason=str(e), + ) + + if ctx.is_primary: update_connector_credential_pair( db_session=db_session_temp, connector_id=ctx.connector_id, credential_id=ctx.credential_id, net_docs=net_doc_change, - run_dt=run_end_dt, - ) - except Exception as e: - logger.exception( - f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds" - ) - - if isinstance(e, ConnectorStopSignal): - with get_session_with_tenant(tenant_id) as db_session_temp: - mark_attempt_canceled( - index_attempt_id, - db_session_temp, - reason=str(e), ) - if ctx.is_primary: - update_connector_credential_pair( - db_session=db_session_temp, - connector_id=ctx.connector_id, - credential_id=ctx.credential_id, - net_docs=net_doc_change, - ) + memory_tracer.stop() + raise e + else: + with get_session_with_tenant(tenant_id) as db_session_temp: + mark_attempt_failed( + index_attempt_id, + db_session_temp, + failure_reason=str(e), + full_exception_trace=traceback.format_exc(), + ) - if INDEXING_TRACER_INTERVAL > 0: - tracer.stop() - raise e - else: - # Only mark the attempt as a complete failure if this is the first indexing window. - # Otherwise, some progress was made - the next run will not start from the beginning. - # In this case, it is not accurate to mark it as a failure. When the next run begins, - # if that fails immediately, it will be marked as a failure. - # - # NOTE: if the connector is manually disabled, we should mark it as a failure regardless - # to give better clarity in the UI, as the next run will never happen. - if ( - ind == 0 - or ( - cc_pair_loop is not None and not cc_pair_loop.status.is_active() - ) - or ( - index_attempt_loop is not None - and index_attempt_loop.status != IndexingStatus.IN_PROGRESS + if ctx.is_primary: + update_connector_credential_pair( + db_session=db_session_temp, + connector_id=ctx.connector_id, + credential_id=ctx.credential_id, + net_docs=net_doc_change, ) - ): - with get_session_with_tenant(tenant_id) as db_session_temp: - mark_attempt_failed( - index_attempt_id, - db_session_temp, - failure_reason=str(e), - full_exception_trace=traceback.format_exc(), - ) - if ctx.is_primary: - update_connector_credential_pair( - db_session=db_session_temp, - connector_id=ctx.connector_id, - credential_id=ctx.credential_id, - net_docs=net_doc_change, - ) - - if INDEXING_TRACER_INTERVAL > 0: - tracer.stop() - raise e - - # break => similar to success case. As mentioned above, if the next run fails for the same - # reason it will then be marked as a failure - break - - if INDEXING_TRACER_INTERVAL > 0: - logger.debug( - f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed." - ) - tracer.snap() - tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES) - tracer.stop() - logger.debug("Memory tracer stopped.") - - if ( - index_attempt_md.num_exceptions > 0 - and index_attempt_md.num_exceptions >= batch_num - ): - with get_session_with_tenant(tenant_id) as db_session_temp: - mark_attempt_failed( - index_attempt_id, - db_session_temp, - failure_reason="All batches exceptioned.", - ) - if ctx.is_primary: - update_connector_credential_pair( - db_session=db_session_temp, - connector_id=ctx.connector_id, - credential_id=ctx.credential_id, - ) - raise Exception( - f"Connector failed - All batches exceptioned: batches={batch_num}" - ) + memory_tracer.stop() + raise e - elapsed_time = time.time() - start_time + memory_tracer.stop() + elapsed_time = time.monotonic() - start_time with get_session_with_tenant(tenant_id) as db_session_temp: - if index_attempt_md.num_exceptions == 0: + # resolve entity-based errors + for error in entity_based_unresolved_errors: + logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'") + error.is_resolved = True + db_session_temp.add(error) + db_session_temp.commit() + + if total_failures == 0: mark_attempt_succeeded(index_attempt_id, db_session_temp) create_milestone_and_report( @@ -535,7 +636,7 @@ def _run_indexing( mark_attempt_partially_succeeded(index_attempt_id, db_session_temp) logger.info( f"Connector completed with some errors: " - f"exceptions={index_attempt_md.num_exceptions} " + f"failures={total_failures} " f"batches={batch_num} " f"docs={document_count} " f"chunks={chunk_count} " @@ -547,7 +648,7 @@ def _run_indexing( db_session=db_session_temp, connector_id=ctx.connector_id, credential_id=ctx.credential_id, - run_dt=run_end_dt, + run_dt=window_end, ) diff --git a/backend/onyx/background/indexing/tracer.py b/backend/onyx/background/indexing/tracer.py deleted file mode 100644 index 0068d4e0213..00000000000 --- a/backend/onyx/background/indexing/tracer.py +++ /dev/null @@ -1,77 +0,0 @@ -import tracemalloc - -from onyx.utils.logger import setup_logger - -logger = setup_logger() - -DANSWER_TRACEMALLOC_FRAMES = 10 - - -class OnyxTracer: - def __init__(self) -> None: - self.snapshot_first: tracemalloc.Snapshot | None = None - self.snapshot_prev: tracemalloc.Snapshot | None = None - self.snapshot: tracemalloc.Snapshot | None = None - - def start(self) -> None: - tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES) - - def stop(self) -> None: - tracemalloc.stop() - - def snap(self) -> None: - snapshot = tracemalloc.take_snapshot() - # Filter out irrelevant frames (e.g., from tracemalloc itself or importlib) - snapshot = snapshot.filter_traces( - ( - tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc - tracemalloc.Filter( - False, "" - ), # Exclude importlib - tracemalloc.Filter( - False, "" - ), # Exclude external importlib - ) - ) - - if not self.snapshot_first: - self.snapshot_first = snapshot - - if self.snapshot: - self.snapshot_prev = self.snapshot - - self.snapshot = snapshot - - def log_snapshot(self, numEntries: int) -> None: - if not self.snapshot: - return - - stats = self.snapshot.statistics("traceback") - for s in stats[:numEntries]: - logger.debug(f"Tracer snap: {s}") - for line in s.traceback: - logger.debug(f"* {line}") - - @staticmethod - def log_diff( - snap_current: tracemalloc.Snapshot, - snap_previous: tracemalloc.Snapshot, - numEntries: int, - ) -> None: - stats = snap_current.compare_to(snap_previous, "traceback") - for s in stats[:numEntries]: - logger.debug(f"Tracer diff: {s}") - for line in s.traceback.format(): - logger.debug(f"* {line}") - - def log_previous_diff(self, numEntries: int) -> None: - if not self.snapshot or not self.snapshot_prev: - return - - OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries) - - def log_first_diff(self, numEntries: int) -> None: - if not self.snapshot or not self.snapshot_first: - return - - OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 315046f7d36..5e87734d222 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -169,6 +169,11 @@ POSTGRES_API_SERVER_POOL_OVERFLOW = int( os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10 ) + +# defaults to False +# generally should only be used for +POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true" + # defaults to False POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true" diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 139f00fdf71..82108acf6e4 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -165,6 +165,9 @@ class DocumentSource(str, Enum): EGNYTE = "egnyte" AIRTABLE = "airtable" + # Special case just for integration tests + MOCK_CONNECTOR = "mock_connector" + DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] @@ -243,6 +246,7 @@ class FileOrigin(str, Enum): CHAT_IMAGE_GEN = "chat_image_gen" CONNECTOR = "connector" GENERATED_REPORT = "generated_report" + INDEXING_CHECKPOINT = "indexing_checkpoint" OTHER = "other" @@ -274,6 +278,7 @@ class OnyxCeleryQueues: DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert" CONNECTOR_DELETION = "connector_deletion" LLM_MODEL_UPDATE = "llm_model_update" + CHECKPOINT_CLEANUP = "checkpoint_cleanup" # Heavy queue CONNECTOR_PRUNING = "connector_pruning" @@ -293,6 +298,7 @@ class OnyxRedisLocks: CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat" CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat" + CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat" CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = ( "da_lock:check_connector_doc_permissions_sync_beat" ) @@ -368,6 +374,10 @@ class OnyxCeleryTask: CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync" CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update" + # Connector checkpoint cleanup + CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup" + CLEANUP_CHECKPOINT = "cleanup_checkpoint" + MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes" MONITOR_CELERY_QUEUES = "monitor_celery_queues" diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 64b73d885eb..8cb48a3d682 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -1,11 +1,16 @@ import sys import time +from collections.abc import Generator from datetime import datetime from onyx.connectors.interfaces import BaseConnector -from onyx.connectors.interfaces import GenerateDocumentsOutput +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document from onyx.utils.logger import setup_logger @@ -15,48 +20,139 @@ TimeRange = tuple[datetime, datetime] +class CheckpointOutputWrapper: + """ + Wraps a CheckpointOutput generator to give things back in a more digestible format. + The connector format is easier for the connector implementor (e.g. it enforces exactly + one new checkpoint is returned AND that the checkpoint is at the end), thus the different + formats. + """ + + def __init__(self) -> None: + self.next_checkpoint: ConnectorCheckpoint | None = None + + def __call__( + self, + checkpoint_connector_generator: CheckpointOutput, + ) -> Generator[ + tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None], + None, + None, + ]: + # grabs the final return value and stores it in the `next_checkpoint` variable + def _inner_wrapper( + checkpoint_connector_generator: CheckpointOutput, + ) -> CheckpointOutput: + self.next_checkpoint = yield from checkpoint_connector_generator + return self.next_checkpoint # not used + + for document_or_failure in _inner_wrapper(checkpoint_connector_generator): + if isinstance(document_or_failure, Document): + yield document_or_failure, None, None + elif isinstance(document_or_failure, ConnectorFailure): + yield None, document_or_failure, None + else: + raise ValueError( + f"Invalid document_or_failure type: {type(document_or_failure)}" + ) + + if self.next_checkpoint is None: + raise RuntimeError( + "Checkpoint is None. This should never happen - the connector should always return a checkpoint." + ) + + yield None, None, self.next_checkpoint + + class ConnectorRunner: + """ + Handles: + - Batching + - Additional exception logging + - Combining different connector types to a single interface + """ + def __init__( self, connector: BaseConnector, + batch_size: int, time_range: TimeRange | None = None, - fail_loudly: bool = False, ): self.connector = connector + self.time_range = time_range + self.batch_size = batch_size + + self.doc_batch: list[Document] = [] + + def run( + self, checkpoint: ConnectorCheckpoint + ) -> Generator[ + tuple[ + list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None + ], + None, + None, + ]: + """Adds additional exception logging to the connector.""" + try: + if isinstance(self.connector, CheckpointConnector): + if self.time_range is None: + raise ValueError("time_range is required for CheckpointConnector") - if isinstance(self.connector, PollConnector): - if time_range is None: - raise ValueError("time_range is required for PollConnector") + start = time.monotonic() + checkpoint_connector_generator = self.connector.load_from_checkpoint( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + checkpoint=checkpoint, + ) + next_checkpoint: ConnectorCheckpoint | None = None + # this is guaranteed to always run at least once with next_checkpoint being non-None + for document, failure, next_checkpoint in CheckpointOutputWrapper()( + checkpoint_connector_generator + ): + if document is not None: + self.doc_batch.append(document) - self.doc_batch_generator = self.connector.poll_source( - time_range[0].timestamp(), time_range[1].timestamp() - ) + if failure is not None: + yield None, failure, None - elif isinstance(self.connector, LoadConnector): - if time_range and fail_loudly: - raise ValueError( - "time_range specified, but passed in connector is not a PollConnector" - ) + if len(self.doc_batch) >= self.batch_size: + yield self.doc_batch, None, None + self.doc_batch = [] - self.doc_batch_generator = self.connector.load_from_state() + # yield remaining documents + if len(self.doc_batch) > 0: + yield self.doc_batch, None, None + self.doc_batch = [] - else: - raise ValueError(f"Invalid connector. type: {type(self.connector)}") + yield None, None, next_checkpoint - def run(self) -> GenerateDocumentsOutput: - """Adds additional exception logging to the connector.""" - try: - start = time.monotonic() - for batch in self.doc_batch_generator: - # to know how long connector is taking logger.debug( - f"Connector took {time.monotonic() - start} seconds to build a batch." + f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint." ) - yield batch + else: + finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + finished_checkpoint.has_more = False - start = time.monotonic() + if isinstance(self.connector, PollConnector): + if self.time_range is None: + raise ValueError("time_range is required for PollConnector") + + for document_batch in self.connector.poll_source( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + ): + yield document_batch, None, None + + yield None, None, finished_checkpoint + elif isinstance(self.connector, LoadConnector): + for document_batch in self.connector.load_from_state(): + yield document_batch, None, None + yield None, None, finished_checkpoint + else: + raise ValueError(f"Invalid connector. type: {type(self.connector)}") except Exception: exc_type, _, exc_traceback = sys.exc_info() diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index d204f3c0bc1..28a67e6a5e6 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -30,12 +30,14 @@ from onyx.connectors.guru.connector import GuruConnector from onyx.connectors.hubspot.connector import HubSpotConnector from onyx.connectors.interfaces import BaseConnector +from onyx.connectors.interfaces import CheckpointConnector from onyx.connectors.interfaces import EventConnector from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.linear.connector import LinearConnector from onyx.connectors.loopio.connector import LoopioConnector from onyx.connectors.mediawiki.wiki import MediaWikiConnector +from onyx.connectors.mock_connector.connector import MockConnector from onyx.connectors.models import InputType from onyx.connectors.notion.connector import NotionConnector from onyx.connectors.onyx_jira.connector import JiraConnector @@ -43,7 +45,7 @@ from onyx.connectors.salesforce.connector import SalesforceConnector from onyx.connectors.sharepoint.connector import SharepointConnector from onyx.connectors.slab.connector import SlabConnector -from onyx.connectors.slack.connector import SlackPollConnector +from onyx.connectors.slack.connector import SlackConnector from onyx.connectors.teams.connector import TeamsConnector from onyx.connectors.web.connector import WebConnector from onyx.connectors.wikipedia.connector import WikipediaConnector @@ -66,8 +68,8 @@ def identify_connector_class( DocumentSource.WEB: WebConnector, DocumentSource.FILE: LocalFileConnector, DocumentSource.SLACK: { - InputType.POLL: SlackPollConnector, - InputType.SLIM_RETRIEVAL: SlackPollConnector, + InputType.POLL: SlackConnector, + InputType.SLIM_RETRIEVAL: SlackConnector, }, DocumentSource.GITHUB: GithubConnector, DocumentSource.GMAIL: GmailConnector, @@ -109,6 +111,8 @@ def identify_connector_class( DocumentSource.FIREFLIES: FirefliesConnector, DocumentSource.EGNYTE: EgnyteConnector, DocumentSource.AIRTABLE: AirtableConnector, + # just for integration tests + DocumentSource.MOCK_CONNECTOR: MockConnector, } connector_by_source = connector_map.get(source, {}) @@ -125,10 +129,23 @@ def identify_connector_class( if any( [ - input_type == InputType.LOAD_STATE - and not issubclass(connector, LoadConnector), - input_type == InputType.POLL and not issubclass(connector, PollConnector), - input_type == InputType.EVENT and not issubclass(connector, EventConnector), + ( + input_type == InputType.LOAD_STATE + and not issubclass(connector, LoadConnector) + ), + ( + input_type == InputType.POLL + # either poll or checkpoint works for this, in the future + # all connectors should be checkpoint connectors + and ( + not issubclass(connector, PollConnector) + and not issubclass(connector, CheckpointConnector) + ) + ), + ( + input_type == InputType.EVENT + and not issubclass(connector, EventConnector) + ), ] ): raise ConnectorMissingException( diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index a0d48f7f171..e92b077ce32 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -1,10 +1,13 @@ import abc +from collections.abc import Generator from collections.abc import Iterator from typing import Any from pydantic import BaseModel from onyx.configs.constants import DocumentSource +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import SlimDocument from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface @@ -14,6 +17,7 @@ GenerateDocumentsOutput = Iterator[list[Document]] GenerateSlimDocumentOutput = Iterator[list[SlimDocument]] +CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint] class BaseConnector(abc.ABC): @@ -105,3 +109,33 @@ class EventConnector(BaseConnector): @abc.abstractmethod def handle_event(self, event: Any) -> GenerateDocumentsOutput: raise NotImplementedError + + +class CheckpointConnector(BaseConnector): + @abc.abstractmethod + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: ConnectorCheckpoint, + ) -> CheckpointOutput: + """Yields back documents or failures. Final return is the new checkpoint. + + Final return can be access via either: + + ``` + try: + for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint): + print(document_or_failure) + except StopIteration as e: + checkpoint = e.value # Extracting the return value + print(checkpoint) + ``` + + OR + + ``` + checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint) + ``` + """ + raise NotImplementedError diff --git a/backend/onyx/connectors/mock_connector/connector.py b/backend/onyx/connectors/mock_connector/connector.py new file mode 100644 index 00000000000..2cd670323ef --- /dev/null +++ b/backend/onyx/connectors/mock_connector/connector.py @@ -0,0 +1,86 @@ +from typing import Any + +import httpx +from pydantic import BaseModel + +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document +from onyx.utils.logger import setup_logger + + +logger = setup_logger() + + +class SingleConnectorYield(BaseModel): + documents: list[Document] + checkpoint: ConnectorCheckpoint + failures: list[ConnectorFailure] + unhandled_exception: str | None = None + + +class MockConnector(CheckpointConnector): + def __init__( + self, + mock_server_host: str, + mock_server_port: int, + ) -> None: + self.mock_server_host = mock_server_host + self.mock_server_port = mock_server_port + self.client = httpx.Client(timeout=30.0) + + self.connector_yields: list[SingleConnectorYield] | None = None + self.current_yield_index: int = 0 + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + response = self.client.get(self._get_mock_server_url("get-documents")) + response.raise_for_status() + data = response.json() + + self.connector_yields = [ + SingleConnectorYield(**yield_data) for yield_data in data + ] + return None + + def _get_mock_server_url(self, endpoint: str) -> str: + return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}" + + def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None: + response = self.client.post( + self._get_mock_server_url("add-checkpoint"), + json=checkpoint.model_dump(mode="json"), + ) + response.raise_for_status() + + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: ConnectorCheckpoint, + ) -> CheckpointOutput: + if self.connector_yields is None: + raise ValueError("No connector yields configured") + + # Save the checkpoint to the mock server + self._save_checkpoint(checkpoint) + + yield_index = self.current_yield_index + self.current_yield_index += 1 + current_yield = self.connector_yields[yield_index] + + # If the current yield has an unhandled exception, raise it + # This is used to simulate an unhandled failure in the connector. + if current_yield.unhandled_exception: + raise RuntimeError(current_yield.unhandled_exception) + + # yield all documents + for document in current_yield.documents: + yield document + + for failure in current_yield.failures: + yield failure + + return current_yield.checkpoint diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 41123318ada..1fd53f0ace1 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -3,6 +3,7 @@ from typing import Any from pydantic import BaseModel +from pydantic import model_validator from onyx.configs.constants import DocumentSource from onyx.configs.constants import INDEX_SEPARATOR @@ -187,36 +188,48 @@ class SlimDocument(BaseModel): perm_sync_data: Any | None = None -class DocumentErrorSummary(BaseModel): - id: str - semantic_id: str - section_link: str | None +class IndexAttemptMetadata(BaseModel): + batch_num: int | None = None + connector_id: int + credential_id: int - @classmethod - def from_document(cls, doc: Document) -> "DocumentErrorSummary": - section_link = doc.sections[0].link if len(doc.sections) > 0 else None - return cls( - id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link - ) + +class ConnectorCheckpoint(BaseModel): + # TODO: maybe move this to something disk-based to handle extremely large checkpoints? + checkpoint_content: dict + has_more: bool @classmethod - def from_dict(cls, data: dict) -> "DocumentErrorSummary": - return cls( - id=str(data.get("id")), - semantic_id=str(data.get("semantic_id")), - section_link=str(data.get("section_link")), - ) + def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint": + return ConnectorCheckpoint(checkpoint_content={}, has_more=True) - def to_dict(self) -> dict[str, str | None]: - return { - "id": self.id, - "semantic_id": self.semantic_id, - "section_link": self.section_link, - } +class DocumentFailure(BaseModel): + document_id: str + document_link: str | None = None -class IndexAttemptMetadata(BaseModel): - batch_num: int | None = None - num_exceptions: int = 0 - connector_id: int - credential_id: int + +class EntityFailure(BaseModel): + entity_id: str + missed_time_range: tuple[datetime, datetime] | None = None + + +class ConnectorFailure(BaseModel): + failed_document: DocumentFailure | None = None + failed_entity: EntityFailure | None = None + failure_message: str + exception: Exception | None = None + + model_config = {"arbitrary_types_allowed": True} + + @model_validator(mode="before") + def check_failed_fields(cls, values: dict) -> dict: + failed_document = values.get("failed_document") + failed_entity = values.get("failed_entity") + if (failed_document is None and failed_entity is None) or ( + failed_document is not None and failed_entity is not None + ): + raise ValueError( + "Exactly one of 'failed_document' or 'failed_entity' must be specified." + ) + return values diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index 865cd8f6272..940a2a728d6 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -1,10 +1,16 @@ +import contextvars +import copy import re from collections.abc import Callable from collections.abc import Generator +from concurrent.futures import as_completed +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from datetime import timezone from typing import Any from typing import cast +from typing import TypedDict from slack_sdk import WebClient from slack_sdk.errors import SlackApiError @@ -12,14 +18,18 @@ from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource -from onyx.connectors.interfaces import GenerateDocumentsOutput +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput -from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector from onyx.connectors.models import BasicExpertInfo +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure +from onyx.connectors.models import EntityFailure from onyx.connectors.models import Section from onyx.connectors.models import SlimDocument from onyx.connectors.slack.utils import expert_info_from_slack_id @@ -33,6 +43,8 @@ logger = setup_logger() +_SLACK_LIMIT = 900 + ChannelType = dict[str, Any] MessageType = dict[str, Any] @@ -40,6 +52,13 @@ ThreadType = list[MessageType] +class SlackCheckpointContent(TypedDict): + channel_ids: list[str] + channel_completion_map: dict[str, str] + current_channel: ChannelType | None + seen_thread_ts: list[str] + + def _collect_paginated_channels( client: WebClient, exclude_archived: bool, @@ -140,6 +159,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime: return datetime.fromtimestamp(max_ts, tz=timezone.utc) +def _build_doc_id(channel_id: str, thread_ts: str) -> str: + return f"{channel_id}__{thread_ts}" + + def thread_to_doc( channel: ChannelType, thread: ThreadType, @@ -182,7 +205,7 @@ def thread_to_doc( ) return Document( - id=f"{channel_id}__{thread[0]['ts']}", + id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]), sections=[ Section( link=get_message_link(event=m, client=client, channel_id=channel_id), @@ -267,64 +290,97 @@ def filter_channels( ] -def _get_all_docs( +def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType: + """Get a channel by its ID. + + Args: + client: The Slack WebClient instance + channel_id: The ID of the channel to fetch + + Returns: + The channel information + + Raises: + SlackApiError: If the channel cannot be fetched + """ + response = make_slack_api_call_w_retries( + client.conversations_info, + channel=channel_id, + ) + return cast(ChannelType, response["channel"]) + + +def _get_messages( + channel: ChannelType, client: WebClient, - channels: list[str] | None = None, - channel_name_regex_enabled: bool = False, oldest: str | None = None, latest: str | None = None, - msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, -) -> Generator[Document, None, None]: - """Get all documents in the workspace, channel by channel""" - slack_cleaner = SlackTextCleaner(client=client) +) -> tuple[list[MessageType], bool]: + """Slack goes from newest to oldest.""" - # Cache to prevent refetching via API since users - user_cache: dict[str, BasicExpertInfo | None] = {} + # have to be in the channel in order to read messages + if not channel["is_member"]: + make_slack_api_call_w_retries( + client.conversations_join, + channel=channel["id"], + is_private=channel["is_private"], + ) + logger.info(f"Successfully joined '{channel['name']}'") - all_channels = get_channels(client) - filtered_channels = filter_channels( - all_channels, channels, channel_name_regex_enabled + response = make_slack_api_call_w_retries( + client.conversations_history, + channel=channel["id"], + oldest=oldest, + latest=latest, + limit=_SLACK_LIMIT, ) + response.validate() - for channel in filtered_channels: - channel_docs = 0 - channel_message_batches = get_channel_messages( - client=client, channel=channel, oldest=oldest, latest=latest - ) + messages = cast(list[MessageType], response.get("messages", [])) + + cursor = cast(dict[str, Any], response.get("response_metadata", {})).get( + "next_cursor", "" + ) + has_more = bool(cursor) + return messages, has_more - seen_thread_ts: set[str] = set() - for message_batch in channel_message_batches: - for message in message_batch: - filtered_thread: ThreadType | None = None - thread_ts = message.get("thread_ts") - if thread_ts: - # skip threads we've already seen, since we've already processed all - # messages in that thread - if thread_ts in seen_thread_ts: - continue - seen_thread_ts.add(thread_ts) - thread = get_thread( - client=client, channel_id=channel["id"], thread_id=thread_ts - ) - filtered_thread = [ - message for message in thread if not msg_filter_func(message) - ] - elif not msg_filter_func(message): - filtered_thread = [message] - - if filtered_thread: - channel_docs += 1 - yield thread_to_doc( - channel=channel, - thread=filtered_thread, - slack_cleaner=slack_cleaner, - client=client, - user_cache=user_cache, - ) - logger.info( - f"Pulled {channel_docs} documents from slack channel {channel['name']}" +def _message_to_doc( + message: MessageType, + client: WebClient, + channel: ChannelType, + slack_cleaner: SlackTextCleaner, + user_cache: dict[str, BasicExpertInfo | None], + seen_thread_ts: set[str], + msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, +) -> Document | None: + filtered_thread: ThreadType | None = None + thread_ts = message.get("thread_ts") + if thread_ts: + # skip threads we've already seen, since we've already processed all + # messages in that thread + if thread_ts in seen_thread_ts: + return None + + thread = get_thread( + client=client, channel_id=channel["id"], thread_id=thread_ts ) + filtered_thread = [ + message for message in thread if not msg_filter_func(message) + ] + elif not msg_filter_func(message): + filtered_thread = [message] + + if filtered_thread: + return thread_to_doc( + channel=channel, + thread=filtered_thread, + slack_cleaner=slack_cleaner, + client=client, + user_cache=user_cache, + ) + + return None def _get_all_doc_ids( @@ -368,7 +424,7 @@ def _get_all_doc_ids( for message_ts in message_ts_set: channel_metadata_list.append( SlimDocument( - id=f"{channel_id}__{message_ts}", + id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts), perm_sync_data={"channel_id": channel_id}, ) ) @@ -376,7 +432,51 @@ def _get_all_doc_ids( yield channel_metadata_list -class SlackPollConnector(PollConnector, SlimConnector): +def _process_message( + message: MessageType, + client: WebClient, + channel: ChannelType, + slack_cleaner: SlackTextCleaner, + user_cache: dict[str, BasicExpertInfo | None], + seen_thread_ts: set[str], + msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, +) -> tuple[Document | None, str | None, ConnectorFailure | None]: + thread_ts = message.get("thread_ts") + try: + # causes random failures for testing checkpointing / continue on failure + # import random + # if random.random() > 0.95: + # raise RuntimeError("Random failure :P") + + doc = _message_to_doc( + message=message, + client=client, + channel=channel, + slack_cleaner=slack_cleaner, + user_cache=user_cache, + seen_thread_ts=seen_thread_ts, + msg_filter_func=msg_filter_func, + ) + return (doc, thread_ts, None) + except Exception as e: + logger.exception(f"Error processing message {message['ts']}") + return ( + None, + thread_ts, + ConnectorFailure( + failed_document=DocumentFailure( + document_id=_build_doc_id( + channel_id=channel["id"], thread_ts=(thread_ts or message["ts"]) + ), + document_link=get_message_link(message, client, channel["id"]), + ), + failure_message=str(e), + exception=e, + ), + ) + + +class SlackConnector(SlimConnector, CheckpointConnector): def __init__( self, channels: list[str] | None = None, @@ -390,9 +490,14 @@ def __init__( self.batch_size = batch_size self.client: WebClient | None = None + # just used for efficiency + self.text_cleaner: SlackTextCleaner | None = None + self.user_cache: dict[str, BasicExpertInfo | None] = {} + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: bot_token = credentials["slack_bot_token"] self.client = WebClient(token=bot_token) + self.text_cleaner = SlackTextCleaner(client=self.client) return None def retrieve_all_slim_documents( @@ -411,30 +516,155 @@ def retrieve_all_slim_documents( callback=callback, ) - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: - if self.client is None: + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: ConnectorCheckpoint, + ) -> CheckpointOutput: + """Rough outline: + + Step 1: Get all channels, yield back Checkpoint. + Step 2: Loop through each channel. For each channel: + Step 2.1: Get messages within the time range. + Step 2.2: Process messages in parallel, yield back docs. + Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel. + Slack returns messages from newest to oldest, so we need to keep track of + the latest message we've seen in each channel. + Step 2.4: If there are no more messages in the channel, switch the current + channel to the next channel. + """ + if self.client is None or self.text_cleaner is None: raise ConnectorMissingCredentialError("Slack") - documents: list[Document] = [] - for document in _get_all_docs( - client=self.client, - channels=self.channels, - channel_name_regex_enabled=self.channel_regex_enabled, - # NOTE: need to impute to `None` instead of using 0.0, since Slack will - # throw an error if we use 0.0 on an account without infinite data - # retention - oldest=str(start) if start else None, - latest=str(end), - ): - documents.append(document) - if len(documents) >= self.batch_size: - yield documents - documents = [] + checkpoint_content = cast( + SlackCheckpointContent, + ( + copy.deepcopy(checkpoint.checkpoint_content) + or { + "channel_ids": None, + "channel_completion_map": {}, + "current_channel": None, + "seen_thread_ts": [], + } + ), + ) + + # if this is the very first time we've called this, need to + # get all relevant channels and save them into the checkpoint + if checkpoint_content["channel_ids"] is None: + raw_channels = get_channels(self.client) + filtered_channels = filter_channels( + raw_channels, self.channels, self.channel_regex_enabled + ) + if len(filtered_channels) == 0: + return checkpoint + + checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels] + checkpoint_content["current_channel"] = filtered_channels[0] + checkpoint = ConnectorCheckpoint( + checkpoint_content=checkpoint_content, # type: ignore + has_more=True, + ) + return checkpoint - if documents: - yield documents + final_channel_ids = checkpoint_content["channel_ids"] + channel = checkpoint_content["current_channel"] + if channel is None: + raise ValueError("current_channel key not found in checkpoint") + + channel_id = channel["id"] + if channel_id not in final_channel_ids: + raise ValueError(f"Channel {channel_id} not found in checkpoint") + + oldest = str(start) if start else None + latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end)) + seen_thread_ts = set(checkpoint_content["seen_thread_ts"]) + try: + logger.debug( + f"Getting messages for channel {channel} within range {oldest} - {latest}" + ) + message_batch, has_more_in_channel = _get_messages( + channel, self.client, oldest, latest + ) + new_latest = message_batch[-1]["ts"] if message_batch else latest + + # Process messages in parallel using ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=8) as executor: + futures: list[Future] = [] + for message in message_batch: + # Capture the current context so that the thread gets the current tenant ID + current_context = contextvars.copy_context() + futures.append( + executor.submit( + current_context.run, + _process_message, + message=message, + client=self.client, + channel=channel, + slack_cleaner=self.text_cleaner, + user_cache=self.user_cache, + seen_thread_ts=seen_thread_ts, + ) + ) + + for future in as_completed(futures): + doc, thread_ts, failures = future.result() + if doc: + # handle race conditions here since this is single + # threaded. Multi-threaded _process_message reads from this + # but since this is single threaded, we won't run into simul + # writes. At worst, we can duplicate a thread, which will be + # deduped later on. + if thread_ts not in seen_thread_ts: + yield doc + + if thread_ts: + seen_thread_ts.add(thread_ts) + elif failures: + for failure in failures: + yield failure + + checkpoint_content["seen_thread_ts"] = list(seen_thread_ts) + checkpoint_content["channel_completion_map"][channel["id"]] = new_latest + if has_more_in_channel: + checkpoint_content["current_channel"] = channel + else: + new_channel_id = next( + ( + channel_id + for channel_id in final_channel_ids + if channel_id + not in checkpoint_content["channel_completion_map"] + ), + None, + ) + if new_channel_id: + new_channel = _get_channel_by_id(self.client, new_channel_id) + checkpoint_content["current_channel"] = new_channel + else: + checkpoint_content["current_channel"] = None + + checkpoint = ConnectorCheckpoint( + checkpoint_content=checkpoint_content, # type: ignore + has_more=checkpoint_content["current_channel"] is not None, + ) + return checkpoint + + except Exception as e: + logger.exception(f"Error processing channel {channel['name']}") + yield ConnectorFailure( + failed_entity=EntityFailure( + entity_id=channel["id"], + missed_time_range=( + datetime.fromtimestamp(start, tz=timezone.utc), + datetime.fromtimestamp(end, tz=timezone.utc), + ), + ), + failure_message=str(e), + exception=e, + ) + return checkpoint if __name__ == "__main__": @@ -442,7 +672,7 @@ def poll_source( import time slack_channel = os.environ.get("SLACK_CHANNEL") - connector = SlackPollConnector( + connector = SlackConnector( channels=[slack_channel] if slack_channel else None, ) connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]}) @@ -450,6 +680,17 @@ def poll_source( current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day - document_batches = connector.poll_source(one_day_ago, current) + checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() - print(next(document_batches)) + gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint) + try: + for document_or_failure in gen: + if isinstance(document_or_failure, Document): + print(document_or_failure) + elif isinstance(document_or_failure, ConnectorFailure): + print(document_or_failure) + except StopIteration as e: + checkpoint = e.value + print("Next checkpoint:", checkpoint) + + print("Next checkpoint:", checkpoint) diff --git a/backend/onyx/connectors/slack/utils.py b/backend/onyx/connectors/slack/utils.py index d992eb55620..8428a453489 100644 --- a/backend/onyx/connectors/slack/utils.py +++ b/backend/onyx/connectors/slack/utils.py @@ -34,9 +34,14 @@ def get_message_link( ) -> str: channel_id = channel_id or event["channel"] message_ts = event["ts"] - response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts) - permalink = response["permalink"] - return permalink + message_ts_without_dot = message_ts.replace(".", "") + thread_ts = event.get("thread_ts") + base_url = get_base_url(client.token) + + link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + ( + f"?thread_ts={thread_ts}" if thread_ts else "" + ) + return link def _make_slack_api_call_paginated( diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index b62530b653f..da7995089da 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -18,6 +18,7 @@ from fastapi import HTTPException from fastapi import Request from sqlalchemy import event +from sqlalchemy import pool from sqlalchemy import text from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine @@ -39,6 +40,7 @@ from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE from onyx.configs.app_configs import POSTGRES_PORT +from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL from onyx.configs.app_configs import POSTGRES_USER from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME from onyx.configs.constants import SSL_CERT_FILE @@ -187,20 +189,38 @@ class SqlEngine: _engine: Engine | None = None _lock: threading.Lock = threading.Lock() _app_name: str = POSTGRES_UNKNOWN_APP_NAME - DEFAULT_ENGINE_KWARGS = { - "pool_size": 20, - "max_overflow": 5, - "pool_pre_ping": POSTGRES_POOL_PRE_PING, - "pool_recycle": POSTGRES_POOL_RECYCLE, - } @classmethod def _init_engine(cls, **engine_kwargs: Any) -> Engine: connection_string = build_connection_string( db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH ) - merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs} - engine = create_engine(connection_string, **merged_kwargs) + + # Start with base kwargs that are valid for all pool types + final_engine_kwargs: dict[str, Any] = {} + + if POSTGRES_USE_NULL_POOL: + # if null pool is specified, then we need to make sure that + # we remove any passed in kwargs related to pool size that would + # cause the initialization to fail + final_engine_kwargs.update(engine_kwargs) + + final_engine_kwargs["poolclass"] = pool.NullPool + if "pool_size" in final_engine_kwargs: + del final_engine_kwargs["pool_size"] + if "max_overflow" in final_engine_kwargs: + del final_engine_kwargs["max_overflow"] + else: + final_engine_kwargs["pool_size"] = 20 + final_engine_kwargs["max_overflow"] = 5 + final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING + final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE + + # any passed in kwargs override the defaults + final_engine_kwargs.update(engine_kwargs) + + logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") + engine = create_engine(connection_string, **final_engine_kwargs) if USE_IAM_AUTH: event.listen(engine, "do_connect", provide_iam_token) @@ -299,13 +319,21 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: connect_args["ssl"] = ssl_context + engine_kwargs = { + "connect_args": connect_args, + "pool_pre_ping": POSTGRES_POOL_PRE_PING, + "pool_recycle": POSTGRES_POOL_RECYCLE, + } + + if POSTGRES_USE_NULL_POOL: + engine_kwargs["poolclass"] = pool.NullPool + else: + engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE + engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW + _ASYNC_ENGINE = create_async_engine( connection_string, - connect_args=connect_args, - pool_size=POSTGRES_API_SERVER_POOL_SIZE, - max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, - pool_pre_ping=POSTGRES_POOL_PRE_PING, - pool_recycle=POSTGRES_POOL_RECYCLE, + **engine_kwargs, ) if USE_IAM_AUTH: diff --git a/backend/onyx/db/index_attempt.py b/backend/onyx/db/index_attempt.py index eabc4d5df09..092f1f9704b 100644 --- a/backend/onyx/db/index_attempt.py +++ b/backend/onyx/db/index_attempt.py @@ -11,8 +11,7 @@ from sqlalchemy import update from sqlalchemy.orm import Session -from onyx.connectors.models import Document -from onyx.connectors.models import DocumentErrorSummary +from onyx.connectors.models import ConnectorFailure from onyx.db.models import IndexAttempt from onyx.db.models import IndexAttemptError from onyx.db.models import IndexingStatus @@ -41,6 +40,27 @@ def get_last_attempt_for_cc_pair( ) +def get_recent_completed_attempts_for_cc_pair( + cc_pair_id: int, + search_settings_id: int, + limit: int, + db_session: Session, +) -> list[IndexAttempt]: + return ( + db_session.query(IndexAttempt) + .filter( + IndexAttempt.connector_credential_pair_id == cc_pair_id, + IndexAttempt.search_settings_id == search_settings_id, + IndexAttempt.status.notin_( + [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] + ), + ) + .order_by(IndexAttempt.time_updated.desc()) + .limit(limit) + .all() + ) + + def get_index_attempt( db_session: Session, index_attempt_id: int ) -> IndexAttempt | None: @@ -615,23 +635,32 @@ def count_unique_cc_pairs_with_successful_index_attempts( def create_index_attempt_error( index_attempt_id: int | None, - batch: int | None, - docs: list[Document], - exception_msg: str, - exception_traceback: str, + connector_credential_pair_id: int, + failure: ConnectorFailure, db_session: Session, ) -> int: - doc_summaries = [] - for doc in docs: - doc_summary = DocumentErrorSummary.from_document(doc) - doc_summaries.append(doc_summary.to_dict()) - new_error = IndexAttemptError( index_attempt_id=index_attempt_id, - batch=batch, - doc_summaries=doc_summaries, - error_msg=exception_msg, - traceback=exception_traceback, + connector_credential_pair_id=connector_credential_pair_id, + document_id=( + failure.failed_document.document_id if failure.failed_document else None + ), + document_link=( + failure.failed_document.document_link if failure.failed_document else None + ), + entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None), + failed_time_range_start=( + failure.failed_entity.missed_time_range[0] + if failure.failed_entity and failure.failed_entity.missed_time_range + else None + ), + failed_time_range_end=( + failure.failed_entity.missed_time_range[1] + if failure.failed_entity and failure.failed_entity.missed_time_range + else None + ), + failure_message=failure.failure_message, + is_resolved=False, ) db_session.add(new_error) db_session.commit() @@ -649,3 +678,42 @@ def get_index_attempt_errors( errors = db_session.scalars(stmt) return list(errors.all()) + + +def count_index_attempt_errors_for_cc_pair( + cc_pair_id: int, + unresolved_only: bool, + db_session: Session, +) -> int: + stmt = ( + select(func.count()) + .select_from(IndexAttemptError) + .where(IndexAttemptError.connector_credential_pair_id == cc_pair_id) + ) + if unresolved_only: + stmt = stmt.where(IndexAttemptError.is_resolved.is_(False)) + + result = db_session.scalar(stmt) + return 0 if result is None else result + + +def get_index_attempt_errors_for_cc_pair( + cc_pair_id: int, + unresolved_only: bool, + db_session: Session, + page: int | None = None, + page_size: int | None = None, +) -> list[IndexAttemptError]: + stmt = select(IndexAttemptError).where( + IndexAttemptError.connector_credential_pair_id == cc_pair_id + ) + if unresolved_only: + stmt = stmt.where(IndexAttemptError.is_resolved.is_(False)) + + # Order by most recent first + stmt = stmt.order_by(desc(IndexAttemptError.time_created)) + + if page is not None and page_size is not None: + stmt = stmt.offset(page * page_size).limit(page_size) + + return list(db_session.scalars(stmt).all()) diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 1187f7aeb4c..5e18e016c99 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -827,6 +827,19 @@ class IndexAttempt(Base): nullable=True, ) + # for polling connectors, the start and end time of the poll window + # will be set when the index attempt starts + poll_range_start: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, default=None + ) + poll_range_end: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, default=None + ) + + # Points to the last checkpoint that was saved for this run. The pointer here + # can be taken to the FileStore to grab the actual checkpoint value + checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True) + time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -870,6 +883,13 @@ class IndexAttempt(Base): desc("time_updated"), unique=False, ), + Index( + "ix_index_attempt_cc_pair_settings_poll", + "connector_credential_pair_id", + "search_settings_id", + "status", + desc("time_updated"), + ), ) def __repr__(self) -> str: @@ -886,25 +906,33 @@ def is_finished(self) -> bool: class IndexAttemptError(Base): - """ - Represents an error that was encountered during an IndexAttempt. - """ - __tablename__ = "index_attempt_errors" id: Mapped[int] = mapped_column(primary_key=True) index_attempt_id: Mapped[int] = mapped_column( ForeignKey("index_attempt.id"), - nullable=True, + nullable=False, + ) + connector_credential_pair_id: Mapped[int] = mapped_column( + ForeignKey("connector_credential_pair.id"), + nullable=False, ) - # The index of the batch where the error occurred (if looping thru batches) - # Just informational. - batch: Mapped[int | None] = mapped_column(Integer, default=None) - doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB()) - error_msg: Mapped[str | None] = mapped_column(Text, default=None) - traceback: Mapped[str | None] = mapped_column(Text, default=None) + document_id: Mapped[str | None] = mapped_column(String, nullable=True) + document_link: Mapped[str | None] = mapped_column(String, nullable=True) + + entity_id: Mapped[str | None] = mapped_column(String, nullable=True) + failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + failure_message: Mapped[str] = mapped_column(Text) + is_resolved: Mapped[bool] = mapped_column(Boolean, default=False) + time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -913,21 +941,6 @@ class IndexAttemptError(Base): # This is the reverse side of the relationship index_attempt = relationship("IndexAttempt", back_populates="error_rows") - __table_args__ = ( - Index( - "index_attempt_id", - "time_created", - ), - ) - - def __repr__(self) -> str: - return ( - f"" - f"time_created={self.time_created!r}, " - ) - class SyncRecord(Base): """ diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index f2e4037d913..a692827c5f8 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -1,6 +1,10 @@ +import time from abc import ABC from abc import abstractmethod +from collections import defaultdict +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import DocumentFailure from onyx.db.models import SearchSettings from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import ChunkEmbedding @@ -217,3 +221,49 @@ def from_db_search_settings( deployment_name=search_settings.deployment_name, callback=callback, ) + + +def embed_chunks_with_failure_handling( + chunks: list[DocAwareChunk], + embedder: IndexingEmbedder, +) -> tuple[list[IndexChunk], list[ConnectorFailure]]: + """Tries to embed all chunks in one large batch. If that batch fails for any reason, + goes document by document to isolate the failure(s). + """ + + # First try to embed all chunks in one batch + try: + return embedder.embed_chunks(chunks=chunks), [] + except Exception: + logger.exception("Failed to embed chunk batch. Trying individual docs.") + # wait a couple seconds to let any rate limits or temporary issues resolve + time.sleep(2) + + # Try embedding each document's chunks individually + chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list) + for chunk in chunks: + chunks_by_doc[chunk.source_document.id].append(chunk) + + embedded_chunks: list[IndexChunk] = [] + failures: list[ConnectorFailure] = [] + + for doc_id, chunks_for_doc in chunks_by_doc.items(): + try: + doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc) + embedded_chunks.extend(doc_embedded_chunks) + except Exception as e: + logger.exception(f"Failed to embed chunks for document '{doc_id}'") + failures.append( + ConnectorFailure( + failed_document=DocumentFailure( + document_id=doc_id, + document_link=( + chunks_for_doc[0].get_link() if chunks_for_doc else None + ), + ), + failure_message=str(e), + exception=e, + ) + ) + + return embedded_chunks, failures diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index e965a5ca4a8..d750a5dcff9 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -1,23 +1,21 @@ -import traceback from collections.abc import Callable from functools import partial -from http import HTTPStatus from typing import Protocol -import httpx from pydantic import BaseModel from pydantic import ConfigDict from sqlalchemy.orm import Session from onyx.access.access import get_access_for_documents from onyx.access.models import DocumentAccess -from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.constants import DEFAULT_BOOST from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, ) +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure from onyx.connectors.models import IndexAttemptMetadata from onyx.db.document import fetch_chunk_counts_for_documents from onyx.db.document import get_documents_by_ids @@ -29,7 +27,6 @@ from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.document import upsert_documents from onyx.db.document_set import fetch_document_sets_for_documents -from onyx.db.index_attempt import create_index_attempt_error from onyx.db.models import Document as DBDocument from onyx.db.search_settings import get_current_search_settings from onyx.db.tag import create_or_add_document_tag @@ -41,10 +38,12 @@ from onyx.document_index.interfaces import DocumentMetadata from onyx.document_index.interfaces import IndexBatchParams from onyx.indexing.chunker import Chunker +from onyx.indexing.embedder import embed_chunks_with_failure_handling from onyx.indexing.embedder import IndexingEmbedder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk +from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time @@ -67,6 +66,8 @@ class IndexingPipelineResult(BaseModel): # number of chunks that were inserted into Vespa total_chunks: int + failures: list[ConnectorFailure] + class IndexingPipelineProtocol(Protocol): def __call__( @@ -156,14 +157,10 @@ def index_doc_batch_with_handler( document_index: DocumentIndex, document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, - attempt_id: int | None, db_session: Session, ignore_time_skip: bool = False, tenant_id: str | None = None, ) -> IndexingPipelineResult: - index_pipeline_result = IndexingPipelineResult( - new_docs=0, total_docs=len(document_batch), total_chunks=0 - ) try: index_pipeline_result = index_doc_batch( chunker=chunker, @@ -176,47 +173,25 @@ def index_doc_batch_with_handler( tenant_id=tenant_id, ) except Exception as e: - if isinstance(e, httpx.HTTPStatusError): - if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE: - logger.error( - "NOTE: HTTP Status 507 Insufficient Storage indicates " - "you need to allocate more memory or disk space to the " - "Vespa/index container." + logger.exception(f"Failed to index document batch: {document_batch}") + index_pipeline_result = IndexingPipelineResult( + new_docs=0, + total_docs=len(document_batch), + total_chunks=0, + failures=[ + ConnectorFailure( + failed_document=DocumentFailure( + document_id=document.id, + document_link=( + document.sections[0].link if document.sections else None + ), + ), + failure_message=str(e), + exception=e, ) - - if INDEXING_EXCEPTION_LIMIT == 0: - raise - - trace = traceback.format_exc() - create_index_attempt_error( - attempt_id, - batch=index_attempt_metadata.batch_num, - docs=document_batch, - exception_msg=str(e), - exception_traceback=trace, - db_session=db_session, + for document in document_batch + ], ) - logger.exception( - f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'" - ) - - index_attempt_metadata.num_exceptions += 1 - if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT: - logger.warning( - f"Maximum number of exceptions for this index attempt " - f"({INDEXING_EXCEPTION_LIMIT}) has been reached. " - f"The next exception will abort the indexing attempt." - ) - elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT: - logger.warning( - f"Maximum number of exceptions for this index attempt " - f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded." - ) - raise RuntimeError( - f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded." - ) - else: - pass return index_pipeline_result @@ -376,8 +351,12 @@ def index_doc_batch( document_ids=[doc.id for doc in filtered_documents], db_session=db_session, ) + db_session.commit() return IndexingPipelineResult( - new_docs=0, total_docs=len(filtered_documents), total_chunks=0 + new_docs=0, + total_docs=len(filtered_documents), + total_chunks=0, + failures=[], ) doc_descriptors = [ @@ -390,10 +369,19 @@ def index_doc_batch( logger.debug(f"Starting indexing process for documents: {doc_descriptors}") logger.debug("Starting chunking") + # NOTE: no special handling for failures here, since the chunker is not + # a common source of failure for the indexing pipeline chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs) logger.debug("Starting embedding") - chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else [] + chunks_with_embeddings, embedding_failures = ( + embed_chunks_with_failure_handling( + chunks=chunks, + embedder=embedder, + ) + if chunks + else ([], []) + ) updatable_ids = [doc.id for doc in ctx.updatable_docs] @@ -459,7 +447,11 @@ def index_doc_batch( # A document will not be spread across different batches, so all the # documents with chunks in this set, are fully represented by the chunks # in this set - insertion_records = document_index.index( + ( + insertion_records, + vector_db_write_failures, + ) = write_chunks_to_vector_db_with_backoff( + document_index=document_index, chunks=access_aware_chunks, index_batch_params=IndexBatchParams( doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt, @@ -519,6 +511,7 @@ def index_doc_batch( new_docs=len([r for r in insertion_records if r.already_existed is False]), total_docs=len(filtered_documents), total_chunks=len(access_aware_chunks), + failures=vector_db_write_failures + embedding_failures, ) return result @@ -531,7 +524,6 @@ def build_indexing_pipeline( db_session: Session, chunker: Chunker | None = None, ignore_time_skip: bool = False, - attempt_id: int | None = None, tenant_id: str | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> IndexingPipelineProtocol: @@ -553,7 +545,6 @@ def build_indexing_pipeline( embedder=embedder, document_index=document_index, ignore_time_skip=ignore_time_skip, - attempt_id=attempt_id, db_session=db_session, tenant_id=tenant_id, ) diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 753ec9ad916..f62a29f1381 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -57,6 +57,13 @@ def to_short_descriptor(self) -> str: """Used when logging the identity of a chunk""" return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}" + def get_link(self) -> str | None: + return ( + self.source_document.sections[0].link + if self.source_document.sections + else None + ) + class IndexChunk(DocAwareChunk): embeddings: ChunkEmbedding diff --git a/backend/onyx/indexing/vector_db_insertion.py b/backend/onyx/indexing/vector_db_insertion.py new file mode 100644 index 00000000000..b9cfa645e7c --- /dev/null +++ b/backend/onyx/indexing/vector_db_insertion.py @@ -0,0 +1,99 @@ +import time +from collections import defaultdict +from http import HTTPStatus + +import httpx + +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import DocumentFailure +from onyx.document_index.interfaces import DocumentIndex +from onyx.document_index.interfaces import DocumentInsertionRecord +from onyx.document_index.interfaces import IndexBatchParams +from onyx.indexing.models import DocMetadataAwareIndexChunk +from onyx.utils.logger import setup_logger + + +logger = setup_logger() + + +def _log_insufficient_storage_error(e: Exception) -> None: + if isinstance(e, httpx.HTTPStatusError): + if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE: + logger.error( + "NOTE: HTTP Status 507 Insufficient Storage indicates " + "you need to allocate more memory or disk space to the " + "Vespa/index container." + ) + + +def write_chunks_to_vector_db_with_backoff( + document_index: DocumentIndex, + chunks: list[DocMetadataAwareIndexChunk], + index_batch_params: IndexBatchParams, +) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]: + """Tries to insert all chunks in one large batch. If that batch fails for any reason, + goes document by document to isolate the failure(s). + + IMPORTANT: must pass in whole documents at a time not individual chunks, since the + vector DB interface assumes that all chunks for a single document are present. + """ + + # first try to write the chunks to the vector db + try: + return ( + list( + document_index.index( + chunks=chunks, + index_batch_params=index_batch_params, + ) + ), + [], + ) + except Exception as e: + logger.exception( + "Failed to write chunk batch to vector db. Trying individual docs." + ) + + # give some specific logging on this common failure case. + _log_insufficient_storage_error(e) + + # wait a couple seconds just to give the vector db a chance to recover + time.sleep(2) + + # try writing each doc one by one + chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list) + for chunk in chunks: + chunks_for_docs[chunk.source_document.id].append(chunk) + + insertion_records: list[DocumentInsertionRecord] = [] + failures: list[ConnectorFailure] = [] + for doc_id, chunks_for_doc in chunks_for_docs.items(): + try: + insertion_records.extend( + document_index.index( + chunks=chunks_for_doc, + index_batch_params=index_batch_params, + ) + ) + except Exception as e: + logger.exception( + f"Failed to write document chunks for '{doc_id}' to vector db" + ) + + # give some specific logging on this common failure case. + _log_insufficient_storage_error(e) + + failures.append( + ConnectorFailure( + failed_document=DocumentFailure( + document_id=doc_id, + document_link=( + chunks_for_doc[0].get_link() if chunks_for_doc else None + ), + ), + failure_message=str(e), + exception=e, + ) + ) + + return insertion_records, failures diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 594870b363f..9f77c4cdf79 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -51,7 +51,6 @@ from onyx.server.documents.connector import router as connector_router from onyx.server.documents.credential import router as credential_router from onyx.server.documents.document import router as document_router -from onyx.server.documents.indexing import router as indexing_router from onyx.server.documents.standard_oauth import router as oauth_router from onyx.server.features.document_set.api import router as document_set_router from onyx.server.features.folder.api import router as folder_router @@ -317,7 +316,6 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router ) - include_router_with_global_prefix_prepended(application, indexing_router) include_router_with_global_prefix_prepended( application, get_full_openai_assistants_api_router() ) diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index 59e930dcd95..88742ba30a0 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -22,6 +22,7 @@ try_creating_prune_generator_task, ) from onyx.background.celery.versioned_apps.primary import app as primary_app +from onyx.background.indexing.models import IndexAttemptErrorPydantic from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask from onyx.db.connector_credential_pair import add_credential_to_connector @@ -39,7 +40,9 @@ from onyx.db.engine import get_session from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus +from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair from onyx.db.index_attempt import count_index_attempts_for_connector +from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from onyx.db.models import SearchSettings @@ -546,6 +549,47 @@ def get_docs_sync_status( return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair] +@router.get("/admin/cc-pair/{cc_pair_id}/errors") +def get_cc_pair_indexing_errors( + cc_pair_id: int, + include_resolved: bool = Query(False), + page: int = Query(0, ge=0), + page_size: int = Query(10, ge=1, le=100), + _: User = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> PaginatedReturn[IndexAttemptErrorPydantic]: + """Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params. + + Args: + cc_pair_id: ID of the connector-credential pair to get errors for + include_resolved: Whether to include resolved errors in the results + page: Page number for pagination, starting at 0 + page_size: Number of errors to return per page + _: Current user, must be curator or admin + db_session: Database session + + Returns: + Paginated list of indexing errors for the CC pair. + """ + total_count = count_index_attempt_errors_for_cc_pair( + db_session=db_session, + cc_pair_id=cc_pair_id, + unresolved_only=not include_resolved, + ) + + index_attempt_errors = get_index_attempt_errors_for_cc_pair( + db_session=db_session, + cc_pair_id=cc_pair_id, + unresolved_only=not include_resolved, + page=page, + page_size=page_size, + ) + return PaginatedReturn( + items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors], + total_items=total_count, + ) + + @router.put("/connector/{connector_id}/credential/{credential_id}") def associate_credential_to_connector( connector_id: int, diff --git a/backend/onyx/server/documents/indexing.py b/backend/onyx/server/documents/indexing.py deleted file mode 100644 index 864d0898eb7..00000000000 --- a/backend/onyx/server/documents/indexing.py +++ /dev/null @@ -1,23 +0,0 @@ -from fastapi import APIRouter -from fastapi import Depends -from sqlalchemy.orm import Session - -from onyx.auth.users import current_admin_user -from onyx.db.engine import get_session -from onyx.db.index_attempt import ( - get_index_attempt_errors, -) -from onyx.db.models import User -from onyx.server.documents.models import IndexAttemptError - -router = APIRouter(prefix="/manage") - - -@router.get("/admin/indexing-errors/{index_attempt_id}") -def get_indexing_errors( - index_attempt_id: int, - _: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> list[IndexAttemptError]: - indexing_errors = get_index_attempt_errors(index_attempt_id, db_session) - return [IndexAttemptError.from_db_model(e) for e in indexing_errors] diff --git a/backend/onyx/server/documents/models.py b/backend/onyx/server/documents/models.py index 57aeb10f673..a0b425fd65f 100644 --- a/backend/onyx/server/documents/models.py +++ b/backend/onyx/server/documents/models.py @@ -8,9 +8,9 @@ from pydantic import Field from ee.onyx.server.query_history.models import ChatSessionMinimal +from onyx.background.indexing.models import IndexAttemptErrorPydantic from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX from onyx.configs.constants import DocumentSource -from onyx.connectors.models import DocumentErrorSummary from onyx.connectors.models import InputType from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus @@ -19,7 +19,6 @@ from onyx.db.models import Credential from onyx.db.models import Document as DbDocument from onyx.db.models import IndexAttempt -from onyx.db.models import IndexAttemptError as DbIndexAttemptError from onyx.db.models import IndexingStatus from onyx.db.models import TaskStatus from onyx.server.models import FullUserSnapshot @@ -150,6 +149,7 @@ def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot class IndexAttemptSnapshot(BaseModel): id: int status: IndexingStatus | None + from_beginning: bool new_docs_indexed: int # only includes completely new docs total_docs_indexed: int # includes docs that are updated docs_removed_from_index: int @@ -166,6 +166,7 @@ def from_index_attempt_db_model( return IndexAttemptSnapshot( id=index_attempt.id, status=index_attempt.status, + from_beginning=index_attempt.from_beginning, new_docs_indexed=index_attempt.new_docs_indexed or 0, total_docs_indexed=index_attempt.total_docs_indexed or 0, docs_removed_from_index=index_attempt.docs_removed_from_index or 0, @@ -181,31 +182,6 @@ def from_index_attempt_db_model( ) -class IndexAttemptError(BaseModel): - id: int - index_attempt_id: int | None - batch_number: int | None - doc_summaries: list[DocumentErrorSummary] - error_msg: str | None - traceback: str | None - time_created: str - - @classmethod - def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError": - doc_summaries = [ - DocumentErrorSummary.from_dict(summary) for summary in error.doc_summaries - ] - return IndexAttemptError( - id=error.id, - index_attempt_id=error.index_attempt_id, - batch_number=error.batch, - doc_summaries=doc_summaries, - error_msg=error.error_msg, - traceback=error.traceback, - time_created=error.time_created.isoformat(), - ) - - # These are the types currently supported by the pagination hook # More api endpoints can be refactored and be added here for use with the pagination hook PaginatedType = TypeVar( @@ -214,6 +190,7 @@ def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError": FullUserSnapshot, InvitedUserSnapshot, ChatSessionMinimal, + IndexAttemptErrorPydantic, ) diff --git a/backend/onyx/utils/object_size_check.py b/backend/onyx/utils/object_size_check.py new file mode 100644 index 00000000000..61847939179 --- /dev/null +++ b/backend/onyx/utils/object_size_check.py @@ -0,0 +1,26 @@ +import sys +from typing import TypeVar + +T = TypeVar("T", dict, list, tuple, set, frozenset) + + +def deep_getsizeof(obj: T, seen: set[int] | None = None) -> int: + """Recursively sum size of objects, handling circular references.""" + if seen is None: + seen = set() + + obj_id = id(obj) + if obj_id in seen: + return 0 # Prevent infinite recursion for circular references + + seen.add(obj_id) + size = sys.getsizeof(obj) + + if isinstance(obj, dict): + size += sum( + deep_getsizeof(k, seen) + deep_getsizeof(v, seen) for k, v in obj.items() + ) + elif isinstance(obj, (list, tuple, set, frozenset)): + size += sum(deep_getsizeof(i, seen) for i in obj) + + return size diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index 3370034c3ce..ef638aebae4 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -42,7 +42,7 @@ def run_jobs() -> None: "--loglevel=INFO", "--hostname=light@%n", "-Q", - "vespa_metadata_sync,connector_deletion,doc_permissions_upsert", + "vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup", ] cmd_worker_heavy = [ diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 78d5679bae7..d42672096a5 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -33,7 +33,7 @@ stopasgroup=true command=celery -A onyx.background.celery.versioned_apps.light worker --loglevel=INFO --hostname=light@%%n - -Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert + -Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup stdout_logfile=/var/log/celery_worker_light.log stdout_logfile_maxbytes=16MB redirect_stderr=true diff --git a/backend/tests/integration/common_utils/chat.py b/backend/tests/integration/common_utils/chat.py index b9de84ad52f..3554bfe5196 100644 --- a/backend/tests/integration/common_utils/chat.py +++ b/backend/tests/integration/common_utils/chat.py @@ -1,14 +1,15 @@ import requests -from sqlalchemy.orm import Session +from onyx.db.engine import get_session_context_manager from onyx.db.models import User -def test_create_chat_session_and_send_messages(db_session: Session) -> None: +def test_create_chat_session_and_send_messages() -> None: # Create a test user - test_user = User(email="test@example.com", hashed_password="dummy_hash") - db_session.add(test_user) - db_session.commit() + with get_session_context_manager() as db_session: + test_user = User(email="test@example.com", hashed_password="dummy_hash") + db_session.add(test_user) + db_session.commit() base_url = "http://localhost:8080" # Adjust this to your API's base URL headers = {"Authorization": f"Bearer {test_user.id}"} diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index 57db1ad9a32..c6731e7397a 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -1,5 +1,7 @@ import os +ADMIN_USER_NAME = "admin_user" + API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http" API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost" API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" @@ -9,3 +11,6 @@ GENERAL_HEADERS = {"Content-Type": "application/json"} NUM_DOCS = 5 + +MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost" +MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001 diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index 44c029d7687..ece55db632e 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -223,12 +223,13 @@ def verify( @staticmethod def run_once( cc_pair: DATestCCPair, + from_beginning: bool, user_performing_action: DATestUser | None = None, ) -> None: body = { "connector_id": cc_pair.connector_id, "credential_ids": [cc_pair.credential_id], - "from_beginning": True, + "from_beginning": from_beginning, } result = requests.post( url=f"{API_SERVER_URL}/manage/admin/connector/run-once", diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py index 9ce3430fe62..29c4bfd221a 100644 --- a/backend/tests/integration/common_utils/managers/document.py +++ b/backend/tests/integration/common_utils/managers/document.py @@ -1,9 +1,14 @@ from uuid import uuid4 import requests +from sqlalchemy import and_ +from sqlalchemy import select +from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.enums import AccessType +from onyx.db.models import ConnectorCredentialPair +from onyx.db.models import DocumentByConnectorCredentialPair from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import NUM_DOCS @@ -186,3 +191,39 @@ def verify( group_names, doc_creating_user, ) + + @staticmethod + def fetch_documents_for_cc_pair( + cc_pair_id: int, + db_session: Session, + vespa_client: vespa_fixture, + ) -> list[SimpleTestDocument]: + stmt = ( + select(DocumentByConnectorCredentialPair) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .where(ConnectorCredentialPair.id == cc_pair_id) + ) + documents = db_session.execute(stmt).scalars().all() + if not documents: + return [] + + doc_ids = [document.id for document in documents] + retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"] + + final_docs: list[SimpleTestDocument] = [] + # NOTE: they are really chunks, but we're assuming that for these tests + # we only have one chunk per document for now + for doc_dict in retrieved_docs_dict: + doc_id = doc_dict["fields"]["document_id"] + doc_content = doc_dict["fields"]["content"] + final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content)) + + return final_docs diff --git a/backend/tests/integration/common_utils/managers/index_attempt.py b/backend/tests/integration/common_utils/managers/index_attempt.py index ca795440f2d..7cc71d9903c 100644 --- a/backend/tests/integration/common_utils/managers/index_attempt.py +++ b/backend/tests/integration/common_utils/managers/index_attempt.py @@ -4,6 +4,7 @@ import requests +from onyx.background.indexing.models import IndexAttemptErrorPydantic from onyx.db.engine import get_session_context_manager from onyx.db.enums import IndexModelStatus from onyx.db.models import IndexAttempt @@ -13,6 +14,7 @@ from onyx.server.documents.models import PaginatedReturn from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY from tests.integration.common_utils.test_models import DATestIndexAttempt from tests.integration.common_utils.test_models import DATestUser @@ -92,8 +94,12 @@ def get_index_attempt_page( "page_size": page_size, } + url = ( + f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts" + f"?{urlencode(query_params, doseq=True)}" + ) response = requests.get( - url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}", + url=url, headers=user_performing_action.headers if user_performing_action else GENERAL_HEADERS, @@ -104,3 +110,125 @@ def get_index_attempt_page( items=[IndexAttemptSnapshot(**item) for item in data["items"]], total_items=data["total_items"], ) + + @staticmethod + def get_latest_index_attempt_for_cc_pair( + cc_pair_id: int, + user_performing_action: DATestUser | None = None, + ) -> IndexAttemptSnapshot | None: + """Get an IndexAttempt by ID""" + index_attempts = IndexAttemptManager.get_index_attempt_page( + cc_pair_id, user_performing_action=user_performing_action + ).items + if not index_attempts: + return None + + index_attempts = sorted( + index_attempts, key=lambda x: x.time_started or "0", reverse=True + ) + return index_attempts[0] + + @staticmethod + def wait_for_index_attempt_start( + cc_pair_id: int, + index_attempts_to_ignore: list[int] | None = None, + timeout: float = MAX_DELAY, + user_performing_action: DATestUser | None = None, + ) -> IndexAttemptSnapshot: + """Wait for an IndexAttempt to start""" + start = datetime.now() + index_attempts_to_ignore = index_attempts_to_ignore or [] + + while True: + index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair( + cc_pair_id=cc_pair_id, + user_performing_action=user_performing_action, + ) + if ( + index_attempt + and index_attempt.time_started + and index_attempt.id not in index_attempts_to_ignore + ): + return index_attempt + + elapsed = (datetime.now() - start).total_seconds() + if elapsed > timeout: + raise TimeoutError( + f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds" + ) + + @staticmethod + def get_index_attempt_by_id( + index_attempt_id: int, + cc_pair_id: int, + user_performing_action: DATestUser | None = None, + ) -> IndexAttemptSnapshot: + page_num = 0 + page_size = 10 + while True: + page = IndexAttemptManager.get_index_attempt_page( + cc_pair_id=cc_pair_id, + page=page_num, + page_size=page_size, + user_performing_action=user_performing_action, + ) + for attempt in page.items: + if attempt.id == index_attempt_id: + return attempt + + if len(page.items) < page_size: + break + + page_num += 1 + + raise ValueError(f"IndexAttempt {index_attempt_id} not found") + + @staticmethod + def wait_for_index_attempt_completion( + index_attempt_id: int, + cc_pair_id: int, + timeout: float = MAX_DELAY, + user_performing_action: DATestUser | None = None, + ) -> None: + """Wait for an IndexAttempt to complete""" + start = datetime.now() + while True: + index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=index_attempt_id, + cc_pair_id=cc_pair_id, + user_performing_action=user_performing_action, + ) + + if index_attempt.status and index_attempt.status.is_terminal(): + print(f"IndexAttempt {index_attempt_id} completed") + return + + elapsed = (datetime.now() - start).total_seconds() + if elapsed > timeout: + raise TimeoutError( + f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds" + ) + + print( + f"Waiting for IndexAttempt {index_attempt_id} to complete. " + f"elapsed={elapsed:.2f} timeout={timeout}" + ) + + @staticmethod + def get_index_attempt_errors_for_cc_pair( + cc_pair_id: int, + include_resolved: bool = True, + user_performing_action: DATestUser | None = None, + ) -> list[IndexAttemptErrorPydantic]: + url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100" + if include_resolved: + url += "&include_resolved=true" + response = requests.get( + url=url, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + data = response.json() + return [IndexAttemptErrorPydantic(**item) for item in data["items"]] diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index e586aebecd4..aa611021bf2 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -25,6 +25,7 @@ from onyx.setup import setup_postgres from onyx.setup import setup_vespa from onyx.utils.logger import setup_logger +from tests.integration.common_utils.timeout import run_with_timeout logger = setup_logger() @@ -66,6 +67,7 @@ def _run_migrations( def downgrade_postgres( database: str = "postgres", + schema: str = "public", config_name: str = "alembic", revision: str = "base", clear_data: bool = False, @@ -73,8 +75,8 @@ def downgrade_postgres( """Downgrade Postgres database to base state.""" if clear_data: if revision != "base": - logger.warning("Clearing data without rolling back to base state") - # Delete all rows to allow migrations to be rolled back + raise ValueError("Clearing data without rolling back to base state") + conn = psycopg2.connect( dbname=database, user=POSTGRES_USER, @@ -82,38 +84,33 @@ def downgrade_postgres( host=POSTGRES_HOST, port=POSTGRES_PORT, ) + conn.autocommit = True # Need autocommit for dropping schema cur = conn.cursor() - # Disable triggers to prevent foreign key constraints from being checked - cur.execute("SET session_replication_role = 'replica';") - - # Fetch all table names in the current database + # Close any existing connections to the schema before dropping cur.execute( - """ - SELECT tablename - FROM pg_tables - WHERE schemaname = 'public' + f""" + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{database}' + AND pg_stat_activity.state = 'idle in transaction' + AND pid <> pg_backend_pid(); """ ) - tables = cur.fetchall() - - for table in tables: - table_name = table[0] + # Drop and recreate the public schema - this removes ALL objects + cur.execute(f"DROP SCHEMA {schema} CASCADE;") + cur.execute(f"CREATE SCHEMA {schema};") - # Don't touch migration history or Kombu - if table_name in ("alembic_version", "kombu_message", "kombu_queue"): - continue + # Restore default privileges + cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;") + cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;") - cur.execute(f'DELETE FROM "{table_name}"') - - # Re-enable triggers - cur.execute("SET session_replication_role = 'origin';") - - conn.commit() cur.close() conn.close() + return + # Downgrade to base conn_str = build_connection_string( db=database, @@ -157,11 +154,37 @@ def reset_postgres( setup_onyx: bool = True, ) -> None: """Reset the Postgres database.""" - downgrade_postgres( - database=database, config_name=config_name, revision="base", clear_data=True - ) + # this seems to hang due to locking issues, so run with a timeout with a few retries + NUM_TRIES = 10 + TIMEOUT = 10 + success = False + for _ in range(NUM_TRIES): + logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})") + try: + run_with_timeout( + downgrade_postgres, + TIMEOUT, + kwargs={ + "database": database, + "config_name": config_name, + "revision": "base", + "clear_data": True, + }, + ) + success = True + break + except TimeoutError: + logger.warning( + f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})" + ) + + if not success: + raise RuntimeError("Postgres downgrade failed after 10 timeouts.") + + logger.info("Upgrading Postgres...") upgrade_postgres(database=database, config_name=config_name, revision="head") if setup_onyx: + logger.info("Setting up Postgres...") with get_session_context_manager() as db_session: setup_postgres(db_session) diff --git a/backend/tests/integration/common_utils/test_document_utils.py b/backend/tests/integration/common_utils/test_document_utils.py new file mode 100644 index 00000000000..234968e96cc --- /dev/null +++ b/backend/tests/integration/common_utils/test_document_utils.py @@ -0,0 +1,57 @@ +import uuid +from datetime import datetime +from datetime import timezone + +from onyx.configs.constants import DocumentSource +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure +from onyx.connectors.models import Section + + +def create_test_document( + doc_id: str | None = None, + text: str = "Test content", + link: str = "http://test.com", + source: DocumentSource = DocumentSource.MOCK_CONNECTOR, + metadata: dict | None = None, +) -> Document: + """Create a test document with the given parameters. + + Args: + doc_id: Optional document ID. If not provided, a random UUID will be generated. + text: The text content of the document. Defaults to "Test content". + link: The link for the document section. Defaults to "http://test.com". + source: The document source. Defaults to MOCK_CONNECTOR. + metadata: Optional metadata dictionary. Defaults to empty dict. + """ + doc_id = doc_id or f"test-doc-{uuid.uuid4()}" + return Document( + id=doc_id, + sections=[Section(text=text, link=link)], + source=source, + semantic_identifier=doc_id, + doc_updated_at=datetime.now(timezone.utc), + metadata=metadata or {}, + ) + + +def create_test_document_failure( + doc_id: str, + failure_message: str = "Simulated failure", + document_link: str | None = None, +) -> ConnectorFailure: + """Create a test document failure with the given parameters. + + Args: + doc_id: The ID of the document that failed. + failure_message: The failure message. Defaults to "Simulated failure". + document_link: Optional link to the failed document. + """ + return ConnectorFailure( + failed_document=DocumentFailure( + document_id=doc_id, + document_link=document_link, + ), + failure_message=failure_message, + ) diff --git a/backend/tests/integration/common_utils/timeout.py b/backend/tests/integration/common_utils/timeout.py new file mode 100644 index 00000000000..64dacecaf86 --- /dev/null +++ b/backend/tests/integration/common_utils/timeout.py @@ -0,0 +1,18 @@ +import multiprocessing +from collections.abc import Callable +from typing import Any +from typing import TypeVar + +T = TypeVar("T") + + +def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T: + # Use multiprocessing to prevent a thread from blocking the main thread + with multiprocessing.Pool(processes=1) as pool: + async_result = pool.apply_async(task, kwds=kwargs) + try: + # Wait at most timeout seconds for the function to complete + result = async_result.get(timeout=timeout) + return result + except multiprocessing.TimeoutError: + raise TimeoutError(f"Function timed out after {timeout} seconds") diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index ec50669d0bb..8c929934754 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,12 +1,11 @@ import os -from collections.abc import Generator import pytest -from sqlalchemy.orm import Session from onyx.auth.schemas import UserRole from onyx.db.engine import get_session_context_manager from onyx.db.search_settings import get_current_search_settings +from tests.integration.common_utils.constants import ADMIN_USER_NAME from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.user import build_email from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD @@ -36,16 +35,24 @@ def load_env_vars(env_file: str = ".env") -> None: load_env_vars() -@pytest.fixture -def db_session() -> Generator[Session, None, None]: - with get_session_context_manager() as session: - yield session +"""NOTE: for some reason using this seems to lead to misc +`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly` +errors. + +Commenting out till we can get to the bottom of it. For now, just using +instantiate the session directly within the test. +""" +# @pytest.fixture +# def db_session() -> Generator[Session, None, None]: +# with get_session_context_manager() as session: +# yield session @pytest.fixture -def vespa_client(db_session: Session) -> vespa_fixture: - search_settings = get_current_search_settings(db_session) - return vespa_fixture(index_name=search_settings.index_name) +def vespa_client() -> vespa_fixture: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + return vespa_fixture(index_name=search_settings.index_name) @pytest.fixture @@ -56,20 +63,27 @@ def reset() -> None: @pytest.fixture def new_admin_user(reset: None) -> DATestUser | None: try: - return UserManager.create(name="admin_user") + return UserManager.create(name=ADMIN_USER_NAME) except Exception: return None @pytest.fixture -def admin_user() -> DATestUser | None: +def admin_user() -> DATestUser: try: - return UserManager.create(name="admin_user") - except Exception: - pass + user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True) + + # if there are other users for some reason, reset and try again + if not UserManager.is_role(user, UserRole.ADMIN): + print("Trying to reset") + reset_all() + user = UserManager.create(name=ADMIN_USER_NAME) + return user + except Exception as e: + print(f"Failed to create admin user: {e}") try: - return UserManager.login_as_user( + user = UserManager.login_as_user( DATestUser( id="", email=build_email("admin_user"), @@ -79,10 +93,16 @@ def admin_user() -> DATestUser | None: is_active=True, ) ) - except Exception: - pass + if not UserManager.is_role(user, UserRole.ADMIN): + reset_all() + user = UserManager.create(name=ADMIN_USER_NAME) + return user + + return user + except Exception as e: + print(f"Failed to create or login as admin user: {e}") - return None + raise RuntimeError("Failed to create or login as admin user") @pytest.fixture diff --git a/backend/tests/integration/connector_job_tests/google/test_google_drive_permission_sync.py b/backend/tests/integration/connector_job_tests/google/test_google_drive_permission_sync.py index 30283097cf8..d72939196f2 100644 --- a/backend/tests/integration/connector_job_tests/google/test_google_drive_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/google/test_google_drive_permission_sync.py @@ -138,7 +138,9 @@ def test_google_permission_sync( GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1) # run indexing - CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user ) @@ -184,7 +186,9 @@ def test_google_permission_sync( GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2) # Run indexing - CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index a40f977dc1e..d8e5b7a8e53 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -113,7 +113,9 @@ def test_slack_permission_sync( # Run indexing before = datetime.now(timezone.utc) - CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, @@ -305,7 +307,9 @@ def test_slack_group_permission_sync( ) # Run indexing - CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, diff --git a/backend/tests/integration/connector_job_tests/slack/test_prune.py b/backend/tests/integration/connector_job_tests/slack/test_prune.py index 4fda4d6a63c..69f045cf05a 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_prune.py +++ b/backend/tests/integration/connector_job_tests/slack/test_prune.py @@ -111,7 +111,9 @@ def test_slack_prune( # Run indexing before = datetime.now(timezone.utc) - CCPairManager.run_once(cc_pair, admin_user) + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, diff --git a/backend/tests/integration/mock_services/docker-compose.mock-it-services.yml b/backend/tests/integration/mock_services/docker-compose.mock-it-services.yml new file mode 100644 index 00000000000..1e28975ee8a --- /dev/null +++ b/backend/tests/integration/mock_services/docker-compose.mock-it-services.yml @@ -0,0 +1,20 @@ +version: '3.8' + +services: + mock_connector_server: + build: + context: ./mock_connector_server + dockerfile: Dockerfile + ports: + - "8001:8001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8001/health"] + interval: 10s + timeout: 5s + retries: 5 + networks: + - onyx-stack_default +networks: + onyx-stack_default: + name: onyx-stack_default + external: true diff --git a/backend/tests/integration/mock_services/mock_connector_server/Dockerfile b/backend/tests/integration/mock_services/mock_connector_server/Dockerfile new file mode 100644 index 00000000000..7e056ff7840 --- /dev/null +++ b/backend/tests/integration/mock_services/mock_connector_server/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.11.7-slim-bookworm + +WORKDIR /app + +RUN pip install fastapi uvicorn + +COPY ./main.py /app/main.py + +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"] \ No newline at end of file diff --git a/backend/tests/integration/mock_services/mock_connector_server/main.py b/backend/tests/integration/mock_services/mock_connector_server/main.py new file mode 100644 index 00000000000..a6ffbf57b33 --- /dev/null +++ b/backend/tests/integration/mock_services/mock_connector_server/main.py @@ -0,0 +1,76 @@ +from fastapi import FastAPI +from fastapi import HTTPException +from pydantic import BaseModel +from pydantic import Field + +# We would like to import these, but it makes building this so much harder/slower +# from onyx.connectors.mock_connector.connector import SingleConnectorYield +# from onyx.connectors.models import ConnectorCheckpoint + +app = FastAPI() + + +# Global state to store connector behavior configuration +class ConnectorBehavior(BaseModel): + connector_yields: list[dict] = Field( + default_factory=list + ) # really list[SingleConnectorYield] + called_with_checkpoints: list[dict] = Field( + default_factory=list + ) # really list[ConnectorCheckpoint] + + +current_behavior: ConnectorBehavior = ConnectorBehavior() + + +@app.post("/set-behavior") +async def set_behavior(behavior: list[dict]) -> None: + """Set the behavior for the next connector run""" + global current_behavior + current_behavior = ConnectorBehavior(connector_yields=behavior) + + +@app.get("/get-documents") +async def get_documents() -> list[dict]: + """Get the next batch of documents and update the checkpoint""" + global current_behavior + + if not current_behavior.connector_yields: + raise HTTPException( + status_code=400, detail="No documents or failures configured" + ) + + connector_yields = current_behavior.connector_yields + + # Clear the current behavior after returning it + current_behavior = ConnectorBehavior() + + return connector_yields + + +@app.post("/add-checkpoint") +async def add_checkpoint(checkpoint: dict) -> None: + """Add a checkpoint to the list of checkpoints. Called by the MockConnector.""" + global current_behavior + current_behavior.called_with_checkpoints.append(checkpoint) + + +@app.get("/get-checkpoints") +async def get_checkpoints() -> list[dict]: + """Get the list of checkpoints. Used by the test to verify the + proper checkpoint ordering.""" + global current_behavior + return current_behavior.called_with_checkpoints + + +@app.post("/reset") +async def reset() -> None: + """Reset the connector behavior to default""" + global current_behavior + current_behavior = ConnectorBehavior() + + +@app.get("/health") +async def health_check() -> dict[str, str]: + """Health check endpoint""" + return {"status": "healthy"} diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index 8878a502e2a..dc685e89f39 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -9,6 +9,8 @@ from sqlalchemy.orm import Session +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import DocumentFailure from onyx.db.engine import get_sqlalchemy_engine from onyx.db.enums import IndexingStatus from onyx.db.index_attempt import create_index_attempt @@ -101,10 +103,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: create_index_attempt_error( index_attempt_id=new_attempt.id, - batch=1, - docs=[], - exception_msg="", - exception_traceback="", + connector_credential_pair_id=cc_pair_1.id, + failure=ConnectorFailure( + failure_message="Test error", + failed_document=DocumentFailure( + document_id=cc_pair_1.documents[0].id, + document_link=None, + ), + failed_entity=None, + ), db_session=db_session, ) @@ -127,10 +134,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: ) create_index_attempt_error( index_attempt_id=attempt_id, - batch=1, - docs=[], - exception_msg="", - exception_traceback="", + connector_credential_pair_id=cc_pair_1.id, + failure=ConnectorFailure( + failure_message="Test error", + failed_document=DocumentFailure( + document_id=cc_pair_1.documents[0].id, + document_link=None, + ), + failed_entity=None, + ), db_session=db_session, ) diff --git a/backend/tests/integration/tests/indexing/test_checkpointing.py b/backend/tests/integration/tests/indexing/test_checkpointing.py new file mode 100644 index 00000000000..6aeac9f53f3 --- /dev/null +++ b/backend/tests/integration/tests/indexing/test_checkpointing.py @@ -0,0 +1,518 @@ +import uuid +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +import httpx +import pytest + +from onyx.configs.constants import DocumentSource +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import EntityFailure +from onyx.connectors.models import InputType +from onyx.db.engine import get_session_context_manager +from onyx.db.enums import IndexingStatus +from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST +from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager +from tests.integration.common_utils.test_document_utils import create_test_document +from tests.integration.common_utils.test_document_utils import ( + create_test_document_failure, +) +from tests.integration.common_utils.test_models import DATestUser +from tests.integration.common_utils.vespa import vespa_fixture + + +@pytest.fixture +def mock_server_client() -> httpx.Client: + print( + f"Initializing mock server client with host: " + f"{MOCK_CONNECTOR_SERVER_HOST} and port: " + f"{MOCK_CONNECTOR_SERVER_PORT}" + ) + return httpx.Client( + base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}", + timeout=5.0, + ) + + +def test_mock_connector_basic_flow( + mock_server_client: httpx.Client, + vespa_client: vespa_fixture, + admin_user: DATestUser, +) -> None: + """Test that the mock connector can successfully process documents and failures""" + # Set up mock server behavior + doc_uuid = uuid.uuid4() + test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}") + + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [test_doc.model_dump(mode="json")], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [], + } + ], + ) + assert response.status_code == 200 + + # create CC Pair + index attempt + cc_pair = CCPairManager.create_from_scratch( + name=f"mock-connector-{uuid.uuid4()}", + source=DocumentSource.MOCK_CONNECTOR, + input_type=InputType.POLL, + connector_specific_config={ + "mock_server_host": MOCK_CONNECTOR_SERVER_HOST, + "mock_server_port": MOCK_CONNECTOR_SERVER_PORT, + }, + user_performing_action=admin_user, + ) + + # wait for index attempt to start + index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # wait for index attempt to finish + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # validate status + finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_index_attempt.status == IndexingStatus.SUCCESS + + # Verify results + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 1 + assert documents[0].id == test_doc.id + + errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert len(errors) == 0 + + +def test_mock_connector_with_failures( + mock_server_client: httpx.Client, + vespa_client: vespa_fixture, + admin_user: DATestUser, +) -> None: + """Test that the mock connector processes both successes and failures properly.""" + doc1 = create_test_document() + doc2 = create_test_document() + doc2_failure = create_test_document_failure(doc_id=doc2.id) + + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [doc1.model_dump(mode="json")], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [doc2_failure.model_dump(mode="json")], + } + ], + ) + assert response.status_code == 200 + + # Create a CC Pair for the mock connector + cc_pair = CCPairManager.create_from_scratch( + name=f"mock-connector-failure-{uuid.uuid4()}", + source=DocumentSource.MOCK_CONNECTOR, + input_type=InputType.POLL, + connector_specific_config={ + "mock_server_host": MOCK_CONNECTOR_SERVER_HOST, + "mock_server_port": MOCK_CONNECTOR_SERVER_PORT, + }, + user_performing_action=admin_user, + ) + + # Wait for the index attempt to start and then complete + index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # validate status + finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS + + # Verify results: doc1 should be indexed and doc2 should have an error entry + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 1 + assert documents[0].id == doc1.id + + errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert len(errors) == 1 + error = errors[0] + assert error.failure_message == doc2_failure.failure_message + assert error.document_id == doc2.id + + +def test_mock_connector_failure_recovery( + mock_server_client: httpx.Client, + vespa_client: vespa_fixture, + admin_user: DATestUser, +) -> None: + """Test that a failed document can be successfully indexed in a subsequent attempt + while maintaining previously successful documents.""" + # Create test documents and failure + doc1 = create_test_document() + doc2 = create_test_document() + doc2_failure = create_test_document_failure(doc_id=doc2.id) + entity_id = "test-entity-id" + entity_failure_msg = "Simulated unhandled error" + + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [doc1.model_dump(mode="json")], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [ + doc2_failure.model_dump(mode="json"), + ConnectorFailure( + failed_entity=EntityFailure( + entity_id=entity_id, + missed_time_range=( + datetime.now(timezone.utc) - timedelta(days=1), + datetime.now(timezone.utc), + ), + ), + failure_message=entity_failure_msg, + ).model_dump(mode="json"), + ], + } + ], + ) + assert response.status_code == 200 + + # Create CC Pair and run initial indexing attempt + cc_pair = CCPairManager.create_from_scratch( + name=f"mock-connector-{uuid.uuid4()}", + source=DocumentSource.MOCK_CONNECTOR, + input_type=InputType.POLL, + connector_specific_config={ + "mock_server_host": MOCK_CONNECTOR_SERVER_HOST, + "mock_server_port": MOCK_CONNECTOR_SERVER_PORT, + }, + user_performing_action=admin_user, + ) + + # Wait for first index attempt to complete + initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=initial_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # validate status + finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=initial_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS + + # Verify initial state: doc1 indexed, doc2 failed + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 1 + assert documents[0].id == doc1.id + + errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert len(errors) == 2 + error_doc2 = next(error for error in errors if error.document_id == doc2.id) + assert error_doc2.failure_message == doc2_failure.failure_message + assert not error_doc2.is_resolved + + error_entity = next(error for error in errors if error.entity_id == entity_id) + assert error_entity.failure_message == entity_failure_msg + assert not error_entity.is_resolved + + # Update mock server to return success for both documents + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [ + doc1.model_dump(mode="json"), + doc2.model_dump(mode="json"), + ], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [], + } + ], + ) + assert response.status_code == 200 + + # Trigger another indexing attempt + # NOTE: must be from beginning to handle the entity failure + CCPairManager.run_once( + cc_pair, from_beginning=True, user_performing_action=admin_user + ) + recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + index_attempts_to_ignore=[initial_index_attempt.id], + user_performing_action=admin_user, + ) + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=recovery_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=recovery_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_second_index_attempt.status == IndexingStatus.SUCCESS + + # Verify both documents are now indexed + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 2 + document_ids = {doc.id for doc in documents} + assert doc2.id in document_ids + assert doc1.id in document_ids + + # Verify original failures were marked as resolved + errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert len(errors) == 2 + error_doc2 = next(error for error in errors if error.document_id == doc2.id) + error_entity = next(error for error in errors if error.entity_id == entity_id) + + assert error_doc2.is_resolved + assert error_entity.is_resolved + + +def test_mock_connector_checkpoint_recovery( + mock_server_client: httpx.Client, + vespa_client: vespa_fixture, + admin_user: DATestUser, +) -> None: + """Test that checkpointing works correctly when an unhandled exception occurs + and that subsequent runs pick up from the last successful checkpoint.""" + # Create test documents + # Create 100 docs for first batch, this is needed to get past the + # `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`. + docs_batch_1 = [create_test_document() for _ in range(100)] + doc2 = create_test_document() + doc3 = create_test_document() + + # Set up mock server behavior for initial run: + # - First yield: 100 docs with checkpoint1 + # - Second yield: doc2 with checkpoint2 + # - Third yield: unhandled exception + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [doc.model_dump(mode="json") for doc in docs_batch_1], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=True + ).model_dump(mode="json"), + "failures": [], + }, + { + "documents": [doc2.model_dump(mode="json")], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=True + ).model_dump(mode="json"), + "failures": [], + }, + { + "documents": [], + # should never hit this, unhandled exception happens first + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [], + "unhandled_exception": "Simulated unhandled error", + }, + ], + ) + assert response.status_code == 200 + + # Create CC Pair and run initial indexing attempt + cc_pair = CCPairManager.create_from_scratch( + name=f"mock-connector-checkpoint-{uuid.uuid4()}", + source=DocumentSource.MOCK_CONNECTOR, + input_type=InputType.POLL, + connector_specific_config={ + "mock_server_host": MOCK_CONNECTOR_SERVER_HOST, + "mock_server_port": MOCK_CONNECTOR_SERVER_PORT, + }, + user_performing_action=admin_user, + ) + + # Wait for first index attempt to complete + initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=initial_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # validate status + finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=initial_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_index_attempt.status == IndexingStatus.FAILED + + # Verify initial state: both docs should be indexed + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 101 # 100 docs from first batch + doc2 + document_ids = {doc.id for doc in documents} + assert doc2.id in document_ids + assert all(doc.id in document_ids for doc in docs_batch_1) + + # Get the checkpoints that were sent to the mock server + response = mock_server_client.get("/get-checkpoints") + assert response.status_code == 200 + initial_checkpoints = response.json() + + # Verify we got the expected checkpoints in order + assert len(initial_checkpoints) > 0 + assert ( + initial_checkpoints[0]["checkpoint_content"] == {} + ) # Initial empty checkpoint + assert initial_checkpoints[1]["checkpoint_content"] == {} + assert initial_checkpoints[2]["checkpoint_content"] == {} + + # Reset the mock server for the next run + response = mock_server_client.post("/reset") + assert response.status_code == 200 + + # Set up mock server behavior for recovery run - should succeed fully this time + response = mock_server_client.post( + "/set-behavior", + json=[ + { + "documents": [doc3.model_dump(mode="json")], + "checkpoint": ConnectorCheckpoint( + checkpoint_content={}, has_more=False + ).model_dump(mode="json"), + "failures": [], + } + ], + ) + assert response.status_code == 200 + + # Trigger another indexing attempt + CCPairManager.run_once( + cc_pair, from_beginning=False, user_performing_action=admin_user + ) + recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start( + cc_pair_id=cc_pair.id, + index_attempts_to_ignore=[initial_index_attempt.id], + user_performing_action=admin_user, + ) + IndexAttemptManager.wait_for_index_attempt_completion( + index_attempt_id=recovery_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + + # validate status + finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id( + index_attempt_id=recovery_index_attempt.id, + cc_pair_id=cc_pair.id, + user_performing_action=admin_user, + ) + assert finished_recovery_attempt.status == IndexingStatus.SUCCESS + + # Verify results + with get_session_context_manager() as db_session: + documents = DocumentManager.fetch_documents_for_cc_pair( + cc_pair_id=cc_pair.id, + db_session=db_session, + vespa_client=vespa_client, + ) + assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3 + document_ids = {doc.id for doc in documents} + assert doc3.id in document_ids + assert doc2.id in document_ids + assert all(doc.id in document_ids for doc in docs_batch_1) + + # Get the checkpoints from the recovery run + response = mock_server_client.get("/get-checkpoints") + assert response.status_code == 200 + recovery_checkpoints = response.json() + + # Verify the recovery run started from the last successful checkpoint + assert len(recovery_checkpoints) == 1 + assert recovery_checkpoints[0]["checkpoint_content"] == {} diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index d5fba3a7ea7..14860827c5d 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -61,6 +61,7 @@ services: # Other services - POSTGRES_HOST=relational_db - POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-} + - POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-} - VESPA_HOST=index - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose @@ -174,6 +175,7 @@ services: - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} - POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-} + - POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-} - VESPA_HOST=index - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors diff --git a/web/src/app/admin/connector/[ccPairId]/IndexAttemptErrorsModal.tsx b/web/src/app/admin/connector/[ccPairId]/IndexAttemptErrorsModal.tsx new file mode 100644 index 00000000000..c509d780c68 --- /dev/null +++ b/web/src/app/admin/connector/[ccPairId]/IndexAttemptErrorsModal.tsx @@ -0,0 +1,141 @@ +import { Modal } from "@/components/Modal"; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { IndexAttemptError } from "./types"; +import { localizeAndPrettify } from "@/lib/time"; +import { Button } from "@/components/ui/button"; +import { useState } from "react"; +import { PageSelector } from "@/components/PageSelector"; + +interface IndexAttemptErrorsModalProps { + errors: { + items: IndexAttemptError[]; + total_items: number; + }; + onClose: () => void; + onResolveAll: () => void; + isResolvingErrors?: boolean; + onPageChange: (page: number) => void; + currentPage: number; + pageSize?: number; +} + +const DEFAULT_PAGE_SIZE = 10; + +export default function IndexAttemptErrorsModal({ + errors, + onClose, + onResolveAll, + isResolvingErrors = false, + onPageChange, + currentPage, + pageSize = DEFAULT_PAGE_SIZE, +}: IndexAttemptErrorsModalProps) { + const totalPages = Math.ceil(errors.total_items / pageSize); + const hasUnresolvedErrors = errors.items.some((error) => !error.is_resolved); + + return ( + +
+
+ {isResolvingErrors ? ( +
+ Currently attempting to resolve all errors by performing a full + re-index. This may take some time to complete. +
+ ) : ( + <> +
+ Below are the errors encountered during indexing. Each row + represents a failed document or entity. +
+
+ Click the button below to kick off a full re-index to try and + resolve these errors. This full re-index may take much longer + than a normal update. +
+ + )} +
+ + + + + Time + Document ID + Error Message + Status + + + + {errors.items.map((error) => ( + + {localizeAndPrettify(error.time_created)} + + {error.document_link ? ( + + {error.document_id || error.entity_id || "Unknown"} + + ) : ( + error.document_id || error.entity_id || "Unknown" + )} + + + {error.failure_message} + + + + {error.is_resolved ? "Resolved" : "Unresolved"} + + + + ))} + +
+ +
+ {totalPages > 1 && ( +
+ onPageChange(page - 1)} + /> +
+ )} + +
+
+ {hasUnresolvedErrors && !isResolvingErrors && ( + + )} +
+
+
+
+
+ ); +} diff --git a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx index b7c03bb7e6a..d4c96c92bdf 100644 --- a/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx +++ b/web/src/app/admin/connector/[ccPairId]/IndexingAttemptsTable.tsx @@ -34,38 +34,26 @@ import usePaginatedFetch from "@/hooks/usePaginatedFetch"; const ITEMS_PER_PAGE = 8; const PAGES_PER_BATCH = 8; -export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { +export interface IndexingAttemptsTableProps { + ccPair: CCPairFullInfo; + indexAttempts: IndexAttemptSnapshot[]; + currentPage: number; + totalPages: number; + onPageChange: (page: number) => void; +} + +export function IndexingAttemptsTable({ + ccPair, + indexAttempts, + currentPage, + totalPages, + onPageChange, +}: IndexingAttemptsTableProps) { const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState< number | null >(null); - const { - currentPageData: pageOfIndexAttempts, - isLoading, - error, - currentPage, - totalPages, - goToPage, - } = usePaginatedFetch({ - itemsPerPage: ITEMS_PER_PAGE, - pagesPerBatch: PAGES_PER_BATCH, - endpoint: `${buildCCPairInfoUrl(ccPair.id)}/index-attempts`, - }); - - if (isLoading || !pageOfIndexAttempts) { - return ; - } - - if (error) { - return ( - - ); - } - - if (!pageOfIndexAttempts?.length) { + if (!indexAttempts?.length) { return ( indexAttempt.id === indexAttemptTracePopupId ); @@ -119,7 +107,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { - {pageOfIndexAttempts.map((indexAttempt) => { + {indexAttempts.map((indexAttempt) => { const docsPerMinute = getDocsProcessedPerMinute(indexAttempt)?.toFixed(2); return ( @@ -161,18 +149,6 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) { {indexAttempt.total_docs_indexed}
- {indexAttempt.error_count > 0 && ( - - - -  View Errors - - - )} - {indexAttempt.status === "success" && ( {"-"} @@ -209,7 +185,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
diff --git a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx index 962339e9fe8..bd647adcd8d 100644 --- a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx @@ -1,11 +1,9 @@ "use client"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; -import { runConnector } from "@/lib/connector"; import { Button } from "@/components/ui/button"; import Text from "@/components/ui/text"; -import { mutate } from "swr"; -import { buildCCPairInfoUrl } from "./lib"; +import { triggerIndexing } from "./lib"; import { useState } from "react"; import { Modal } from "@/components/Modal"; import { Separator } from "@/components/ui/separator"; @@ -23,26 +21,6 @@ function ReIndexPopup({ setPopup: (popupSpec: PopupSpec | null) => void; hide: () => void; }) { - async function triggerIndexing(fromBeginning: boolean) { - const errorMsg = await runConnector( - connectorId, - [credentialId], - fromBeginning - ); - if (errorMsg) { - setPopup({ - message: errorMsg, - type: "error", - }); - } else { - setPopup({ - message: "Triggered connector run", - type: "success", - }); - } - mutate(buildCCPairInfoUrl(ccPairId)); - } - return (
@@ -50,7 +28,13 @@ function ReIndexPopup({ variant="submit" className="ml-auto" onClick={() => { - triggerIndexing(false); + triggerIndexing( + false, + connectorId, + credentialId, + ccPairId, + setPopup + ); hide(); }} > @@ -68,7 +52,13 @@ function ReIndexPopup({ variant="submit" className="ml-auto" onClick={() => { - triggerIndexing(true); + triggerIndexing( + true, + connectorId, + credentialId, + ccPairId, + setPopup + ); hide(); }} > diff --git a/web/src/app/admin/connector/[ccPairId]/lib.ts b/web/src/app/admin/connector/[ccPairId]/lib.ts index c2d02b23d75..89372eaf7a2 100644 --- a/web/src/app/admin/connector/[ccPairId]/lib.ts +++ b/web/src/app/admin/connector/[ccPairId]/lib.ts @@ -1,4 +1,7 @@ +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import { runConnector } from "@/lib/connector"; import { ValidSources } from "@/lib/types"; +import { mutate } from "swr"; export function buildCCPairInfoUrl(ccPairId: string | number) { return `/api/manage/admin/cc-pair/${ccPairId}`; @@ -11,3 +14,29 @@ export function buildSimilarCredentialInfoURL( const base = `/api/manage/admin/similar-credentials/${source_type}`; return get_editable ? `${base}?get_editable=True` : base; } + +export async function triggerIndexing( + fromBeginning: boolean, + connectorId: number, + credentialId: number, + ccPairId: number, + setPopup: (popupSpec: PopupSpec | null) => void +) { + const errorMsg = await runConnector( + connectorId, + [credentialId], + fromBeginning + ); + if (errorMsg) { + setPopup({ + message: errorMsg, + type: "error", + }); + } else { + setPopup({ + message: "Triggered connector run", + type: "success", + }); + } + mutate(buildCCPairInfoUrl(ccPairId)); +} diff --git a/web/src/app/admin/connector/[ccPairId]/page.tsx b/web/src/app/admin/connector/[ccPairId]/page.tsx index a539feaaf63..fb84d31d017 100644 --- a/web/src/app/admin/connector/[ccPairId]/page.tsx +++ b/web/src/app/admin/connector/[ccPairId]/page.tsx @@ -25,13 +25,24 @@ import DeletionErrorStatus from "./DeletionErrorStatus"; import { IndexingAttemptsTable } from "./IndexingAttemptsTable"; import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster"; import { ReIndexButton } from "./ReIndexButton"; -import { buildCCPairInfoUrl } from "./lib"; -import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types"; +import { buildCCPairInfoUrl, triggerIndexing } from "./lib"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + CCPairFullInfo, + ConnectorCredentialPairStatus, + IndexAttemptError, + PaginatedIndexAttemptErrors, +} from "./types"; import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay"; import { Button } from "@/components/ui/button"; import EditPropertyModal from "@/components/modals/EditPropertyModal"; import * as Yup from "yup"; +import { AlertCircle } from "lucide-react"; +import IndexAttemptErrorsModal from "./IndexAttemptErrorsModal"; +import usePaginatedFetch from "@/hooks/usePaginatedFetch"; +import { IndexAttemptSnapshot } from "@/lib/types"; +import { Spinner } from "@/components/Spinner"; // synchronize these validations with the SQLAlchemy connector class until we have a // centralized schema for both frontend and backend @@ -51,43 +62,99 @@ const PruneFrequencySchema = Yup.object().shape({ .required("Property value is required"), }); +const ITEMS_PER_PAGE = 8; +const PAGES_PER_BATCH = 8; + function Main({ ccPairId }: { ccPairId: number }) { - const router = useRouter(); // Initialize the router + const router = useRouter(); const { data: ccPair, - isLoading, - error, + isLoading: isLoadingCCPair, + error: ccPairError, } = useSWR( buildCCPairInfoUrl(ccPairId), errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds ); + const { + currentPageData: indexAttempts, + isLoading: isLoadingIndexAttempts, + currentPage, + totalPages, + goToPage, + } = usePaginatedFetch({ + itemsPerPage: ITEMS_PER_PAGE, + pagesPerBatch: PAGES_PER_BATCH, + endpoint: `${buildCCPairInfoUrl(ccPairId)}/index-attempts`, + }); + + const { + currentPageData: indexAttemptErrorsPage, + currentPage: errorsCurrentPage, + totalPages: errorsTotalPages, + goToPage: goToErrorsPage, + } = usePaginatedFetch({ + itemsPerPage: 10, + pagesPerBatch: 1, + endpoint: `/api/manage/admin/cc-pair/${ccPairId}/errors`, + }); + + const indexAttemptErrors = indexAttemptErrorsPage + ? { + items: indexAttemptErrorsPage, + total_items: + errorsCurrentPage === errorsTotalPages && + indexAttemptErrorsPage.length === 0 + ? 0 + : errorsTotalPages * 10, + } + : null; + const [hasLoadedOnce, setHasLoadedOnce] = useState(false); const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false); const [editingPruningFrequency, setEditingPruningFrequency] = useState(false); + const [showIndexAttemptErrors, setShowIndexAttemptErrors] = useState(false); + const [showIsResolvingKickoffLoader, setShowIsResolvingKickoffLoader] = + useState(false); const { popup, setPopup } = usePopup(); + const latestIndexAttempt = indexAttempts?.[0]; + const isResolvingErrors = + (latestIndexAttempt?.status === "in_progress" || + latestIndexAttempt?.status === "not_started") && + latestIndexAttempt?.from_beginning && + // if there are errors in the latest index attempt, we don't want to show the loader + !indexAttemptErrors?.items?.some( + (error) => error.index_attempt_id === latestIndexAttempt?.id + ); + const finishConnectorDeletion = useCallback(() => { router.push("/admin/indexing/status?message=connector-deleted"); }, [router]); useEffect(() => { - if (isLoading) { + if (isLoadingCCPair) { return; } - if (ccPair && !error) { + if (ccPair && !ccPairError) { setHasLoadedOnce(true); } if ( - (hasLoadedOnce && (error || !ccPair)) || + (hasLoadedOnce && (ccPairError || !ccPair)) || (ccPair?.status === ConnectorCredentialPairStatus.DELETING && !ccPair.connector) ) { finishConnectorDeletion(); } - }, [isLoading, ccPair, error, hasLoadedOnce, finishConnectorDeletion]); + }, [ + isLoadingCCPair, + ccPair, + ccPairError, + hasLoadedOnce, + finishConnectorDeletion, + ]); const handleUpdateName = async (newName: string) => { try { @@ -191,15 +258,19 @@ function Main({ ccPairId }: { ccPairId: number }) { } }; - if (isLoading) { + if (isLoadingCCPair || isLoadingIndexAttempts) { return ; } - if (!ccPair || (!hasLoadedOnce && error)) { + if (!ccPair || (!hasLoadedOnce && ccPairError)) { return ( ); } @@ -219,6 +290,7 @@ function Main({ ccPairId }: { ccPairId: number }) { return ( <> {popup} + {showIsResolvingKickoffLoader && !isResolvingErrors && } {editingRefreshFrequency && ( )} + {showIndexAttemptErrors && indexAttemptErrors && ( + setShowIndexAttemptErrors(false)} + onResolveAll={async () => { + setShowIndexAttemptErrors(false); + setShowIsResolvingKickoffLoader(true); + await triggerIndexing( + true, + ccPair.connector.id, + ccPair.credential.id, + ccPair.id, + setPopup + ); + + // show the loader for a max of 10 seconds + setTimeout(() => { + setShowIsResolvingKickoffLoader(false); + }, 10000); + }} + isResolvingErrors={isResolvingErrors} + onPageChange={goToErrorsPage} + currentPage={errorsCurrentPage} + /> + )} + router.push("/admin/indexing/status")} /> @@ -342,13 +440,46 @@ function Main({ ccPairId }: { ccPairId: number }) { /> )} - {/* NOTE: no divider / title here for `ConfigDisplay` since it is optional and we need - to render these conditionally.*/}
Indexing Attempts
- + {indexAttemptErrors && indexAttemptErrors.total_items > 0 && ( + + + + Some documents failed to index + + + {isResolvingErrors ? ( + + + Resolving failures + + + ) : ( + <> + We ran into some issues while processing some documents.{" "} + setShowIndexAttemptErrors(true)} + > + View details. + + + )} + + + )} + {indexAttempts && ( + + )}
diff --git a/web/src/app/admin/connector/[ccPairId]/types.ts b/web/src/app/admin/connector/[ccPairId]/types.ts index 877bb3c525d..f1c534573c4 100644 --- a/web/src/app/admin/connector/[ccPairId]/types.ts +++ b/web/src/app/admin/connector/[ccPairId]/types.ts @@ -37,3 +37,27 @@ export interface PaginatedIndexAttempts { page: number; total_pages: number; } + +export interface IndexAttemptError { + id: number; + connector_credential_pair_id: number; + + document_id: string | null; + document_link: string | null; + + entity_id: string | null; + failed_time_range_start: string | null; + failed_time_range_end: string | null; + + failure_message: string; + is_resolved: boolean; + + time_created: string; + + index_attempt_id: number; +} + +export interface PaginatedIndexAttemptErrors { + items: IndexAttemptError[]; + total_items: number; +} diff --git a/web/src/app/admin/indexing/[id]/IndexAttemptErrorsTable.tsx b/web/src/app/admin/indexing/[id]/IndexAttemptErrorsTable.tsx deleted file mode 100644 index bea15555f9c..00000000000 --- a/web/src/app/admin/indexing/[id]/IndexAttemptErrorsTable.tsx +++ /dev/null @@ -1,189 +0,0 @@ -"use client"; - -import { Modal } from "@/components/Modal"; -import { PageSelector } from "@/components/PageSelector"; -import { CheckmarkIcon, CopyIcon } from "@/components/icons/icons"; -import { localizeAndPrettify } from "@/lib/time"; -import { - Table, - TableBody, - TableCell, - TableHead, - TableRow, -} from "@/components/ui/table"; -import Text from "@/components/ui/text"; -import { useState } from "react"; -import { IndexAttemptError } from "./types"; -import { TableHeader } from "@/components/ui/table"; - -const NUM_IN_PAGE = 8; - -export function CustomModal({ - isVisible, - onClose, - title, - content, - showCopyButton = false, -}: { - isVisible: boolean; - onClose: () => void; - title: string; - content: string; - showCopyButton?: boolean; -}) { - const [copyClicked, setCopyClicked] = useState(false); - - if (!isVisible) return null; - - return ( - -
- {showCopyButton && ( -
- {!copyClicked ? ( -
{ - navigator.clipboard.writeText(content); - setCopyClicked(true); - setTimeout(() => setCopyClicked(false), 2000); - }} - className="flex w-fit cursor-pointer hover:bg-accent-background p-2 border-border border rounded" - > - Copy full content - -
- ) : ( -
- Copied to clipboard - -
- )} -
- )} -
{content}
-
-
- ); -} - -export function IndexAttemptErrorsTable({ - indexAttemptErrors, -}: { - indexAttemptErrors: IndexAttemptError[]; -}) { - const [page, setPage] = useState(1); - const [modalData, setModalData] = useState<{ - id: number | null; - title: string; - content: string; - } | null>(null); - const closeModal = () => setModalData(null); - - return ( - <> - {modalData && ( - - )} - - - - - Timestamp - Batch Number - Document Summaries - Error Message - - - - {indexAttemptErrors - .slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page) - .map((indexAttemptError) => { - return ( - - - {indexAttemptError.time_created - ? localizeAndPrettify(indexAttemptError.time_created) - : "-"} - - {indexAttemptError.batch_number} - - {indexAttemptError.doc_summaries && ( -
- setModalData({ - id: indexAttemptError.id, - title: "Document Summaries", - content: JSON.stringify( - indexAttemptError.doc_summaries, - null, - 2 - ), - }) - } - className="mt-2 text-link cursor-pointer select-none" - > - View Document Summaries -
- )} -
- -
- - {indexAttemptError.error_msg || "-"} - - {indexAttemptError.traceback && ( -
- setModalData({ - id: indexAttemptError.id, - title: "Exception Traceback", - content: indexAttemptError.traceback!, - }) - } - className="mt-2 text-link cursor-pointer select-none" - > - View Full Trace -
- )} -
-
-
- ); - })} -
-
- {indexAttemptErrors.length > NUM_IN_PAGE && ( -
-
- { - setPage(newPage); - window.scrollTo({ - top: 0, - left: 0, - behavior: "smooth", - }); - }} - /> -
-
- )} - - ); -} diff --git a/web/src/app/admin/indexing/[id]/lib.ts b/web/src/app/admin/indexing/[id]/lib.ts deleted file mode 100644 index f81f95d8c2f..00000000000 --- a/web/src/app/admin/indexing/[id]/lib.ts +++ /dev/null @@ -1,3 +0,0 @@ -export function buildIndexingErrorsUrl(id: string | number) { - return `/api/manage/admin/indexing-errors/${id}`; -} diff --git a/web/src/app/admin/indexing/[id]/page.tsx b/web/src/app/admin/indexing/[id]/page.tsx deleted file mode 100644 index 75aa482a00f..00000000000 --- a/web/src/app/admin/indexing/[id]/page.tsx +++ /dev/null @@ -1,59 +0,0 @@ -"use client"; -import { use } from "react"; - -import { BackButton } from "@/components/BackButton"; -import { ErrorCallout } from "@/components/ErrorCallout"; -import { ThreeDotsLoader } from "@/components/Loading"; -import { errorHandlingFetcher } from "@/lib/fetcher"; -import Title from "@/components/ui/title"; -import useSWR from "swr"; -import { IndexAttemptErrorsTable } from "./IndexAttemptErrorsTable"; -import { buildIndexingErrorsUrl } from "./lib"; -import { IndexAttemptError } from "./types"; - -function Main({ id }: { id: number }) { - const { - data: indexAttemptErrors, - isLoading, - error, - } = useSWR( - buildIndexingErrorsUrl(id), - errorHandlingFetcher - ); - - if (isLoading) { - return ; - } - - if (error || !indexAttemptErrors) { - return ( - - ); - } - - return ( - <> - -
-
- Indexing Errors for Attempt {id} -
- -
- - ); -} - -export default function Page(props: { params: Promise<{ id: string }> }) { - const params = use(props.params); - const id = parseInt(params.id); - - return ( -
-
-
- ); -} diff --git a/web/src/app/admin/indexing/[id]/types.ts b/web/src/app/admin/indexing/[id]/types.ts deleted file mode 100644 index 66480805f58..00000000000 --- a/web/src/app/admin/indexing/[id]/types.ts +++ /dev/null @@ -1,15 +0,0 @@ -export interface IndexAttemptError { - id: number; - index_attempt_id: number; - batch_number: number; - doc_summaries: DocumentErrorSummary[]; - error_msg: string; - traceback: string; - time_created: string; -} - -export interface DocumentErrorSummary { - id: string; - semantic_id: string; - section_link: string; -} diff --git a/web/src/components/Status.tsx b/web/src/components/Status.tsx index cdeeb5ff4a0..ed57bf2e7d4 100644 --- a/web/src/components/Status.tsx +++ b/web/src/components/Status.tsx @@ -41,25 +41,11 @@ export function IndexAttemptStatus({ badge = icon; } } else if (status === "completed_with_errors") { - const icon = ( + badge = ( Completed with errors ); - badge = ( - {icon}
} - popupContent={ -
- The indexing attempt completed, but some errors were encountered - during the run. -
-
- Click View Errors for more details. -
- } - /> - ); } else if (status === "success") { badge = ( diff --git a/web/src/hooks/usePaginatedFetch.tsx b/web/src/hooks/usePaginatedFetch.tsx index ba68ee154dd..820f3220430 100644 --- a/web/src/hooks/usePaginatedFetch.tsx +++ b/web/src/hooks/usePaginatedFetch.tsx @@ -7,12 +7,13 @@ import { } from "@/lib/types"; import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types"; import { errorHandlingFetcher } from "@/lib/fetcher"; +import { PaginatedIndexAttemptErrors } from "@/app/admin/connector/[ccPairId]/types"; -type PaginatedType = - | IndexAttemptSnapshot - | AcceptedUserSnapshot - | InvitedUserSnapshot - | ChatSessionMinimal; +// Any type that has an id property +type PaginatedType = { + id: number | string; + [key: string]: any; +}; interface PaginatedApiResponse { items: T[]; diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 028b17d6613..e38819a36d1 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1232,6 +1232,7 @@ export interface ConnectorBase { indexing_start: Date | null; access_type: string; groups?: number[]; + from_beginning?: boolean; } export interface Connector extends ConnectorBase { @@ -1253,6 +1254,7 @@ export interface ConnectorSnapshot { indexing_start: number | null; time_created: string; time_updated: string; + from_beginning?: boolean; } export interface WebConfig { diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 26c7b18b267..846918986ee 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -335,6 +335,13 @@ export const SOURCE_METADATA_MAP: SourceMap = { displayName: "Not Applicable", category: SourceCategory.Other, }, + + // Just so integration tests don't crash the UI + mock_connector: { + icon: GlobeIcon, + displayName: "Mock Connector", + category: SourceCategory.Other, + }, } as SourceMap; function fillSourceMetadata( diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index a250df10f76..022ea3b7039 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -123,6 +123,7 @@ export interface FailedConnectorIndexingStatus { export interface IndexAttemptSnapshot { id: number; status: ValidStatuses | null; + from_beginning: boolean; new_docs_indexed: number; docs_removed_from_index: number; total_docs_indexed: number;