Skip to content

Commit

Permalink
Chore: publish; Fix: gpt3 auth token
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanHeng committed May 22, 2023
1 parent 3ea4177 commit a86787e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 30 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = '0.2.1'
VERSION = '0.2.2'
DESCRIPTION = """
code and data for the Findings of ACL'23 paper Label Agnostic Pre-training for Zero-shot Text Classification
by Christopher Clarke, Yuzhao Heng, Yiping Kang, Krisztian Flautner, Lingjia Tang and Jason Mars
Expand All @@ -15,7 +15,7 @@
description=DESCRIPTION,
long_description=DESCRIPTION,
url='https://github.com/ChrisIsKing/zero-shot-text-classification',
download_url='https://github.com/ChrisIsKing/zero-shot-text-classification/archive/refs/tags/v0.2.1.tar.gz',
download_url='https://github.com/ChrisIsKing/zero-shot-text-classification/archive/refs/tags/v0.2.2.tar.gz',
packages=find_packages(),
include_package_data=True,
install_requires=[
Expand Down
65 changes: 37 additions & 28 deletions zeroshot_classifier/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,27 @@ class ApiCaller:
"""
url = 'https://api.openai.com/v1/completions'

with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f:
auth = json.load(f)
# api_key, org = auth['api-key'], auth['organization']
api_key, org = auth['api-key-chris'], auth['organization']
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}',
'OpenAI-Organization': org
}

def __init__(self, model: str = 'text-ada-001', batched: bool = False, delay: float = None):
def __init__(
self, model: str = 'text-ada-001', batched: bool = False, delay: float = None,
token_path: str = None
):
self.model = model
self.batched = batched

self.delay = delay

token_path = token_path or os_join(u.proj_path, 'auth', 'open-ai.json')
assert os.path.exists(token_path), f'OpenAI token not found at {pl.i(token_path)}'
with open(token_path) as f:
auth = json.load(f)
api_key, org = auth['api-key'], auth['organization']
openai.api_key = api_key
# self.headers = {
# 'Content-Type': 'application/json',
# 'Authorization': f'Bearer {api_key}',
# 'OpenAI-Organization': org
# }

@retry(wait=wait_random_exponential(min=1, max=60 * 30)) # Wait for 30min
def completion(self, **kwargs):
if self.delay:
Expand Down Expand Up @@ -413,11 +418,13 @@ def parse_args():
if __name__ == '__main__':
mic.output_width = 256

with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f:
auth = json.load(f)
org = auth['organization']
api_key = auth['api-key']
openai.api_key = api_key
def set_openai_token():
with open(os_join(u.proj_path, 'auth', 'open-ai.json')) as f:
auth = json.load(f)
org = auth['organization']
api_key = auth['api-key']
openai.api_key = api_key
# set_openai_token()

# evaluate(model='text-ada-001', domain='in', dataset_name='emotion')
# evaluate(model='text-curie-001', domain='out', dataset_name='multi_eurlex', concurrent=True)
Expand All @@ -429,18 +436,20 @@ def parse_args():
# evaluate(model='text-davinci-002', domain='out', dataset_name='consumer_finance')
# evaluate(model='text-davinci-002', domain='out', dataset_name='amazon_polarity', concurrent=True)

run_args = dict(model='text-curie-001', subsample=True, store_meta=True, store_frequency=10)
# dnm = 'amazon_polarity'
# dnm = 'yelp'
# dnm_ = 'consumer_finance'
# dnm_ = 'slurp'
dnm_ = 'multi_eurlex' # TODO: doesn't work w/ batched requests???
# dnm_ = 'sgd'
rsm = [os_join(
u.eval_path, '2022-12-04_17-34-01_Zeroshot-GPT3-Eval_{md=text-curie-001, dm=out, dnm=multi_eurlex}',
'22-12-04_out-of-domain', f'{dnm_}_meta.json'
)]
# evaluate(domain='out', dataset_name=dnm_, **run_args, concurrent=True, delay=12, resume=rsm)
def subsample_large_dset():
run_args = dict(model='text-curie-001', subsample=True, store_meta=True, store_frequency=10)
# dnm = 'amazon_polarity'
# dnm = 'yelp'
# dnm_ = 'consumer_finance'
# dnm_ = 'slurp'
dnm_ = 'multi_eurlex' # TODO: doesn't work w/ batched requests???
# dnm_ = 'sgd'
rsm = [os_join(
u.eval_path, '2022-12-04_17-34-01_Zeroshot-GPT3-Eval_{md=text-curie-001, dm=out, dnm=multi_eurlex}',
'22-12-04_out-of-domain', f'{dnm_}_meta.json'
)]
evaluate(domain='out', dataset_name=dnm_, **run_args, concurrent=True, delay=12, resume=rsm)
# subsample_large_dset()

def command_prompt():
args = parse_args()
Expand Down

0 comments on commit a86787e

Please sign in to comment.