From 54b88941c4c27e242fcb70161c569245a6877d2a Mon Sep 17 00:00:00 2001 From: Ruiqi Guo Date: Wed, 25 Sep 2024 12:20:02 -0400 Subject: [PATCH] Change ScaNN submission to pin version and use all available cores. (#310) --- neurips23/ood/scann/Dockerfile | 2 +- neurips23/ood/scann/scann.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/neurips23/ood/scann/Dockerfile b/neurips23/ood/scann/Dockerfile index 85e1d8b1..3e1fb74f 100644 --- a/neurips23/ood/scann/Dockerfile +++ b/neurips23/ood/scann/Dockerfile @@ -2,6 +2,6 @@ FROM neurips23 RUN apt update RUN apt install -y software-properties-common -RUN pip install --no-cache-dir scann +RUN pip install --no-cache-dir scann==1.3.2 WORKDIR /home/app diff --git a/neurips23/ood/scann/scann.py b/neurips23/ood/scann/scann.py index e0a7ad36..63903231 100644 --- a/neurips23/ood/scann/scann.py +++ b/neurips23/ood/scann/scann.py @@ -1,4 +1,3 @@ - import os from pathlib import Path @@ -10,6 +9,7 @@ import numpy as np import scann from scann.proto import scann_pb2 +import multiprocessing class Scann(BaseOODANN): @@ -19,6 +19,8 @@ def __init__(self, metric, index_params): self.download = index_params.get('download', False) self.tree_size = index_params.get('tree_size', 40_000) self.serialized_dir = 'data/scann/ood' + # To account for hyper-threading. + self.num_cpus = multiprocessing.cpu_count() // 2 print('ScaNN: Init') def load_index(self, dataset): @@ -54,7 +56,7 @@ def load_index(self, dataset): os.rename(src, dst) self.searcher = scann.scann_ops_pybind.load_searcher( self.serialized_dir) - self.searcher.set_num_threads(8) + self.searcher.set_num_threads(self.num_cpus) def fit(self, dataset): if hasattr(self, 'searcher'): @@ -143,9 +145,9 @@ def fit(self, dataset): del data gc.collect() if not trained and s >= 8_000_000: - self.searcher.set_num_threads(8) + self.searcher.set_num_threads(self.num_cpus) self.searcher.rebalance(config) - self.searcher.set_num_threads(8) + self.searcher.set_num_threads(self.num_cpus) trained = True del ds # path = Path(self.serialized_dir) @@ -158,7 +160,7 @@ def query(self, X, k): X, leaves_to_search=self.leaves_to_search, pre_reorder_num_neighbors=self.reorder, - batch_size=12500 + batch_size=(X.shape[0] // self.num_cpus) )[0] def set_query_arguments(self, query_args):