Skip to content

Commit

Permalink
don't store logits if no punct applied
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 15, 2024
1 parent 1825000 commit a4f91af
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,13 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
dset_group.create_dataset("test_logit_lengths", data=test_logit_lengths)
else:
test_labels = get_labels(lang_code, test_sentences, after_space=False)

if args.skip_punct:
print(test_logits[:, 0].shape)
# remove punct logits
test_logits = test_logits[:, 0]
# back to [N, 1]
test_logits = np.expand_dims(test_logits, axis=1)
print(test_logits.shape)
dset_group.create_dataset("test_logits", data=test_logits)
dset_group.create_dataset("test_labels", data=test_labels)

Expand Down Expand Up @@ -296,6 +302,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
else:
train_labels = get_labels(lang_code, train_sentences, after_space=False)

if args.skip_punct:
# remove punct logits
train_logits = train_logits[:, 0]
# back to [N, 1]
train_logits = np.expand_dims(train_logits, axis=1)
dset_group.create_dataset("train_logits", data=train_logits)
dset_group.create_dataset("train_labels", data=train_labels)

Expand Down

0 comments on commit a4f91af

Please sign in to comment.