Skip to content

Commit

Permalink
fixes balancing of histogram, also avoids Array bounds issues
Browse files Browse the repository at this point in the history
  • Loading branch information
awildturtok committed Jan 3, 2024
1 parent 2bb084e commit 6f3456c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
package com.bakdata.conquery.models.query.statistics;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.List;

import com.google.common.math.Stats;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import lombok.Data;
import lombok.ToString;

@Data
public class BalancingStaticHistogram {
public class BalancingHistogram {
private final Node[] nodes;
private final double min;
private final double width;

public static BalancingStaticHistogram create(double min, double max, int expectedBins) {
return new BalancingStaticHistogram(new Node[expectedBins], min, (max - min) / (expectedBins - 1));
private final int expectedBins;

private int total;

public static BalancingHistogram create(double min, double max, int expectedBins) {
return new BalancingHistogram(new Node[expectedBins], min, (max - min) / (expectedBins - 1), expectedBins);
}

public void add(double value) {
total++;

final int index = (int) Math.floor((value - min) / width);

if (nodes[index] == null) {
Expand All @@ -29,24 +37,25 @@ public void add(double value) {
nodes[index].add(value);
}

public List<Node> balanced(int expectedBins, int total) {
public List<Node> balanced() {

final List<Node> merged = mergeLeft(total, nodes);
final List<Node> merged = mergeLeft(nodes);

final List<Node> split = splitRight(expectedBins, merged);
final List<Node> split = splitRight(merged);

return split;

}

private static List<Node> mergeLeft(int total, Node[] nodes) {
private List<Node> mergeLeft(Node[] nodes) {
final List<Node> bins = new ArrayList<>();

Node prior = null;

for (Node bin : nodes) {
for (int i = nodes.length - 1; i >= 0; i--) {
final Node bin = nodes[i];
// Not all bins are initialised.
if(bin == null){
if (bin == null) {
continue;
}

Expand All @@ -56,14 +65,14 @@ private static List<Node> mergeLeft(int total, Node[] nodes) {
}

// If the bin is too small, we merge-left
if ((double) prior.getCount() / total <= (1d / total)) {
if (prior.getCount() < (total / expectedBins) * 0.5d) {
prior = prior.merge(bin);
continue;
}

// Only emit bin, if we cannot merge left.
// emit prior, if we cannot merge left.
bins.add(prior);
prior = null;
prior = bin;
}

if (prior != null) {
Expand All @@ -76,34 +85,29 @@ private static List<Node> mergeLeft(int total, Node[] nodes) {
return bins;
}

private static List<Node> splitRight(int expectedBins, List<Node> nodes) {

if ((double) nodes.size() / (double) expectedBins >= 0.7d) {
return nodes;
}
private List<Node> splitRight(List<Node> nodes) {

final List<Node> bins = new ArrayList<>();

final Stats stats = nodes.stream().mapToDouble(node -> (double) node.getCount()).boxed().collect(Stats.toStats());

final double stdDev = stats.sampleStandardDeviation();
final double mean = stats.mean();

final Deque<Node> frontier = new ArrayDeque<>(nodes);

for (Node node : nodes) {
if (node.getCount() < mean + stdDev) {
while(!frontier.isEmpty()) {
final Node node = frontier.pop();
if (node.getCount() <= (total / expectedBins * 1.5d)) {
bins.add(node);
continue;
}

bins.addAll(node.split());
frontier.addFirst(node.split().get(1));
frontier.addFirst(node.split().get(0));
}

return bins;
}

@Data
public static final class Node {
@ToString.Exclude
private final DoubleList entries;
private double min = Double.MAX_VALUE;
private double max = Double.MIN_VALUE;
Expand All @@ -120,6 +124,7 @@ public Node merge(Node other) {
return out;
}

@ToString.Include
public int getCount() {
return entries.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,29 @@ public ResultColumnStatistics describe() {
return new StringColumnStatsCollector.ColumnDescription(getName(), getLabel(), getDescription(), Collections.emptyList(), Collections.emptyMap());
}

final List<StringColumnStatsCollector.ColumnDescription.Entry> bins = createBins(((int) getStatistics().getN()), 15);
final List<StringColumnStatsCollector.ColumnDescription.Entry> bins = createBins(15);
final Map<String, String> extras = getExtras();

return new StringColumnStatsCollector.ColumnDescription(getName(), getLabel(), getDescription(), bins, extras);
}

@NotNull
private List<StringColumnStatsCollector.ColumnDescription.Entry> createBins(int total, int expectedBins) {
final BalancingStaticHistogram histogram = BalancingStaticHistogram.create(getStatistics().getMin(), getStatistics().getMax(), expectedBins);
private List<StringColumnStatsCollector.ColumnDescription.Entry> createBins(int expectedBins) {
final BalancingHistogram histogram = BalancingHistogram.create(getStatistics().getMin(), getStatistics().getMax(), expectedBins);

Arrays.stream(getStatistics().getValues()).forEach(histogram::add);

final List<BalancingStaticHistogram.Node> balanced = histogram.balanced(expectedBins, total);
final List<BalancingHistogram.Node> balanced = histogram.balanced();


final List<StringColumnStatsCollector.ColumnDescription.Entry> entries = new ArrayList<>();


for (BalancingStaticHistogram.Node bin : balanced) {
for (BalancingHistogram.Node bin : balanced) {
final String lower = printValue(bin.getMin());
final String upper = printValue(bin.getMax());

final String binLabel = String.format("%s - %s", lower, upper);
final String binLabel = lower.equals(upper) ? lower : String.format("%s - %s", lower, upper);


entries.add(new StringColumnStatsCollector.ColumnDescription.Entry(binLabel, bin.getCount()));
Expand Down

0 comments on commit 6f3456c

Please sign in to comment.