From 1f1c18aeb0b46db7fdcfdbb20a1126922b468854 Mon Sep 17 00:00:00 2001 From: danieltufvesson <135624160+Daniel-B-Tufvesson@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:16:19 +0100 Subject: [PATCH] Plot a confusion matrix. --- requirements.txt | 3 ++- scripts/classify_with_rules.py | 10 ++++++--- speechact/evaluation.py | 40 ++++++++++++++++++++++++++++++---- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index faca9b8..a017973 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ stanza==1.7.0 -pandas==2.2.1 \ No newline at end of file +pandas==2.2.1 +matplotlib==3.8 \ No newline at end of file diff --git a/scripts/classify_with_rules.py b/scripts/classify_with_rules.py index 99cdeec..43870d9 100644 --- a/scripts/classify_with_rules.py +++ b/scripts/classify_with_rules.py @@ -20,17 +20,21 @@ print(f'Most frequenct class is "{most_frequent.most_common}"') print() print('Baseline results:') - evaluation.evaluate(corpus, most_frequent, labels) + evaluation.evaluate(corpus, most_frequent, labels, + draw_conf_matrix=True) print() # Evaluate punctuation classifier. punctuation_classifier = rb.PunctuationClassifier() print('Punctuation classifier results:') - evaluation.evaluate(corpus, punctuation_classifier, labels) + evaluation.evaluate(corpus, punctuation_classifier, labels, + draw_conf_matrix=True) print() # Evaluate clause classifier. clause_classifier = rb.ClauseClassifier() print('Clause classifier results:') - evaluation.evaluate(corpus, clause_classifier, labels) + evaluation.evaluate(corpus, clause_classifier, labels, + print_missclassified=('assertion', 'none'), + draw_conf_matrix=False) print() \ No newline at end of file diff --git a/speechact/evaluation.py b/speechact/evaluation.py index 2ff0bce..dfb2e45 100644 --- a/speechact/evaluation.py +++ b/speechact/evaluation.py @@ -9,7 +9,8 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str], - print_classifications=False): + print_missclassified: tuple[str, str]|None=None, + draw_conf_matrix=False): """ Evaluate the classifier on the CoNNL-U corpus. """ @@ -17,6 +18,7 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str], # Collect all the correct and predicted labels. all_correct_labels = [] all_predicted_labels = [] + missclassified = [] for batch in corpus.batched_docs(100): # Get the correct labels for batch. @@ -30,6 +32,15 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str], predicted_labels = [sentence.speech_act for sentence in batch.sentences] all_predicted_labels += predicted_labels + # Collect the missclassifed. + if print_missclassified: + for sentence, correct in zip(batch.sentences, correct_labels): + if (correct == print_missclassified[0] and + sentence.speech_act == print_missclassified[1]): + + missclassified.append(sentence.text) + + # Compute accuracy. accuracy = metrics.accuracy_score(y_true=all_correct_labels, y_pred=all_predicted_labels) @@ -50,7 +61,28 @@ def evaluate(corpus: corp.Corpus, classifier: cb.Classifier, labels: list[str], labels=labels) conf_matrix_dframe = pd.DataFrame(conf_matrix, index = labels, - columns = labels - ) + columns = labels) print('Confusion matrix:') - print(conf_matrix_dframe) \ No newline at end of file + print(conf_matrix_dframe) + + # Plot the confusion matrix. + if draw_conf_matrix: + plot_confusion_matrix(conf_matrix, labels) + + # Print missclassified sentences. + if print_missclassified: + print() + print(f'{len(missclassified)} "{print_missclassified[0]}" sentences missclassified as "{print_missclassified[1]}".') + print('Printing missclassified sentences:') + for sentence_text in missclassified: + print(sentence_text) + + +def plot_confusion_matrix(confusion_matrix, labels: list[str]): + import matplotlib.pyplot as plt + + display = metrics.ConfusionMatrixDisplay(confusion_matrix, + display_labels=labels) + + display.plot(xticks_rotation='vertical') + plt.show() \ No newline at end of file