Skip to content

Commit

Permalink
Plot a confusion matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel-B-Tufvesson committed Mar 12, 2024
1 parent 817a77e commit 1f1c18a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
stanza==1.7.0
pandas==2.2.1
pandas==2.2.1
matplotlib==3.8
10 changes: 7 additions & 3 deletions scripts/classify_with_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
print(f'Most frequenct class is "{most_frequent.most_common}"')
print()
print('Baseline results:')
evaluation.evaluate(corpus, most_frequent, labels)
evaluation.evaluate(corpus, most_frequent, labels,
draw_conf_matrix=True)
print()

# Evaluate punctuation classifier.
punctuation_classifier = rb.PunctuationClassifier()
print('Punctuation classifier results:')
evaluation.evaluate(corpus, punctuation_classifier, labels)
evaluation.evaluate(corpus, punctuation_classifier, labels,
draw_conf_matrix=True)
print()

# Evaluate clause classifier.
clause_classifier = rb.ClauseClassifier()
print('Clause classifier results:')
evaluation.evaluate(corpus, clause_classifier, labels)
evaluation.evaluate(corpus, clause_classifier, labels,
print_missclassified=('assertion', 'none'),
draw_conf_matrix=False)
print()
40 changes: 36 additions & 4 deletions speechact/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@


def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str],
print_classifications=False):
print_missclassified: tuple[str, str]|None=None,
draw_conf_matrix=False):
"""
Evaluate the classifier on the CoNNL-U corpus.
"""

# Collect all the correct and predicted labels.
all_correct_labels = []
all_predicted_labels = []
missclassified = []
for batch in corpus.batched_docs(100):

# Get the correct labels for batch.
Expand All @@ -30,6 +32,15 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str],
predicted_labels = [sentence.speech_act for sentence in batch.sentences]
all_predicted_labels += predicted_labels

# Collect the missclassifed.
if print_missclassified:
for sentence, correct in zip(batch.sentences, correct_labels):
if (correct == print_missclassified[0] and
sentence.speech_act == print_missclassified[1]):

missclassified.append(sentence.text)


# Compute accuracy.
accuracy = metrics.accuracy_score(y_true=all_correct_labels,
y_pred=all_predicted_labels)
Expand All @@ -50,7 +61,28 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str],
labels=labels)
conf_matrix_dframe = pd.DataFrame(conf_matrix,
index = labels,
columns = labels
)
columns = labels)
print('Confusion matrix:')
print(conf_matrix_dframe)
print(conf_matrix_dframe)

# Plot the confusion matrix.
if draw_conf_matrix:
plot_confusion_matrix(conf_matrix, labels)

# Print missclassified sentences.
if print_missclassified:
print()
print(f'{len(missclassified)} "{print_missclassified[0]}" sentences missclassified as "{print_missclassified[1]}".')
print('Printing missclassified sentences:')
for sentence_text in missclassified:
print(sentence_text)


def plot_confusion_matrix(confusion_matrix, labels: list[str]):
import matplotlib.pyplot as plt

display = metrics.ConfusionMatrixDisplay(confusion_matrix,
display_labels=labels)

display.plot(xticks_rotation='vertical')
plt.show()

0 comments on commit 1f1c18a

Please sign in to comment.