Skip to content

Commit

Permalink
Chore: pull changes from my repo
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanHeng committed May 17, 2023
1 parent 1dd42a5 commit 0d42a50
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 27 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ gdown==4.5.4
google-auth==2.15.0
google-auth-oauthlib==0.4.6
grpcio==1.51.1
huggingface-hub==0.4.0
huggingface-hub==0.14.1
icecream==2.1.3
idna==3.4
importlib-metadata==5.1.0
Expand Down Expand Up @@ -91,7 +91,7 @@ spacy-legacy==3.0.10
spacy-loggers==1.0.3
srsly==2.4.5
starlette==0.22.0
stefutils==0.20.2
stefutils==0.22.2
sty==1.0.4
tenacity==8.1.0
tensorboard==2.10.1
Expand Down
4 changes: 2 additions & 2 deletions zeroshot_classifier/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import gpt2
from . import architecture
from . architecture import *
from . import dual_bi_encoder
from .gpt2 import *
from . import gpt3
from . import gpt_neo
3 changes: 1 addition & 2 deletions zeroshot_classifier/models/architecture/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .binary_bert import load_sliced_binary_bert
from .sbert import BinaryBertCrossEncoder, BiEncoder
from .sbert import *
3 changes: 3 additions & 0 deletions zeroshot_classifier/models/architecture/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from stefutil import *


__all__ = ['BinaryBertCrossEncoder', 'BiEncoder']


class BinaryBertCrossEncoder(CrossEncoder):
logger = get_logger('Bin BERT Train')

Expand Down
1 change: 0 additions & 1 deletion zeroshot_classifier/models/bi-encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
from os.path import join as os_join
from typing import List, Dict
from argparse import ArgumentParser

import numpy as np
from torch.utils.data import DataLoader
Expand Down
57 changes: 40 additions & 17 deletions zeroshot_classifier/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
HF_MODEL_NAME = 'gpt2-medium'


__all__ = ['ZsGPT2Tokenizer', 'ZsGPT2LMHeadModel']


logger = get_logger(MODEL_NAME)


Expand Down Expand Up @@ -97,7 +100,7 @@ def __getitem__(self, key: Tuple[str, str]):
)
return super().__getitem__(key)

