diff --git a/setup.py b/setup.py index 79efdd5..95beab4 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = '0.2.1' +VERSION = '0.2.2' DESCRIPTION = """ code and data for the Findings of ACL'23 paper Label Agnostic Pre-training for Zero-shot Text Classification by Christopher Clarke, Yuzhao Heng, Yiping Kang, Krisztian Flautner, Lingjia Tang and Jason Mars @@ -15,7 +15,7 @@ description=DESCRIPTION, long_description=DESCRIPTION, url='https://github.com/ChrisIsKing/zero-shot-text-classification', - download_url='https://github.com/ChrisIsKing/zero-shot-text-classification/archive/refs/tags/v0.2.1.tar.gz', + download_url='https://github.com/ChrisIsKing/zero-shot-text-classification/archive/refs/tags/v0.2.2.tar.gz', packages=find_packages(), include_package_data=True, install_requires=[ diff --git a/zeroshot_classifier/models/gpt3.py b/zeroshot_classifier/models/gpt3.py index 92e681b..00411c4 100644 --- a/zeroshot_classifier/models/gpt3.py +++ b/zeroshot_classifier/models/gpt3.py @@ -36,22 +36,27 @@ class ApiCaller: """ url = 'https://api.openai.com/v1/completions' - with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f: - auth = json.load(f) - # api_key, org = auth['api-key'], auth['organization'] - api_key, org = auth['api-key-chris'], auth['organization'] - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {api_key}', - 'OpenAI-Organization': org - } - - def __init__(self, model: str = 'text-ada-001', batched: bool = False, delay: float = None): + def __init__( + self, model: str = 'text-ada-001', batched: bool = False, delay: float = None, + token_path: str = None + ): self.model = model self.batched = batched self.delay = delay + token_path = token_path or os_join(u.proj_path, 'auth', 'open-ai.json') + assert os.path.exists(token_path), f'OpenAI token not found at {pl.i(token_path)}' + with open(token_path) as f: + auth = json.load(f) + api_key, org = auth['api-key'], auth['organization'] + openai.api_key = api_key + # self.headers = { + # 'Content-Type': 'application/json', + # 'Authorization': f'Bearer {api_key}', + # 'OpenAI-Organization': org + # } + @retry(wait=wait_random_exponential(min=1, max=60 * 30)) # Wait for 30min def completion(self, **kwargs): if self.delay: @@ -413,11 +418,13 @@ def parse_args(): if __name__ == '__main__': mic.output_width = 256 - with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f: - auth = json.load(f) - org = auth['organization'] - api_key = auth['api-key'] - openai.api_key = api_key + def set_openai_token(): + with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f: + auth = json.load(f) + org = auth['organization'] + api_key = auth['api-key'] + openai.api_key = api_key + # set_openai_token() # evaluate(model='text-ada-001', domain='in', dataset_name='emotion') # evaluate(model='text-curie-001', domain='out', dataset_name='multi_eurlex', concurrent=True) @@ -429,18 +436,20 @@ def parse_args(): # evaluate(model='text-davinci-002', domain='out', dataset_name='consumer_finance') # evaluate(model='text-davinci-002', domain='out', dataset_name='amazon_polarity', concurrent=True) - run_args = dict(model='text-curie-001', subsample=True, store_meta=True, store_frequency=10) - # dnm = 'amazon_polarity' - # dnm = 'yelp' - # dnm_ = 'consumer_finance' - # dnm_ = 'slurp' - dnm_ = 'multi_eurlex' # TODO: doesn't work w/ batched requests??? - # dnm_ = 'sgd' - rsm = [os_join( - u.eval_path, '2022-12-04_17-34-01_Zeroshot-GPT3-Eval_{md=text-curie-001, dm=out, dnm=multi_eurlex}', - '22-12-04_out-of-domain', f'{dnm_}_meta.json' - )] - # evaluate(domain='out', dataset_name=dnm_, **run_args, concurrent=True, delay=12, resume=rsm) + def subsample_large_dset(): + run_args = dict(model='text-curie-001', subsample=True, store_meta=True, store_frequency=10) + # dnm = 'amazon_polarity' + # dnm = 'yelp' + # dnm_ = 'consumer_finance' + # dnm_ = 'slurp' + dnm_ = 'multi_eurlex' # TODO: doesn't work w/ batched requests??? + # dnm_ = 'sgd' + rsm = [os_join( + u.eval_path, '2022-12-04_17-34-01_Zeroshot-GPT3-Eval_{md=text-curie-001, dm=out, dnm=multi_eurlex}', + '22-12-04_out-of-domain', f'{dnm_}_meta.json' + )] + evaluate(domain='out', dataset_name=dnm_, **run_args, concurrent=True, delay=12, resume=rsm) + # subsample_large_dset() def command_prompt(): args = parse_args()