Skip to content

Commit

Permalink
Parallel distinct hash aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Feb 12, 2025
1 parent b3f88d7 commit 18b0903
Show file tree
Hide file tree
Showing 16 changed files with 586 additions and 186 deletions.
37 changes: 31 additions & 6 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,35 @@ class AggregateHashTable : public BaseHashTable {
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity);

bool isAggregateValueDistinctForGroupByKeys(
// Returns true if the value was distinct and was inserted
// otherwise if the value already existed, returns false and the hash table is unchanged
bool insertAggregateValueIfDistinctForGroupByKeys(
const std::vector<common::ValueVector*>& groupByKeyVectors,
common::ValueVector* aggregateVector);

//! merge aggregate hash table by combining aggregate states under the same key
void merge(FactorizedTable&& other);
void merge(AggregateHashTable&& other) { merge(std::move(*other.factorizedTable)); }
// Must be called after merging hash tables with distinct functions, but only when the merged
// distinct tuples match the merged non-distinct tuples
void mergeDistinctAggregateInfo();

void finalizeAggregateStates();

void resize(uint64_t newSize);
void clear();
void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys);

AggregateHashTable createEmptyCopy() const { return AggregateHashTable(*this); }

DEFAULT_BOTH_MOVE(AggregateHashTable);
AggregateHashTable* getDistinctHashTable(uint64_t aggregateFunctionIdx) const {
return distinctHashTables[aggregateFunctionIdx].get();
}

protected:
virtual uint64_t matchFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
virtual uint64_t matchFTEntries(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);

