From 153a518e1a6052d2e7349f5130238be65b2a701a Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 5 Aug 2024 16:23:19 +0800 Subject: [PATCH] .. Signed-off-by: min.tian --- .../components/check_results/filters.py | 69 ++++++++++++++----- vectordb_bench/frontend/pages/filter.py | 4 +- vectordb_bench/frontend/vdb_benchmark.py | 16 +---- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/vectordb_bench/frontend/components/check_results/filters.py b/vectordb_bench/frontend/components/check_results/filters.py index 34da854bb..2bf2ed3e8 100644 --- a/vectordb_bench/frontend/components/check_results/filters.py +++ b/vectordb_bench/frontend/components/check_results/filters.py @@ -1,15 +1,18 @@ from vectordb_bench.backend.cases import Case +from vectordb_bench.backend.dataset import DatasetWithSizeType +from vectordb_bench.backend.filters import FilterType from vectordb_bench.frontend.components.check_results.data import getChartData from vectordb_bench.frontend.components.check_results.expanderStyle import ( initSidebarExanderStyle, ) +from vectordb_bench.frontend.config.dbCaseConfigs import CASE_NAME_ORDER from vectordb_bench.frontend.config.styles import SIDEBAR_CONTROL_COLUMNS import streamlit as st from vectordb_bench.models import CaseResult, TestResult -def getshownData(st, results: list[TestResult], display_case_name_order: list[str]): +def getshownData(st, results: list[TestResult], filter_type: FilterType): # hide the nav st.markdown( "", @@ -19,9 +22,7 @@ def getshownData(st, results: list[TestResult], display_case_name_order: list[st st.header("Filters") shownResults = getshownResults(st, results) - showDBNames, showCaseNames = getShowDbsAndCases( - st, shownResults, display_case_name_order - ) + showDBNames, showCaseNames = getShowDbsAndCases(st, shownResults, filter_type) shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames) @@ -56,7 +57,7 @@ def getshownResults(st, results: list[TestResult]) -> list[CaseResult]: def getShowDbsAndCases( - st, result: list[CaseResult], display_case_name_order + st, result: list[CaseResult], filter_type: FilterType ) -> tuple[list[str], list[str]]: initSidebarExanderStyle(st) allDbNames = list(set({res.task_config.db_name for res in result})) @@ -67,12 +68,8 @@ def getShowDbsAndCases( ) for res in result ] - allCaseNameSet = set({case.name for case in allCases}) - allCaseNames = [ - case_name - for case_name in display_case_name_order - if case_name in allCaseNameSet - ] + allCases = [case for case in allCases if case.filters.type == filter_type] + # DB Filter dbFilterContainer = st.container() showDBNames = filterView( @@ -81,15 +78,49 @@ def getShowDbsAndCases( allDbNames, col=1, ) + showCaseNames = [] + + if filter_type == FilterType.NonFilter: + allCaseNameSet = set({case.name for case in allCases}) + allCaseNames = [ + case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet + ] + [ + case_name + for case_name in allCaseNameSet + if case_name not in CASE_NAME_ORDER + ] + + # Case Filter + caseFilterContainer = st.container() + showCaseNames = filterView( + caseFilterContainer, + "Case Filter", + [caseName for caseName in allCaseNames], + col=1, + ) - # Case Filter - caseFilterContainer = st.container() - showCaseNames = filterView( - caseFilterContainer, - "Case Filter", - [caseName for caseName in allCaseNames], - col=1, - ) + if filter_type == FilterType.Label: + container = st.container() + datasetWithSizeTypes = [ + dataset_with_size_type for dataset_with_size_type in DatasetWithSizeType + ] + showDatasetWithSizeTypes = filterView( + container, + "Case Filter", + datasetWithSizeTypes, + col=1, + optionLables=[v.value for v in datasetWithSizeTypes], + ) + datasets = [ + dataset_with_size_type.get_manager() + for dataset_with_size_type in showDatasetWithSizeTypes + ] + showCaseNames = list( + set([case.name for case in allCases if case.dataset in datasets]) + ) + + if filter_type == FilterType.Int: + raise NotImplementedError return showDBNames, showCaseNames diff --git a/vectordb_bench/frontend/pages/filter.py b/vectordb_bench/frontend/pages/filter.py index 98fb94882..d34dcfb7f 100644 --- a/vectordb_bench/frontend/pages/filter.py +++ b/vectordb_bench/frontend/pages/filter.py @@ -1,4 +1,5 @@ import streamlit as st +from vectordb_bench.backend.filters import FilterType from vectordb_bench.frontend.components.check_results.footer import footer from vectordb_bench.frontend.components.check_results.stPageConfig import ( initResultsPageConfig, @@ -31,10 +32,9 @@ def main(): ) # results selector and filter - display_case_name_order = [] resultSelectorContainer = st.sidebar.container() shownData, failedTasks, showCaseNames = getshownData( - resultSelectorContainer, allResults, display_case_name_order + resultSelectorContainer, allResults, filter_type=FilterType.Label ) resultSelectorContainer.divider() diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py index 69edf892c..eaeb25ac7 100644 --- a/vectordb_bench/frontend/vdb_benchmark.py +++ b/vectordb_bench/frontend/vdb_benchmark.py @@ -1,5 +1,6 @@ import streamlit as st from vectordb_bench.backend.cases import CaseType +from vectordb_bench.backend.filters import FilterType from vectordb_bench.frontend.components.check_results.footer import footer from vectordb_bench.frontend.components.check_results.stPageConfig import ( initResultsPageConfig, @@ -31,21 +32,8 @@ def main(): # results selector and filter resultSelectorContainer = st.sidebar.container() - display_case_name_order = [ - case_type.case_name() - for case_type in [ - CaseType.Performance768D100M, - CaseType.Performance768D10M, - CaseType.Performance768D1M, - CaseType.Performance1536D5M, - CaseType.Performance1536D500K, - CaseType.Performance1536D50K, - CaseType.CapacityDim960, - CaseType.CapacityDim128, - ] - ] shownData, failedTasks, showCaseNames = getshownData( - resultSelectorContainer, allResults, display_case_name_order + resultSelectorContainer, allResults, filter_type=FilterType.NonFilter ) resultSelectorContainer.divider()