def __init__(self, form: str = 'vanilla', **kwargs):
def __init__(self, form: str = 'vanilla', verbose: bool = False, **kwargs):
"""
:param form: One of [`vanilla`, `implicit`, `explicit`]
See `binary_bert::modes`
Expand All @@ -114,7 +117,7 @@ def __init__(self, form: str = 'vanilla', **kwargs):
# TODO: when re-loaded, PAD token doesn't seem to be added...
else:
spec_toks.append(utcd_util.EOT_TOKEN) # SGD end of turn
ca.check_mismatch('GPT2 Training Strategy', form, ['vanilla', 'implicit', 'explicit'])
ca(gpt2_training_strategy=form)
self.form = form
self.did2aspect, aspect_sep_token = None, None
if form == 'implicit':
Expand All @@ -139,9 +142,11 @@ def __init__(self, form: str = 'vanilla', **kwargs):

self.warned_desc = set() # Warning for each dataset happens once @property

self.verbose = verbose
self.logger = get_logger(self.__class__.__qualname__)
d_log = dict(form=form, added_vocab=list(self.get_added_vocab().keys()), vocab_size=self.vocab_size)
self.logger.info(f'{pl.i(self.__class__.__qualname__)} initialized with {pl.i(d_log)}')
if verbose:
d_log = dict(form=form, added_vocab=list(self.get_added_vocab().keys()), vocab_size=self.vocab_size)
self.logger.info(f'{pl.i(self.__class__.__qualname__)} initialized with {pl.i(d_log)}')

@property
def max_len_single_sentence(self) -> int:
Expand Down Expand Up @@ -186,7 +191,8 @@ def __call__(
idxs_tpl = np.random.randint(len(self.templates), size=ln)

def call_single(
i, dataset_id: int = None, text: str = None, labels: List[int] = None, label_options: List[str] = None
i, dataset_id: int = None, text: str = None, labels: List[int] = None, label_options: List[str] = None,
aspect: str = None
):
dset_nm: str = None if mode == 'inference-sample' else sconfig('UTCD.dataset_id2name')[dataset_id]
if mode == 'inference-sample':
Expand All @@ -201,7 +207,7 @@ def lb_int2desc(lb: int) -> str:
n_cls = len(descs)
# `label` is shared across all datasets, map to local label within dataset
if self.cache_utcd is None:
path = os_join(utcd_util.get_base_path(), PROJ_DIR, DSET_DIR, 'processed', dataset_name)
path = os_join(utcd_util.get_base_path(), u.proj_dir, u.dset_dir, 'processed', dataset_name)
# cos `Sequential`; each split, the label is the same
self.cache_utcd = datasets.load_from_disk(path)[split].features['labels'].feature
# The ordering indicates int<=>str label mapping, i.e., index is int label,
Expand Down Expand Up @@ -238,14 +244,20 @@ def lb_int2desc(lb: int) -> str:
ids_ques = self._call_paren(question, **kwargs)
ids_text = self._call_paren(text, **kwargs)
if self.form == 'implicit':
ids_asp = self._call_paren(self.did2aspect[dataset_id], **kwargs)
if dataset_id is None:
assert aspect is not None
else:
assert aspect is None
aspect = self.did2aspect[dataset_id]
ids_asp = self._call_paren(aspect, **kwargs)
ids_text = ids_asp + [self.enc_spec(self.aspect_sep_token)] + ids_text
id_sep = self.enc_spec(self.ques_sep_token)
ids_answ = [self._call_paren(a, **kwargs) for a in answers]
ids_answ = sum(join_it(ids_answ, [id_sep]), start=[])
ln_q, ln_t, ln_a = len(ids_ques), len(ids_text), len(ids_answ)

if mode == 'inference':
assert dset_nm is not None # sanity check not `inference-sample`
# If text sample is so long that we need to truncate, leave room for one label only
ln_cont = (1+ln_q+1) + (1+ln_t+1) + 1 # for `pref_answ`
max_label_id_length = self.cache[dset_nm, split]['max_label_id_length']
Expand Down Expand Up @@ -282,7 +294,7 @@ def lb_int2desc(lb: int) -> str:
tids = [self.enc_spec(self.question_type_token)] * n_ques + \
[self.enc_spec(self.text_type_token)] * n_text + \
[self.enc_spec(self.answer_type_token)] * n_answ
if mode == 'inference':
if mode in ['inference', 'inference-sample']:
ids, tids = ids[:-(n_answ-1)], tids[:-(n_answ-1)]
assert len(ids) == (n_ques+n_text+1) # sanity check
msks = [1] * len(ids) # Encode ids are attended for CLM
Expand All @@ -307,12 +319,13 @@ def pad(ints: List[int], name) -> List[int]:
out = {k: (pad(ints, k) if mode == 'train' else ints) for k, ints in ((
('input_ids', ids), ('attention_mask', msks), ('token_type_ids', tids), ('position_ids', pids)
))}
out['dataset_id'] = dataset_id # For computing zero-shot classification accuracy
if dataset_id is not None:
out['dataset_id'] = dataset_id # For computing zero-shot classification accuracy
if mode == 'stats': # the number of tokens for just the text part
out['ids_text'] = ids_text
return out
# See `zeroshot_classifier.util.util.py::process_utcd_dataset`
keys_ = ['dataset_id', 'text', 'labels', 'label_options']
keys_ = ['dataset_id', 'text', 'labels', 'label_options', 'aspect']
if mode == 'inference-sample':
assert not is_batched, f'Batched {pl.i("inference-sample")} not supported'
else:
Expand All @@ -324,7 +337,7 @@ def pad(ints: List[int], name) -> List[int]:
))]
return BatchEncoding({k: [d[k] for d in ds] for k in ds[0]}) # Stack all the ids
else:
return BatchEncoding(call_single(0, *[samples[k] for k in keys_]))
return BatchEncoding(call_single(0, *[samples.get(k, None) for k in keys_]))


class ZsGPT2Model(GPT2Model):
Expand Down Expand Up @@ -381,7 +394,7 @@ def forward(self, dataset_id=None, **kwargs):
return super().forward(**kwargs)

@classmethod
def from_pretrained(cls, *args, is_zs_gpt2: bool = False, **kwargs):
def from_pretrained(cls, *args, is_zs_gpt2: bool = True, **kwargs):
"""
:param is_zs_gpt2: If True, loads a local `ZsGPT2LMHeadModel`; otherwise, expects a GPT2 model
"""
Expand Down Expand Up @@ -428,17 +441,27 @@ def prepare_inputs_for_generation(input_ids, past=None, **kwargs):
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
# ========================== Begin of modified ==========================
# return {
# "input_ids": input_ids,
# "past_key_values": past,
# "use_cache": kwargs.get("use_cache"),
# "position_ids": position_ids,
# "attention_mask": attention_mask,
# "token_type_ids": token_type_ids,
# }
ret = {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
# ========================== Begin of added ==========================
'dataset_id': kwargs['dataset_id'] # Should definitely exist
# ========================== End of added ==========================
}
if 'dataset_id' in kwargs: # only case it doesn't exist: `inference-sample` mode
ret['dataset_id'] = kwargs['dataset_id']
return ret
# ========================== End of modified ==========================

def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
Expand Down Expand Up @@ -770,7 +793,7 @@ def map_func(examples):
def load_trained(
form: str = 'vanilla', epoch: int = 3, normalize_aspect: bool = False, model_name_or_path: str = None
) -> Tuple[ZsGPT2LMHeadModel, ZsGPT2Tokenizer, str]:
ca.check_mismatch('GPT2 Training Strategy', form, ['vanilla', 'implicit', 'explicit'])
ca(gpt2_training_strategy=form)

d_log = dict(form=form, epoch=epoch, normalize_aspect=normalize_aspect)
logger.info(f'Loading model with {pl.i(d_log)}... ')
Expand Down
9 changes: 9 additions & 0 deletions zeroshot_classifier/util/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,15 @@
"implicit-on-text-encode-sep",
"explicit"
]
},
{
"display_name": "GPT2Training strategy",
"attr_name": "gpt2_training_strategy",
"accepted_values": [
"vanilla",
"implicit",
"explicit"
]
}
]
}
9 changes: 6 additions & 3 deletions zeroshot_classifier/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
path='UTCD/in-domain/emotion', aspect='sentiment', eval_labels_same=True, domain='in',
name='Emotion', name_compact='Emotion'
),
# `eval_labels_same` := has some unique test labels
# not `eval_labels_same` := has some unique test labels
sgd=dict(
path='UTCD/in-domain/sgd', aspect='intent', eval_labels_same=False, domain='in',
name='Schema-Guided Dialogue', name_compact='SGD'
Expand All @@ -133,8 +133,7 @@
path='UTCD/in-domain/yahoo', aspect='topic', eval_labels_same=True, domain='in',
name='Yahoo Answer Topics', name_compact='Yahoo'
),
# Out-of-domain datasets: test split intended to evaluation
# TODO: until new multi-label format supported
# Out-of-domain datasets: only test split used & intended for evaluation
amazon_polarity=dict(
path='UTCD/out-of-domain/amazon_polarity', aspect='sentiment', eval_labels_same=True, domain='out',
name='Amazon Review Polarity', name_compact='Amazon Polarity'
Expand Down Expand Up @@ -224,6 +223,10 @@
'implicit-on-text-encode-sep',
'explicit'
]
),
dict(
display_name='GPT2 Training Strategy', attr_name='gpt2_training_strategy',
accepted_values=['vanilla', 'implicit', 'explicit']
)
]
}
Expand Down

0 comments on commit 0d42a50

Please sign in to comment.