diff --git a/src/AppConstants.ts b/src/AppConstants.ts index 51a9614..2e726d0 100644 --- a/src/AppConstants.ts +++ b/src/AppConstants.ts @@ -88,6 +88,7 @@ export class AppConstants { static CELL_PRECISION = 'cellPrecision'; static CELL_RECALL = 'cellRecall'; static CELL_F1_SCORE = 'cellF1Score'; + static CELL_OVERALL_ACCURACY_SCORE = 'overallPrecisionScore'; /** * Initial size of a heatmap cells for the softmax stampx diff --git a/src/ConfusionMatrix.ts b/src/ConfusionMatrix.ts index bfb019b..58e6bb3 100644 --- a/src/ConfusionMatrix.ts +++ b/src/ConfusionMatrix.ts @@ -1,4 +1,4 @@ -import {IAppView} from './app'; +import {App, IAppView} from './app'; import * as d3 from 'd3'; import * as events from 'phovea_core/src/event'; import {AppConstants} from './AppConstants'; @@ -14,7 +14,7 @@ import {ACell, LabelCell, MatrixCell, PanelCell} from './confusion_matrix_cell/C import {zip} from './utils'; import * as confMeasures from './ConfusionMeasures'; import {Language} from './language'; -import {SquareMatrix, max, Matrix} from './DataStructures'; +import {SquareMatrix, max, Matrix, matrixSum} from './DataStructures'; import { DataStoreApplicationProperties, DataStoreSelectedRun, dataStoreTimelines, RenderMode } from './DataStore'; @@ -41,6 +41,7 @@ export class ConfusionMatrix implements IAppView { private f1ScoreColumn: ChartColumn; private classSizeColumn: ChartColumn; private $cells = null; + private $overallAccuracyCell: d3.Selection; constructor(parent: Element) { this.$node = d3.select(parent) @@ -129,6 +130,9 @@ export class ConfusionMatrix implements IAppView { const numBottomColumns = 1; // number of additional columns this.$node.style('--num-bottom-columns', numBottomColumns); + + this.$overallAccuracyCell = this.$node.append('div').classed('overall', true) + .append('div').classed('cell', true); } private attachListeners() { @@ -347,6 +351,7 @@ export class ConfusionMatrix implements IAppView { let dataPrecision = null; let dataRecall = null; let dataF1 = null; + let dataOverallAccuracy = null; let fpfnRendererProto: IMatrixRendererChain = null; let confMatrixRendererProto: IMatrixRendererChain = null; @@ -362,6 +367,7 @@ export class ConfusionMatrix implements IAppView { dataPrecision = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.PPV)); dataRecall = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.TPR)); dataF1 = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.F1)); + dataOverallAccuracy = datasets.map((x) => confMeasures.calcOverallAccuracy(x.multiEpochData.map((y) => y.confusionData))); singleEpochIndex = data[1].heatcell.indexInMultiSelection; confMatrixRendererProto = { @@ -408,6 +414,7 @@ export class ConfusionMatrix implements IAppView { dataPrecision = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.PPV)); dataRecall = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.TPR)); dataF1 = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.F1)); + dataOverallAccuracy = datasets.map((x) => confMeasures.calcOverallAccuracy(x.multiEpochData.map((y) => y.confusionData))); singleEpochIndex = null; confMatrixRendererProto = { @@ -460,6 +467,7 @@ export class ConfusionMatrix implements IAppView { this.renderPrecisionColumn(dataPrecision, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor)); this.renderRecallColumn(dataRecall, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor)); this.renderF1ScoreColumn(dataF1, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor)); + this.renderOverallAccuracyCell(dataOverallAccuracy, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor)); } private renderConfMatrixCells() { @@ -489,6 +497,18 @@ export class ConfusionMatrix implements IAppView { }); } + renderOverallAccuracyCell(data: number[][], renderer: IMatrixRendererChain, labels: string[], singleEpochIndex: number[], colors: string[]) { + const maxVal = Math.max(...[].concat(...data)); + const res = { + linecell: data.map((x, i) => [{values: x, valuesInPercent: x, max: maxVal, classLabel: null, color: colors[i]}]), + heatcell: {indexInMultiSelection: singleEpochIndex, counts: null, maxVal: 0, classLabels: null, colorValues: null} + }; + const cell = new PanelCell(res, AppConstants.CELL_OVERALL_ACCURACY_SCORE); + cell.init(this.$overallAccuracyCell); + applyRendererChain(renderer, cell, renderer.diagonal); + cell.render(); + } + renderRecallColumn(data: Matrix[], renderer: IMatrixRendererChain, labels: string[], singleEpochIndex: number[], colors: string[]) { const maxVal = Math.max(...data.map((x: Matrix) => max(x, (d) => Math.max(...d)))); let transformedData = data.map((x) => x.to1DArray()); diff --git a/src/ConfusionMeasures.ts b/src/ConfusionMeasures.ts index baae797..1352b91 100644 --- a/src/ConfusionMeasures.ts +++ b/src/ConfusionMeasures.ts @@ -95,11 +95,25 @@ export function calcEvolution(matrices: NumberMatrix[], funct: (matrix: NumberMa const res = calcForMultipleClasses(m, funct); matrix.values.map((c, i) => c[0].push(res[i])); } - - //const summedPercent = calcSummedPercent(matrices); - //matrix[order - 1][0] = summedPercent; return matrix; } +export function calcOverallAccuracy(matrices: NumberMatrix[]): number[] { + return matrices.map((m) => calcSummedPercent1(m)); +} + +export function calcSummedPercent1(matrix: NumberMatrix) { + let tpSum = 0; + let classSizeSum = 0; + for(let i = 0; i < matrix.order(); i++) { + tpSum += TP(matrix, i); + classSizeSum += ClassSize(matrix, i); + } + if(classSizeSum === 0) { + return 0; + } + return tpSum / classSizeSum; +} + diff --git a/src/confusion_matrix_cell/ACellRenderer.ts b/src/confusion_matrix_cell/ACellRenderer.ts index dea2ca2..483e07e 100644 --- a/src/confusion_matrix_cell/ACellRenderer.ts +++ b/src/confusion_matrix_cell/ACellRenderer.ts @@ -398,7 +398,8 @@ export class AxisRenderer extends ACellRenderer { //todo these are magic constants: use a more sophisticated algo to solve this let tickFrequency = 1; - if (selectedRangesLength[largest] > 20) { + const stride = Number(values[1]) - Number(values[0]); + if (stride <= 5) { tickFrequency = 4; } diff --git a/src/detail_view/DetailChart.ts b/src/detail_view/DetailChart.ts index 5b293ec..af1784b 100644 --- a/src/detail_view/DetailChart.ts +++ b/src/detail_view/DetailChart.ts @@ -87,6 +87,8 @@ export class DetailChart { text = Language.F1_SCORE_Y_LABEL; text = text + ' ' + Language.FOR_CLASS + ' '; text += cell.data.linecell[0][0].classLabel; + } else if (cell.type === AppConstants.CELL_OVERALL_ACCURACY_SCORE) { + text = Language.OVERALL_ACCURACY; } } this.$header.text(text); diff --git a/src/language.ts b/src/language.ts index 5c3e054..9b0e6cc 100644 --- a/src/language.ts +++ b/src/language.ts @@ -23,7 +23,7 @@ export class Language { static PRECISION_Y_LABEL = 'Precision [%]'; static RECALL_Y_LABEL = 'Recall [%]'; static F1_SCORE_Y_LABEL = 'F1 Score [%]'; - static OVERALL_PRECISION = 'Overall Precision'; + static OVERALL_ACCURACY = 'Overall Accuracy'; static FP_RATE = 'False Positive Rate'; static FN_RATE = 'False Negative Rate'; static FOR_CLASS = 'for class'; diff --git a/src/styles/_confusionmatrix.scss b/src/styles/_confusionmatrix.scss index 784db2b..88ac40f 100644 --- a/src/styles/_confusionmatrix.scss +++ b/src/styles/_confusionmatrix.scss @@ -13,7 +13,7 @@ .grid { display: grid; - grid-template-areas: ". . axis-top . ." ". . label-top . label-right" "axis-left label-left matrix . chart-right" ". . . . ." ". label-bottom chart-bottom . ."; + grid-template-areas: ". . axis-top . ." ". . label-top . label-right" "axis-left label-left matrix . chart-right" ". . . . ." ". label-bottom chart-bottom . overall"; grid-template-columns: 20px 45px auto 10px minmax(100px, 300px); grid-template-rows: 20px 45px auto 10px minmax(60px, 1fr); } @@ -45,6 +45,13 @@ } } +.overall { + grid-area: overall; + position: relative; + display: grid; + grid-template-columns: repeat(var(--num-right-columns), 1fr); +} + .label-right { grid-area: label-right; background: $label-bg; @@ -159,6 +166,7 @@ } } +.overall, .chart-right, .chart-bottom, .matrix {