Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel distinct hash aggregate #4881

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,35 @@ class AggregateHashTable : public BaseHashTable {
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity);

bool isAggregateValueDistinctForGroupByKeys(
// Returns true if the value was distinct and was inserted
// otherwise if the value already existed, returns false and the hash table is unchanged
bool insertAggregateValueIfDistinctForGroupByKeys(
const std::vector<common::ValueVector*>& groupByKeyVectors,
common::ValueVector* aggregateVector);

//! merge aggregate hash table by combining aggregate states under the same key
void merge(FactorizedTable&& other);
void merge(AggregateHashTable&& other) { merge(std::move(*other.factorizedTable)); }
// Must be called after merging hash tables with distinct functions, but only when the merged
// distinct tuples match the merged non-distinct tuples
void mergeDistinctAggregateInfo();

void finalizeAggregateStates();

void resize(uint64_t newSize);
void clear();
void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys);

AggregateHashTable createEmptyCopy() const { return AggregateHashTable(*this); }

DEFAULT_BOTH_MOVE(AggregateHashTable);
AggregateHashTable* getDistinctHashTable(uint64_t aggregateFunctionIdx) const {
return distinctHashTables[aggregateFunctionIdx].get();
}

protected:
virtual uint64_t matchFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
virtual uint64_t matchFTEntries(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches);

