Skip to content

Commit

Permalink
Change ScaNN submission to pin version and use all available cores. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
arron2003 authored Sep 25, 2024
1 parent 3b4002e commit 54b8894
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion neurips23/ood/scann/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ FROM neurips23

RUN apt update
RUN apt install -y software-properties-common
RUN pip install --no-cache-dir scann
RUN pip install --no-cache-dir scann==1.3.2

WORKDIR /home/app
12 changes: 7 additions & 5 deletions neurips23/ood/scann/scann.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
from pathlib import Path

Expand All @@ -10,6 +9,7 @@
import numpy as np
import scann
from scann.proto import scann_pb2
import multiprocessing


class Scann(BaseOODANN):
Expand All @@ -19,6 +19,8 @@ def __init__(self, metric, index_params):
self.download = index_params.get('download', False)
self.tree_size = index_params.get('tree_size', 40_000)
self.serialized_dir = 'data/scann/ood'
# To account for hyper-threading.
self.num_cpus = multiprocessing.cpu_count() // 2
print('ScaNN: Init')

def load_index(self, dataset):
Expand Down Expand Up @@ -54,7 +56,7 @@ def load_index(self, dataset):
os.rename(src, dst)
self.searcher = scann.scann_ops_pybind.load_searcher(
self.serialized_dir)
self.searcher.set_num_threads(8)
self.searcher.set_num_threads(self.num_cpus)

def fit(self, dataset):
if hasattr(self, 'searcher'):
Expand Down Expand Up @@ -143,9 +145,9 @@ def fit(self, dataset):
del data
gc.collect()
if not trained and s >= 8_000_000:
self.searcher.set_num_threads(8)
self.searcher.set_num_threads(self.num_cpus)
self.searcher.rebalance(config)
self.searcher.set_num_threads(8)
self.searcher.set_num_threads(self.num_cpus)
trained = True
del ds
# path = Path(self.serialized_dir)
Expand All @@ -158,7 +160,7 @@ def query(self, X, k):
X,
leaves_to_search=self.leaves_to_search,
pre_reorder_num_neighbors=self.reorder,
batch_size=12500
batch_size=(X.shape[0] // self.num_cpus)
)[0]

def set_query_arguments(self, query_args):
Expand Down

0 comments on commit 54b8894

Please sign in to comment.