Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added overall accuracy cell #189

Merged
merged 6 commits into from
Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AppConstants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export class AppConstants {
static CELL_PRECISION = 'cellPrecision';
static CELL_RECALL = 'cellRecall';
static CELL_F1_SCORE = 'cellF1Score';
static CELL_OVERALL_ACCURACY_SCORE = 'overallPrecisionScore';

/**
* Initial size of a heatmap cells for the softmax stampx
Expand Down
24 changes: 22 additions & 2 deletions src/ConfusionMatrix.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {IAppView} from './app';
import {App, IAppView} from './app';
import * as d3 from 'd3';
import * as events from 'phovea_core/src/event';
import {AppConstants} from './AppConstants';
Expand All @@ -14,7 +14,7 @@ import {ACell, LabelCell, MatrixCell, PanelCell} from './confusion_matrix_cell/C
import {zip} from './utils';
import * as confMeasures from './ConfusionMeasures';
import {Language} from './language';
import {SquareMatrix, max, Matrix} from './DataStructures';
import {SquareMatrix, max, Matrix, matrixSum} from './DataStructures';
import {
DataStoreApplicationProperties, DataStoreSelectedRun, dataStoreTimelines, RenderMode
} from './DataStore';
Expand All @@ -41,6 +41,7 @@ export class ConfusionMatrix implements IAppView {
private f1ScoreColumn: ChartColumn;
private classSizeColumn: ChartColumn;
private $cells = null;
private $overallAccuracyCell: d3.Selection<any>;

constructor(parent: Element) {
this.$node = d3.select(parent)
Expand Down Expand Up @@ -129,6 +130,9 @@ export class ConfusionMatrix implements IAppView {

const numBottomColumns = 1; // number of additional columns
this.$node.style('--num-bottom-columns', numBottomColumns);

this.$overallAccuracyCell = this.$node.append('div').classed('overall', true)
.append('div').classed('cell', true);
}

private attachListeners() {
Expand Down Expand Up @@ -347,6 +351,7 @@ export class ConfusionMatrix implements IAppView {
let dataPrecision = null;
let dataRecall = null;
let dataF1 = null;
let dataOverallAccuracy = null;

let fpfnRendererProto: IMatrixRendererChain = null;
let confMatrixRendererProto: IMatrixRendererChain = null;
Expand All @@ -362,6 +367,7 @@ export class ConfusionMatrix implements IAppView {
dataPrecision = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.PPV));
dataRecall = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.TPR));
dataF1 = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.F1));
dataOverallAccuracy = datasets.map((x) => confMeasures.calcOverallAccuracy(x.multiEpochData.map((y) => y.confusionData)));
singleEpochIndex = data[1].heatcell.indexInMultiSelection;

confMatrixRendererProto = {
Expand Down Expand Up @@ -408,6 +414,7 @@ export class ConfusionMatrix implements IAppView {
dataPrecision = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.PPV));
dataRecall = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.TPR));
dataF1 = datasets.map((x) => confMeasures.calcEvolution(x.multiEpochData.map((y) => y.confusionData), confMeasures.F1));
dataOverallAccuracy = datasets.map((x) => confMeasures.calcOverallAccuracy(x.multiEpochData.map((y) => y.confusionData)));
singleEpochIndex = null;

confMatrixRendererProto = {
Expand Down Expand Up @@ -460,6 +467,7 @@ export class ConfusionMatrix implements IAppView {
this.renderPrecisionColumn(dataPrecision, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor));
this.renderRecallColumn(dataRecall, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor));
this.renderF1ScoreColumn(dataF1, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor));
this.renderOverallAccuracyCell(dataOverallAccuracy, precRendererProto, datasets[0].labels, singleEpochIndex, datasets.map((x) => x.datasetColor));
}

private renderConfMatrixCells() {
Expand Down Expand Up @@ -489,6 +497,18 @@ export class ConfusionMatrix implements IAppView {
});
}

renderOverallAccuracyCell(data: number[][], renderer: IMatrixRendererChain, labels: string[], singleEpochIndex: number[], colors: string[]) {
const maxVal = Math.max(...[].concat(...data));
const res = {
linecell: data.map((x, i) => [{values: x, valuesInPercent: x, max: maxVal, classLabel: null, color: colors[i]}]),
heatcell: {indexInMultiSelection: singleEpochIndex, counts: null, maxVal: 0, classLabels: null, colorValues: null}
};
const cell = new PanelCell(res, AppConstants.CELL_OVERALL_ACCURACY_SCORE);
cell.init(this.$overallAccuracyCell);
applyRendererChain(renderer, cell, renderer.diagonal);
cell.render();
}

renderRecallColumn(data: Matrix<number[]>[], renderer: IMatrixRendererChain, labels: string[], singleEpochIndex: number[], colors: string[]) {
const maxVal = Math.max(...data.map((x: Matrix<number[]>) => max(x, (d) => Math.max(...d))));
let transformedData = data.map((x) => x.to1DArray());
Expand Down
20 changes: 17 additions & 3 deletions src/ConfusionMeasures.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,25 @@ export function calcEvolution(matrices: NumberMatrix[], funct: (matrix: NumberMa
const res = calcForMultipleClasses(m, funct);
matrix.values.map((c, i) => c[0].push(res[i]));
}

//const summedPercent = calcSummedPercent(matrices);
//matrix[order - 1][0] = summedPercent;
return matrix;
}

export function calcOverallAccuracy(matrices: NumberMatrix[]): number[] {
return matrices.map((m) => calcSummedPercent1(m));
}

export function calcSummedPercent1(matrix: NumberMatrix) {
let tpSum = 0;
let classSizeSum = 0;
for(let i = 0; i < matrix.order(); i++) {
tpSum += TP(matrix, i);
classSizeSum += ClassSize(matrix, i);
}
if(classSizeSum === 0) {
return 0;
}
return tpSum / classSizeSum;
}



3 changes: 2 additions & 1 deletion src/confusion_matrix_cell/ACellRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,8 @@ export class AxisRenderer extends ACellRenderer {

//todo these are magic constants: use a more sophisticated algo to solve this
let tickFrequency = 1;
if (selectedRangesLength[largest] > 20) {
const stride = Number(values[1]) - Number(values[0]);
if (stride <= 5) {
tickFrequency = 4;
}

Expand Down
2 changes: 2 additions & 0 deletions src/detail_view/DetailChart.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ export class DetailChart {
text = Language.F1_SCORE_Y_LABEL;
text = text + ' ' + Language.FOR_CLASS + ' ';
text += cell.data.linecell[0][0].classLabel;
} else if (cell.type === AppConstants.CELL_OVERALL_ACCURACY_SCORE) {
text = Language.OVERALL_ACCURACY;
}
}
this.$header.text(text);
Expand Down
2 changes: 1 addition & 1 deletion src/language.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export class Language {
static PRECISION_Y_LABEL = 'Precision [%]';
static RECALL_Y_LABEL = 'Recall [%]';
static F1_SCORE_Y_LABEL = 'F1 Score [%]';
static OVERALL_PRECISION = 'Overall Precision';
static OVERALL_ACCURACY = 'Overall Accuracy';
static FP_RATE = 'False Positive Rate';
static FN_RATE = 'False Negative Rate';
static FOR_CLASS = 'for class';
Expand Down
10 changes: 9 additions & 1 deletion src/styles/_confusionmatrix.scss
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

.grid {
display: grid;
grid-template-areas: ". . axis-top . ." ". . label-top . label-right" "axis-left label-left matrix . chart-right" ". . . . ." ". label-bottom chart-bottom . .";
grid-template-areas: ". . axis-top . ." ". . label-top . label-right" "axis-left label-left matrix . chart-right" ". . . . ." ". label-bottom chart-bottom . overall";
grid-template-columns: 20px 45px auto 10px minmax(100px, 300px);
grid-template-rows: 20px 45px auto 10px minmax(60px, 1fr);
}
Expand Down Expand Up @@ -45,6 +45,13 @@
}
}

.overall {
grid-area: overall;
position: relative;
display: grid;
grid-template-columns: repeat(var(--num-right-columns), 1fr);
}

.label-right {
grid-area: label-right;
background: $label-bg;
Expand Down Expand Up @@ -159,6 +166,7 @@
}
}

.overall,
.chart-right,
.chart-bottom,
.matrix {
Expand Down