uint64_t matchFTEntries(const FactorizedTable& srcTable, uint64_t startOffset,
Expand All @@ -111,10 +120,10 @@ class AggregateHashTable : public BaseHashTable {
void initializeFTEntries(const FactorizedTable& sourceTable, uint64_t sourceStartOffset,
uint64_t numFTEntriesToInitialize);

uint64_t matchUnFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t matchUnFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

uint64_t matchFlatVecWithFTColumn(common::ValueVector* vector, uint64_t numMayMatches,
uint64_t matchFlatVecWithFTColumn(const common::ValueVector* vector, uint64_t numMayMatches,
uint64_t& numNoMatches, uint32_t colIdx);

void findHashSlots(const std::vector<common::ValueVector*>& flatKeyVectors,
Expand Down Expand Up @@ -166,7 +175,8 @@ class AggregateHashTable : public BaseHashTable {

void updateAggStates(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity);
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity,
bool updateDistinct);

void fillEntryWithInitialNullAggregateState(FactorizedTable& table, uint8_t* entry);

Expand Down Expand Up @@ -230,6 +240,20 @@ class AggregateHashTable : public BaseHashTable {
return distinctAggKeyTypes;
}

template<class Func>
uint8_t* findEntry(common::hash_t hash, Func compareKeys) {
auto slotIdx = getSlotIdxForHash(hash);
while (true) {
auto slot = (HashSlot*)getHashSlot(slotIdx);
if (slot->entry == nullptr) {
return nullptr;
} else if ((slot->hash == hash) && compareKeys(slot->entry)) {
return slot->entry;
}
increaseSlotIdx(slotIdx);
}
}

private:
// Does not copy the contents of the hash table and is provided as a convenient way of
// constructing more hash tables without having to hold on to or expose the construction
Expand All @@ -251,6 +275,7 @@ class AggregateHashTable : public BaseHashTable {

//! special handling of distinct aggregate
std::vector<std::unique_ptr<AggregateHashTable>> distinctHashTables;
std::vector<uint64_t> distinctHashEntriesProcessed;
uint32_t hashColOffsetInFT{};
uint32_t aggStateColOffsetInFT{};
uint32_t aggStateColIdxInFT{};
Expand Down
3 changes: 0 additions & 3 deletions src/include/processor/operator/aggregate/base_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,10 @@ class BaseAggregate : public Sink {

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool isParallel() const final { return !containDistinctAggregate(); }

void finalizeInternal(ExecutionContext* context) override = 0;

std::unique_ptr<PhysicalOperator> copy() override = 0;

private:
bool containDistinctAggregate() const;

protected:
Expand Down
56 changes: 47 additions & 9 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include "main/client_context.h"
#include "processor/operator/aggregate/aggregate_input.h"
#include "processor/operator/aggregate/base_aggregate.h"
#include "processor/result/factorized_table.h"
#include "processor/result/factorized_table_schema.h"
Expand All @@ -37,12 +38,23 @@ struct HashAggregateInfo {
class HashAggregateSharedState final : public BaseAggregateSharedState {

public:
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo aggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions);
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashAggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions,
std::span<AggregateInfo> aggregateInfos);

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash) {
auto& partition =
globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()];
partition.queue->appendTuple(tuple);
}

~HashAggregateSharedState();
void appendDistinctTuple(size_t distinctFuncIndex, std::span<uint8_t> tuple,
common::hash_t hash) {
auto& partition =
globalPartitions[(hash >> shiftForPartitioning) % globalPartitions.size()];
partition.distinctTableQueues[distinctFuncIndex]->appendTuple(tuple);
}

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash);
void appendOverflow(common::InMemOverflowBuffer&& overflowBuffer) {
overflow.push(std::make_unique<common::InMemOverflowBuffer>(std::move(overflowBuffer)));
}
Expand Down Expand Up @@ -82,13 +94,31 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
public:
HashAggregateInfo aggInfo;
common::MPSCQueue<std::unique_ptr<common::InMemOverflowBuffer>> overflow;
struct Partition {
std::unique_ptr<AggregateHashTable> hashTable;
std::mutex mtx;
class HashTableQueue {
public:
HashTableQueue(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema);

std::unique_ptr<HashTableQueue> copy() const {
return std::make_unique<HashTableQueue>(headBlock.load()->table.getMemoryManager(),
headBlock.load()->table.getTableSchema()->copy());
}
~HashTableQueue();

void appendTuple(std::span<uint8_t> tuple);

void mergeInto(AggregateHashTable& hashTable);

bool empty() const {
auto headBlock = this->headBlock.load();
return (headBlock == nullptr || headBlock->numTuplesReserved == 0) &&
queuedTuples.approxSize() == 0;
}

private:
struct TupleBlock {
TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchama)
TupleBlock(storage::MemoryManager* memoryManager, FactorizedTableSchema tableSchema)
: numTuplesReserved{0}, numTuplesWritten{0},
table{memoryManager, std::move(tableSchama)} {
table{memoryManager, std::move(tableSchema)} {
// Start at a fixed capacity of one full block (so that concurrent writes are safe).
// If it is not filled, we resize it to the actual capacity before writing it to the
// hashTable
Expand All @@ -110,6 +140,14 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
// queuedTuples (at which point, the numTuplesReserved may not be equal to the
// numTuplesWritten)
std::atomic<TupleBlock*> headBlock;
};
struct Partition {
std::unique_ptr<AggregateHashTable> hashTable;
std::mutex mtx;
std::unique_ptr<HashTableQueue> queue;
// The tables storing the distinct values for distinct aggregate functions all get merged in
// the same way as the main table
std::vector<std::unique_ptr<HashTableQueue>> distinctTableQueues;
std::atomic<bool> finalized = false;
};
std::vector<Partition> globalPartitions;
Expand Down
6 changes: 6 additions & 0 deletions src/include/processor/operator/aggregate/simple_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class SimpleAggregate final : public BaseAggregate {
printInfo->copy());
}

// TODO(bmwinger): We can use the same type of partitioning to handle distinct simple aggregates
// It could even be the exact same pipeline, but it would perform better if we don't use the
// hash tables for anything but collecting the distinct values
// Maybe try and move the partitioning into BaseAggregate
bool isParallel() const final { return !containDistinctAggregate(); }

private:
void computeDistinctAggregate(AggregateHashTable* distinctHT,
function::AggregateFunction* function, AggregateInput* input,
Expand Down
22 changes: 17 additions & 5 deletions src/include/processor/result/base_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class MemoryManager;
}
namespace processor {

using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using raw_compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;
using compare_function_t =
std::function<bool(const common::ValueVector*, uint32_t, const uint8_t*)>;
using ft_compare_function_t =
std::function<bool(const uint8_t*, const uint8_t*, const common::LogicalType& type)>;

class BaseHashTable {
public:
Expand All @@ -37,14 +39,24 @@ class BaseHashTable {

uint64_t getSlotIdxForHash(common::hash_t hash) const { return hash & bitmask; }
void setMaxNumHashSlots(uint64_t newSize);
void computeAndCombineVecHash(const std::vector<common::ValueVector*>& unFlatKeyVectors,
void computeAndCombineVecHash(std::span<const common::ValueVector*> unFlatKeyVectors,
uint32_t startVecIdx);

void computeVectorHashes(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors);
const std::vector<common::ValueVector*>& unFlatKeyVectors) {
computeVectorHashes(constSpan(flatKeyVectors), constSpan(unFlatKeyVectors));
}
void computeVectorHashes(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors);
void initSlotConstant(uint64_t numSlotsPerBlock);
bool matchFlatVecWithEntry(const std::vector<common::ValueVector*>& keyVectors,
const uint8_t* entry);

template<typename T>
std::span<const T*> constSpan(const std::vector<T*>& vector) {
return std::span(const_cast<const T**>(vector.data()), vector.size());
}

private:
void initCompareFuncs();
void initTmpHashVector();
Expand All @@ -58,7 +70,7 @@ class BaseHashTable {
storage::MemoryManager* memoryManager;
std::unique_ptr<FactorizedTable> factorizedTable;
std::vector<compare_function_t> compareEntryFuncs;
std::vector<raw_compare_function_t> rawCompareEntryFuncs;
std::vector<ft_compare_function_t> ftCompareEntryFuncs;
std::vector<common::LogicalType> keyTypes;
// Temporary arrays to hold intermediate results for appending.
std::shared_ptr<common::DataChunkState> hashState;
Expand Down
8 changes: 4 additions & 4 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ class KUZU_API FactorizedTable {

// This function scans numTuplesToScan of rows to vectors starting at tupleIdx. Callers are
// responsible for making sure all the parameters are valid.
void scan(std::vector<common::ValueVector*>& vectors, ft_tuple_idx_t tupleIdx,
void scan(std::span<common::ValueVector*> vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan) const {
std::vector<uint32_t> colIdxes(tableSchema.getNumColumns());
iota(colIdxes.begin(), colIdxes.end(), 0);
scan(vectors, tupleIdx, numTuplesToScan, colIdxes);
}
bool isEmpty() const { return getNumTuples() == 0; }
void scan(std::vector<common::ValueVector*>& vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan, std::vector<uint32_t>& colIdxToScan) const;
void scan(std::span<common::ValueVector*> vectors, ft_tuple_idx_t tupleIdx,
uint64_t numTuplesToScan, std::span<uint32_t> colIdxToScan) const;
// TODO(Guodong): Unify these two interfaces along with `readUnflatCol`.
// startPos is the starting position in the tuplesToRead, not the starting position in the
// factorizedTable
void lookup(std::vector<common::ValueVector*>& vectors, std::vector<uint32_t>& colIdxesToScan,
void lookup(std::span<common::ValueVector*> vectors, std::span<uint32_t> colIdxesToScan,
uint8_t** tuplesToRead, uint64_t startPos, uint64_t numTuplesToRead) const;
void lookup(std::vector<common::ValueVector*>& vectors,
const common::SelectionVector* selVector, std::vector<uint32_t>& colIdxesToScan,
Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/result/pattern_creation_info_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class PatternCreationInfoTable : public AggregateHashTable {

PatternCreationInfo getPatternCreationInfo(const std::vector<common::ValueVector*>& keyVectors);

uint64_t matchFTEntries(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors, uint64_t numMayMatches,
uint64_t matchFTEntries(std::span<const common::ValueVector*> flatKeyVectors,
std::span<const common::ValueVector*> unFlatKeyVectors, uint64_t numMayMatches,
uint64_t numNoMatches) override;

private:
Expand Down
2 changes: 1 addition & 1 deletion src/processor/map/map_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::createHashAggregate(const expressi
getDataPos(unFlatKeys, *inSchema), getDataPos(payloads, *inSchema), std::move(tableSchema)};

auto sharedState = std::make_shared<HashAggregateSharedState>(clientContext,
std::move(aggregateInfo), aggFunctions);
std::move(aggregateInfo), aggFunctions, aggregateInputInfos);
auto printInfo = std::make_unique<HashAggregatePrintInfo>(allKeys, aggregates);
auto aggregate = make_unique<HashAggregate>(std::make_unique<ResultSetDescriptor>(inSchema),
sharedState, std::move(aggFunctions), std::move(aggregateInputInfos),
Expand Down
Loading