uint64_t matchFTEntries(const FactorizedTable& srcTable, uint64_t startOffset,
Expand All @@ -111,10 +120,10 @@ class AggregateHashTable : public BaseHashTable {
void initializeFTEntries(const FactorizedTable& sourceTable, uint64_t sourceStartOffset,
uint64_t numFTEntriesToInitialize);

uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t matchUnFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

uint64_t matchFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t matchFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

void findHashSlots(const std::vector<common::ValueVector*>& flatKeyVectors,
Expand Down Expand Up @@ -166,7 +175,8 @@ class AggregateHashTable : public BaseHashTable {

void updateAggStates(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity);
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity,
bool updateDistinct);

void fillEntryWithInitialNullAggregateState(FactorizedTable& table, uint8_t* entry);

Expand Down Expand Up @@ -230,6 +240,20 @@ class AggregateHashTable : public BaseHashTable {
return distinctAggKeyTypes;
}

template<class Func>
uint8_t* findEntry(common::hash_t hash, Func compareKeys) {
auto slotIdx = getSlotIdxForHash(hash);
while (true) {
auto slot = (HashSlot*)getHashSlot(slotIdx);
if (slot->entry == nullptr) {
return nullptr;
} else if ((slot->hash == hash) && compareKeys(slot->entry)) {
return slot->entry;
}
increaseSlotIdx(slotIdx);
}
}

private:
// Does not copy the contents of the hash table and is provided as a convenient way of
// constructing more hash tables without having to hold on to or expose the construction
Expand All @@ -251,6 +275,7 @@ class AggregateHashTable : public BaseHashTable {

//! special handling of distinct aggregate
std::vector<std::unique_ptr<AggregateHashTable>> distinctHashTables;
std::vector<uint64_t> distinctHashEntriesProcessed;
uint32_t hashColOffsetInFT{};
uint32_t aggStateColOffsetInFT{};
uint32_t aggStateColIdxInFT{};
Expand Down
3 changes: 0 additions & 3 deletions src/include/processor/operator/aggregate/base_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ class BaseAggregate : public Sink {

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool isParallel() const final { return !containDistinctAggregate(); }

void finalizeInternal(ExecutionContext* context) override = 0;

std::unique_ptr<PhysicalOperator> copy() override = 0;

private:
bool containDistinctAggregate() const;

protected:
Expand Down
56 changes: 47 additions & 9 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include "main/client_context.h"
#include "processor/operator/aggregate/aggregate_input.h"
#include "processor/operator/aggregate/base_aggregate.h"
#include "processor/result/factorized_table.h"
#include "processor/result/factorized_table_schema.h"
Expand All @@ -37,12 +38,23 @@ struct HashAggregateInfo {
class HashAggregateSharedState final : public BaseAggregateSharedState {

public:
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo aggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions);
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions,
std::span<AggregateInfo> aggregateInfos);

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash) {
auto& partition =
globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()];
partition.queue->appendTuple(tuple);
}

~HashAggregateSharedState();
void appendDistinctTuple(size_t distinctFuncIndex, std::span<uint8_t> tuple,
common::hash_t hash) {
auto& partition =
globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()];
partition.distinctTableQueues[distinctFuncIndex]->appendTuple(tuple);
}

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash);
void appendOverflow(common::InMemOverflowBuffer&& overflowBuffer) {
overflow.push(std::make_unique<common::InMemOverflowBuffer>(std::move(overflowBuffer)));
}
Expand Down Expand Up @@ -82,13 +94,31 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
public:
HashAggregateInfo aggInfo;
common::MPSCQueue<std::unique_ptr<common::InMemOverflowBuffer>> overflow;
struct Partition {
std::unique_ptr<AggregateHashTable> hashTable;
std::mutex mtx;
class HashTableQueue {
public:
HashTableQueue(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema);

std::unique_ptr<HashTableQueue> copy() const {
return std::make_unique<HashTableQueue>(headBlock.load()->table.getMemoryManager(),
headBlock.load()->table.getTableSchema()->copy());
}
~HashTableQueue();

void appendTuple(std::span<uint8_t> tuple);

void mergeInto(AggregateHashTable& hashTable);

bool empty() const {
auto headBlock = this->headBlock.load();
return (headBlock == nullptr || headBlock->numTuplesReserved == 0) &&
queuedTuples.approxSize() == 0;
}

private:
struct TupleBlock {
TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchama)
TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema)
: numTuplesReserved{0}, numTuplesWritten{0},
table{memoryManager, std::move(tableSchama)} {
table{memoryManager, std::move(tableSchema)} {
// Start at a fixed capacity of one full block (so that concurrent writes are safe).
// If it is not filled, we resize it to the actual capacity before writing it to the
// hashTable
Expand All @@ -110,6 +140,14 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
// queuedTuples (at which point, the numTuplesReserved may not be equal to the
// numTuplesWritten)
std::atomic<TupleBlock*> headBlock;
};
struct Partition {
std::unique_ptr<AggregateHashTable> hashTable;
std::mutex mtx;
std::unique_ptr<HashTableQueue> queue;
// The tables storing the distinct values for distinct aggregate functions all get merged in
// the same way as the main table
std::vector<std::unique_ptr<HashTableQueue>> distinctTableQueues;
std::atomic<bool> finalized = false;
};
std::vector<Partition> globalPartitions;
Expand Down
6 changes: 6 additions & 0 deletions src/include/processor/operator/aggregate/simple_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class SimpleAggregate final : public BaseAggregate {
printInfo->copy());
}

// TODO(bmwinger): We can use the same type of partitioning to handle distinct simple aggregates
// It could even be the exact same pipeline, but it would perform better if we don't use the
// hash tables for anything but collecting the distinct values
// Maybe try and move the partitioning into BaseAggregate
bool isParallel() const final { return !containDistinctAggregate(); }

private:
void computeDistinctAggregate(AggregateHashTable* distinctHT,
function::AggregateFunction* function, AggregateInput* input,
Expand Down
22 changes: 17 additions & 5 deletions src/include/processor/result/base_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class MemoryManager;
}
namespace processor {

using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using raw_compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;
using compare_function_t =
std::function<bool(const common::ValueVector*, uint32_t, const uint8_t*)>;
using ft_compare_function_t =
std::function<bool(const uint8_t*, const uint8_t*, const common::LogicalType& type)>;

class BaseHashTable {
public:
Expand All @@ -37,14 +39,24 @@ class BaseHashTable {

uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash & bitmask; }
void setMaxNumHashSlots(uint64_t newSize);
void computeAndCombineVecHash(const std::vector<common::ValueVector*>& unFlatKeyVectors,
void computeAndCombineVecHash(std::span<const common::ValueVector*> unFlatKeyVectors,
uint32_t startVecIdx);

void computeVectorHashes(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors);
const std::vector<common::ValueVector*>& unFlatKeyVectors) {
computeVectorHashes(constSpan(flatKeyVectors), constSpan(unFlatKeyVectors));
}
void computeVectorHashes(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors);
void initSlotConstant(uint64_t numSlotsPerBlock);
bool matchFlatVecWithEntry(const std::vector<common::ValueVector*>& keyVectors,
const uint8_t* entry);

template<typename T>
std::span<const T*> constSpan(const std::vector<T*>& vector) {
return std::span(const_cast<const T**>(vector.data()), vector.size());
}

private:
void initCompareFuncs();
void initTmpHashVector();
Expand All @@ -58,7 +70,7 @@ class BaseHashTable {
storage::MemoryManager* memoryManager;
std::unique_ptr<FactorizedTable> factorizedTable;
std::vector<compare_function_t> compareEntryFuncs;
std::vector<raw_compare_function_t> rawCompareEntryFuncs;
std::vector<ft_compare_function_t> ftCompareEntryFuncs;
std::vector<common::LogicalType> keyTypes;
// Temporary arrays to hold intermediate results for appending.
std::shared_ptr<common::DataChunkState> hashState;
Expand Down
8 changes: 4 additions & 4 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ class KUZU_API FactorizedTable {

// This function scans numTuplesToScan of rows to vectors starting at tupleIdx. Callers are
// responsible for making sure all the parameters are valid.
void scan(std::vector<common::ValueVector*>& vectors, ft_tuple_idx_t tupleIdx,
void scan(std::span<common::ValueVector*> vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan) const {
std::vector<uint32_t> colIdxes(tableSchema.getNumColumns());
iota(colIdxes.begin(), colIdxes.end(), 0);
scan(vectors, tupleIdx, numTuplesToScan, colIdxes);
}
bool isEmpty() const { return getNumTuples() == 0; }
void scan(std::vector<common::ValueVector*>& vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan, std::vector<uint32_t>& colIdxToScan) const;
void scan(std::span<common::ValueVector*> vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan, std::span<uint32_t> colIdxToScan) const;
// TODO(Guodong): Unify these two interfaces along with `readUnflatCol`.
// startPos is the starting position in the tuplesToRead, not the starting position in the
// factorizedTable
void lookup(std::vector<common::ValueVector*>& vectors, std::vector<uint32_t>& colIdxesToScan,
void lookup(std::span<common::ValueVector*> vectors, std::span<uint32_t> colIdxesToScan,
uint8_t** tuplesToRead, uint64_t startPos, uint64_t numTuplesToRead) const;
void lookup(std::vector<common::ValueVector*>& vectors,
const common::SelectionVector* selVector, std::vector<uint32_t>& colIdxesToScan,
Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/result/pattern_creation_info_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class PatternCreationInfoTable : public AggregateHashTable {

PatternCreationInfo getPatternCreationInfo(const std::vector<common::ValueVector*>& keyVectors);

uint64_t matchFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
uint64_t matchFTEntries(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches) override;

private:
Expand Down
2 changes: 1 addition & 1 deletion src/processor/map/map_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expressi
getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)};

auto sharedState = std::make_shared<HashAggregateSharedState>(clientContext,
std::move(aggregateInfo), aggFunctions);
std::move(aggregateInfo), aggFunctions, aggregateInputInfos);
auto printInfo = std::make_unique<HashAggregatePrintInfo>(allKeys, aggregates);
auto aggregate = make_unique<HashAggregate>(std::make_unique<ResultSetDescriptor>(inSchema),
sharedState, std::move(aggFunctions), std::move(aggregateInputInfos),
Expand Down
Loading

0 comments on commit 18b0903

Please sign in to comment.