Skip to content

Commit

Permalink
Chore: upload pypi as package; Verify command line instructions; Upda…
Browse files Browse the repository at this point in the history
…te: init w/ config
  • Loading branch information
StefanHeng committed May 19, 2023
1 parent 0f8174e commit 3ba8626
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ Below we include command line arguments and example train/eval commands for mode

**Eval**

- Evaluate a model on out-of-domain dataset `multi_eurlex`
- Evaluate a local model on out-of-domain dataset `multi_eurlex`

- ```bash
python zeroshot_classifier/models/bert.py test --domain out --dataset multi_eurlex --model_path models/2022-06-15_21-23-57_BERT-Seq-CLS-out-multi_eurlex/trained
python zeroshot_classifier/models/bert.py test --domain out --dataset multi_eurlex --model_name_or_path models/2022-06-15_21-23-57_BERT-Seq-CLS-out-multi_eurlex/trained
```





Expand Down
43 changes: 43 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from setuptools import setup, find_packages

VERSION = '0.1.1'
DESCRIPTION = """
code and data for the Findings of ACL'23 paper Label Agnostic Pre-training for Zero-shot Text Classification
by Christopher Clarke, Yuzhao Heng, Yiping Kang, Krisztian Flautner, Lingjia Tang and Jason Mars
"""

setup(
name='zeroshot-classifier',
version=VERSION,
license='MIT',
author='Christopher Clarke & Yuzhao Heng',
author_email='csclarke@umich.edu',
description=DESCRIPTION,
long_description=DESCRIPTION,
url='https://github.com/ChrisIsKing/zero-shot-text-classification',
download_url='https://github.com/ChrisIsKing/zero-shot-text-classification/archive/refs/tags/v0.1.0.tar.gz',
packages=find_packages(),
include_package_data=True,
install_requires=[
'gdown==4.5.4', 'openai==0.25.0', 'requests==2.28.1', 'tenacity==8.1.0',
'spacy==3.2.2', 'nltk==3.7', 'scikit-learn==1.1.3',
'torch==1.12.0', 'sentence-transformers==2.2.0', 'transformers==4.16.2',
'datasets==1.18.3',
'stefutils==0.22.2'
],
keywords=['python', 'nlp', 'machine-learning', 'deep-learning', 'text-classification', 'zero-shot-classification'],
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Environment :: GPU :: NVIDIA CUDA',
'Environment :: GPU :: NVIDIA CUDA :: 11.6',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Operating System :: MacOS',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Visualization'
]
)
4 changes: 2 additions & 2 deletions zeroshot_classifier/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def parse_args():
dset_args = dict(domain=domain)
if normalize_aspect:
dset_args['normalize_aspect'] = seed
data = get_datasets(in_domain_data_path if domain == 'in' else out_of_domain_data_path, **dset_args)
data = get_datasets(**dset_args)
if dataset_name == 'all':
train_dset, test_dset, labels = seq_cls_format(data, all=True)
else:
Expand Down Expand Up @@ -94,7 +94,7 @@ def tokenize_function(examples):
sampling=None, normalize_aspect=normalize_aspect
)
output_path = os_join(utcd_util.get_base_path(), u.proj_dir, u.model_dir, dir_nm)
proj_output_path = os_join(u.base_path, u.proj_dir, u.model_dir_nm, dir_nm, 'trained')
proj_output_path = os_join(u.base_path, u.proj_dir, u.model_dir, dir_nm, 'trained')
d_log = {'batch size': bsz, 'epochs': n_ep, 'warmup steps': warmup_steps, 'save path': output_path}
logger.info(f'Launched training with {pl.i(d_log)}... ')

Expand Down
1 change: 0 additions & 1 deletion zeroshot_classifier/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .data_path import BASE_PATH, PROJ_DIR, PKG_NM, MODEL_DIR, DSET_DIR
from .util import *
from . import training
from .gpt2_train import MyTrainer as GPT2Trainer
Expand Down
15 changes: 13 additions & 2 deletions zeroshot_classifier/util/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import os
import math
import json
import configparser
from os.path import join as os_join
from typing import List, Tuple, Dict, Iterable, Optional
Expand All @@ -22,7 +23,17 @@
]


sconfig = StefConfig(config_file=os_join(BASE_PATH, PROJ_DIR, PKG_NM, 'util', 'config.json')).__call__
logger = get_logger('Util')


config_path = os_join(BASE_PATH, PROJ_DIR, PKG_NM, 'util', 'config.json')
if not os.path.exists(config_path):
from zeroshot_classifier.util.config import config_dict
logger.info(f'Writing config file at {pl.i(config_path)}')
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)

sconfig = StefConfig(config_file=config_path).__call__
u = StefUtil(
base_path=BASE_PATH, project_dir=PROJ_DIR, package_name=PKG_NM, dataset_dir=DSET_DIR, model_dir=MODEL_DIR
)
Expand Down

0 comments on commit 3ba8626

Please sign in to comment.