Skip to content

Commit

Permalink
Adding evaluate script.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Aug 31, 2022
1 parent 91e7186 commit 7019c95
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 4 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ For a trained model, see https://huggingface.co/jzonthemtn/distilbert-imdb.
## Requirements

```
python3 -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
python3 -m pip install transformers onnxruntime torch sklearn
```

## Train

`python3 train.py`

## Convert to ONNX

`python3 -m transformers.onnx --model=local-pt-checkpoint/ --feature sequence-classification exported-to-onnx`
43 changes: 43 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("distilbert-imdb/")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-imdb/")

from datasets import load_dataset
imdb = load_dataset("imdb")
small_test_dataset = imdb["test"].shuffle(seed=42).select([i for i in list(range(300))])

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)

tokenized_test = small_test_dataset.map(preprocess_function, batched=True)

from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

import numpy as np
from datasets import load_metric

def compute_metrics(eval_pred):
load_accuracy = load_metric("accuracy")
load_f1 = load_metric("f1")

logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
return {"accuracy": accuracy, "f1": f1}

from transformers import TrainingArguments, Trainer

trainer = Trainer(
model=model,
eval_dataset=tokenized_test,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)

results = trainer.evaluate()
print(results)


171 changes: 171 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
absl-py==0.15.0
agate==1.6.0
agate-dbf==0.2.0
agate-excel==0.2.3
agate-sql==0.5.2
aiohttp==3.8.1
aiosignal==1.2.0
appdirs==1.4.3
apturl==0.5.2
astunparse==1.6.3
async-timeout==4.0.2
atomicwrites==1.1.5
attrs==19.3.0
Babel==2.6.0
beautifulsoup4==4.8.2
blinker==1.4
Brlapi==0.7.0
cachetools==5.0.0
certifi==2019.11.28
chardet==3.0.4
charset-normalizer==2.0.12
chrome-gnome-shell==0.0.0
Click==7.0
colorama==0.4.3
command-not-found==0.3
configobj==5.0.6
cryptography==2.8
csvkit==1.0.2
cupshelpers==1.0
datasets==2.4.0
dbfread==2.0.7
dbus-python==1.2.16
defer==1.0.6
defusedxml==0.7.1
dill==0.3.5.1
distlib==0.3.0
distro==1.4.0
distro-info===0.23ubuntu1
entrypoints==0.3
et-xmlfile==1.0.1
filelock==3.0.12
flatbuffers==1.12
frozenlist==1.3.1
fsspec==2022.8.0
future==0.18.2
gast==0.3.3
google-auth==2.6.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
graphviz==0.8.4
grpcio==1.32.0
h5py==2.10.0
html5lib==1.0.1
httplib2==0.14.0
huggingface-hub==0.9.1
idna==2.8
importlib-metadata==4.11.1
isodate==0.6.0
jdcal==1.0
joblib==1.1.0
Keras-Preprocessing==1.1.2
keyring==18.0.1
language-selector==0.1
launchpadlib==1.10.13
lazr.restfulclient==0.14.2
lazr.uri==1.0.3
leather==0.3.3
louis==3.12.0
lxml==4.5.0
macaroonbakery==1.3.1
Markdown==3.3.6
more-itertools==4.2.0
multidict==6.0.2
multiprocess==0.70.13
mxnet==1.7.0.post2
netifaces==0.10.4
networkx==2.6.3
numpy==1.19.5
oauthlib==3.1.0
olefile==0.46
onnx==1.11.0
openpyxl==3.0.3
opt-einsum==3.3.0
packaging==21.3
pandas==1.4.4
parsedatetime==2.4
pexpect==4.6.0
Pillow==7.0.0
pluggy==0.13.0
prompt-toolkit==2.0.10
protobuf==3.19.4
py==1.8.1
pyarrow==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycairo==1.16.2
pycups==1.9.73
Pygments==2.3.1
PyGObject==3.36.0
PyICU==2.4.2
PyJWT==1.7.1
pymacaroons==0.13.0
PyMySQL==0.9.3
PyNaCl==1.3.0
pyparsing==2.4.6
pyRFC3339==1.1
pytest==4.6.9
python-apt==2.0.0+ubuntu0.20.4.8
python-dateutil==2.8.2
python-debian===0.1.36ubuntu1
python-http-client==3.3.7
python-slugify==4.0.0
pytimeparse==1.1.5
pytz==2022.2.1
pyudev==0.21.0
pyxdg==0.26
PyYAML==6.0
regex==2022.8.17
reportlab==3.5.34
requests==2.27.1
requests-oauthlib==1.3.1
requests-unixsocket==0.2.0
responses==0.18.0
rsa==4.8
scikit-learn==1.1.2
scipy==1.9.1
screen-resolution-extra==0.0.0
SecretStorage==2.3.1
sendgrid==6.9.7
simplejson==3.16.0
six==1.15.0
sklearn==0.0
soupsieve==1.9.5
SQLAlchemy==1.3.12
sqlparse==0.2.4
ssh-import-id==5.10
starkbank-ecdsa==2.0.3
systemd-python==234
tabulate==0.8.6
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.4.4
tensorflow-estimator==2.4.0
termcolor==1.1.0
terminaltables==3.1.0
threadpoolctl==3.1.0
tokenizers==0.12.1
torch==1.12.1+cu116
torchaudio==0.12.1+cu116
torchvision==0.13.1+cu116
tqdm==4.64.0
transformers==4.21.2
typing-extensions==3.7.4.3
ubuntu-advantage-tools==27.10
ubuntu-drivers-common==0.0.0
ufw==0.36
unattended-upgrades==0.1
Unidecode==1.1.1
urllib3==1.26.8
virtualenv==20.0.17
wadllib==1.3.3
wcwidth==0.1.8
webencodings==0.5.1
Werkzeug==2.0.3
wrapt==1.12.1
xkit==0.0.0
xlrd==1.1.0
xxhash==3.0.0
yarl==1.8.1
zipp==3.7.0

0 comments on commit 7019c95

Please sign in to comment.