From 35de02d25c8430f6b1eb6dee9d17e9905a658a27 Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Sun, 5 Jan 2025 13:36:23 -0600 Subject: [PATCH] more --- .../src/functions/aggregate/builtin/first.rs | 170 ++---- .../src/functions/aggregate/builtin/minmax.rs | 577 ++++++++---------- .../functions/aggregate/builtin/regr_count.rs | 52 +- .../functions/aggregate/builtin/string_agg.rs | 39 +- crates/stdutil/src/marker.rs | 22 +- 5 files changed, 390 insertions(+), 470 deletions(-) diff --git a/crates/rayexec_execution/src/functions/aggregate/builtin/first.rs b/crates/rayexec_execution/src/functions/aggregate/builtin/first.rs index 435b4d848..12192f7da 100644 --- a/crates/rayexec_execution/src/functions/aggregate/builtin/first.rs +++ b/crates/rayexec_execution/src/functions/aggregate/builtin/first.rs @@ -1,53 +1,38 @@ -use std::borrow::Borrow; -use std::fmt::{self, Debug}; +use std::fmt::Debug; use std::marker::PhantomData; -use half::f16; use rayexec_error::{not_implemented, Result}; -use crate::arrays::array::ArrayData2; use crate::arrays::buffer::physical_type::{ AddressableMut, MutablePhysicalStorage, PhysicalBinary, + PhysicalBool, + PhysicalF16, + PhysicalF32, + PhysicalF64, + PhysicalI128, + PhysicalI16, + PhysicalI32, + PhysicalI64, + PhysicalI8, + PhysicalInterval, PhysicalType, + PhysicalU128, + PhysicalU16, + PhysicalU32, + PhysicalU64, + PhysicalU8, + PhysicalUntypedNull, + PhysicalUtf8, }; -use crate::arrays::datatype::{DataType, DataTypeId}; -use crate::arrays::executor::aggregate::{AggregateState2, StateFinalizer}; -use crate::arrays::executor::builder::{ArrayBuilder, GermanVarlenBuffer}; -use crate::arrays::executor::physical_type::{ - PhysicalBinary_2, - PhysicalBool_2, - PhysicalF16_2, - PhysicalF32_2, - PhysicalF64_2, - PhysicalI128_2, - PhysicalI16_2, - PhysicalI32_2, - PhysicalI64_2, - PhysicalI8_2, - PhysicalInterval_2, - PhysicalStorage2, - PhysicalType2, - PhysicalU128_2, - PhysicalU16_2, - PhysicalU32_2, - PhysicalU64_2, - PhysicalU8_2, - PhysicalUntypedNull_2, -}; +use crate::arrays::datatype::DataTypeId; use crate::arrays::executor_exp::aggregate::AggregateState; use crate::arrays::executor_exp::PutBuffer; -use crate::arrays::scalar::interval::Interval; -use crate::arrays::storage::{PrimitiveStorage, UntypedNull}; use crate::expr::Expression; use crate::functions::aggregate::states::{ - boolean_finalize, drain, - new_unary_aggregate_states2, - primitive_finalize, unary_update, - untyped_null_finalize, AggregateGroupStates, TypedAggregateGroupStates, }; @@ -94,67 +79,27 @@ impl AggregateFunction for First { let datatype = inputs[0].datatype(table_list)?; let function_impl: Box = match datatype.physical_type() { - // PhysicalType::Boolean + PhysicalType::UntypedNull => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Boolean => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Int8 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Int16 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Int32 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Int64 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Int128 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::UInt8 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::UInt16 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::UInt32 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::UInt64 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::UInt128 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Float16 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Float32 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Float64 => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Interval => Box::new(FirstPrimitiveImpl::::new()), + PhysicalType::Utf8 => Box::new(FirstStringImpl), + PhysicalType::Binary => Box::new(FirstBinaryImpl), other => not_implemented!("FIRST for physical type: {other}"), }; - // let function_impl: Box = match datatype.physical_type2()? { - // PhysicalType2::UntypedNull => Box::new(FirstUntypedNullImpl), - // PhysicalType2::Boolean => Box::new(FirstBoolImpl), - // PhysicalType2::Float16 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float32 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float64 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int8 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int16 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int32 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int64 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int128 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt8 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt16 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt32 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt64 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt128 => Box::new(FirstPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Interval => Box::new( - // FirstPrimitiveImpl::::new(datatype.clone()), - // ), - // PhysicalType2::Binary => Box::new(FirstBinaryImpl { - // datatype: datatype.clone(), - // }), - // PhysicalType2::Utf8 => Box::new(FirstBinaryImpl { - // datatype: datatype.clone(), - // }), - // PhysicalType2::List => { - // // TODO: Easy, clone underlying array and select. - // not_implemented!("FIRST for list arrays") - // } - // }; - Ok(PlannedAggregateFunction { function: Box::new(*self), return_type: datatype, @@ -169,6 +114,12 @@ pub struct FirstPrimitiveImpl { _s: PhantomData, } +impl FirstPrimitiveImpl { + const fn new() -> Self { + FirstPrimitiveImpl { _s: PhantomData } + } +} + impl AggregateFunctionImpl for FirstPrimitiveImpl where S: MutablePhysicalStorage, @@ -183,18 +134,31 @@ where } } -// #[derive(Debug, Clone, Copy)] -// pub struct FirstBinaryImpl; +#[derive(Debug, Clone, Copy)] +pub struct FirstBinaryImpl; -// impl AggregateFunctionImpl for FirstBinaryImpl { -// fn new_states(&self) -> Box { -// Box::new(TypedAggregateGroupStates::new( -// FirstBinaryState::default, -// unary_update::, -// drain::, -// )) -// } -// } +impl AggregateFunctionImpl for FirstBinaryImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + FirstBinaryState::default, + unary_update::, + drain::, + )) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FirstStringImpl; + +impl AggregateFunctionImpl for FirstStringImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + FirstStringState::default, + unary_update::, + drain::, + )) + } +} #[derive(Debug, Default)] pub struct FirstPrimitiveState { diff --git a/crates/rayexec_execution/src/functions/aggregate/builtin/minmax.rs b/crates/rayexec_execution/src/functions/aggregate/builtin/minmax.rs index 448e18fc1..1b242f11a 100644 --- a/crates/rayexec_execution/src/functions/aggregate/builtin/minmax.rs +++ b/crates/rayexec_execution/src/functions/aggregate/builtin/minmax.rs @@ -1,17 +1,13 @@ -use std::borrow::Borrow; use std::fmt::Debug; use std::marker::PhantomData; -use half::f16; use rayexec_error::{not_implemented, Result}; -use crate::arrays::array::ArrayData2; use crate::arrays::buffer::physical_type::{ AddressableMut, MutablePhysicalStorage, PhysicalBinary, PhysicalBool, - PhysicalDictionary, PhysicalF16, PhysicalF32, PhysicalF64, @@ -21,8 +17,6 @@ use crate::arrays::buffer::physical_type::{ PhysicalI64, PhysicalI8, PhysicalInterval, - PhysicalList, - PhysicalStorage, PhysicalType, PhysicalU128, PhysicalU16, @@ -32,42 +26,13 @@ use crate::arrays::buffer::physical_type::{ PhysicalUntypedNull, PhysicalUtf8, }; -use crate::arrays::datatype::{DataType, DataTypeId}; -use crate::arrays::executor::aggregate::{AggregateState2, StateFinalizer}; -use crate::arrays::executor::builder::{ArrayBuilder, GermanVarlenBuffer}; -use crate::arrays::executor::physical_type::{ - PhysicalBinary_2, - PhysicalBool_2, - PhysicalF16_2, - PhysicalF32_2, - PhysicalF64_2, - PhysicalI128_2, - PhysicalI16_2, - PhysicalI32_2, - PhysicalI64_2, - PhysicalI8_2, - PhysicalInterval_2, - PhysicalStorage2, - PhysicalType2, - PhysicalU128_2, - PhysicalU16_2, - PhysicalU32_2, - PhysicalU64_2, - PhysicalU8_2, - PhysicalUntypedNull_2, -}; +use crate::arrays::datatype::DataTypeId; use crate::arrays::executor_exp::aggregate::AggregateState; use crate::arrays::executor_exp::PutBuffer; -use crate::arrays::scalar::interval::Interval; -use crate::arrays::storage::{PrimitiveStorage, UntypedNull}; use crate::expr::Expression; use crate::functions::aggregate::states::{ - boolean_finalize, drain, - new_unary_aggregate_states2, - primitive_finalize, unary_update, - untyped_null_finalize, AggregateGroupStates, TypedAggregateGroupStates, }; @@ -113,65 +78,34 @@ impl AggregateFunction for Min { let datatype = inputs[0].datatype(table_list)?; - unimplemented!() - // let function_impl: Box = match datatype.physical_type2()? { - // PhysicalType2::UntypedNull => Box::new(MinMaxUntypedNull), - // PhysicalType2::Boolean => Box::new(MinBoolImpl::new()), - // PhysicalType2::Float16 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float32 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float64 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int8 => { - // Box::new(MinPrimitiveImpl::::new(datatype.clone())) - // } - // PhysicalType2::Int16 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int32 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int64 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int128 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt8 => { - // Box::new(MinPrimitiveImpl::::new(datatype.clone())) - // } - // PhysicalType2::UInt16 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt32 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt64 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt128 => Box::new(MinPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Interval => Box::new( - // MinPrimitiveImpl::::new(datatype.clone()), - // ), - // PhysicalType2::Binary => Box::new(MinBinaryImpl::new(datatype.clone())), - // PhysicalType2::Utf8 => Box::new(MinBinaryImpl::new(datatype.clone())), - // PhysicalType2::List => { - // not_implemented!("MIN for list arrays") - // } - // }; - - // Ok(PlannedAggregateFunction { - // function: Box::new(*self), - // return_type: datatype, - // inputs, - // function_impl, - // }) + let function_impl: Box = match datatype.physical_type() { + PhysicalType::UntypedNull => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Boolean => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Int8 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Int16 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Int32 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Int64 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Int128 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::UInt8 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::UInt16 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::UInt32 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::UInt64 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::UInt128 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Float16 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Float32 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Float64 => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Interval => Box::new(MinPrimitiveImpl::::new()), + PhysicalType::Utf8 => Box::new(MinStringImpl), + PhysicalType::Binary => Box::new(MinBinaryImpl), + other => not_implemented!("max for type {other:?}"), + }; + + Ok(PlannedAggregateFunction { + function: Box::new(*self), + return_type: datatype, + inputs, + function_impl, + }) } } @@ -225,325 +159,288 @@ impl AggregateFunction for Max { PhysicalType::Float32 => Box::new(MaxPrimitiveImpl::::new()), PhysicalType::Float64 => Box::new(MaxPrimitiveImpl::::new()), PhysicalType::Interval => Box::new(MaxPrimitiveImpl::::new()), - // PhysicalType::Utf8 => Box::new(MaxImpl::::new()), - // PhysicalType::Binary => Box::new(MaxImpl::::new()), + PhysicalType::Utf8 => Box::new(MaxStringImpl), + PhysicalType::Binary => Box::new(MaxBinaryImpl), other => not_implemented!("max for type {other:?}"), }; - // let function_impl: Box = match datatype.physical_type2()? { - // PhysicalType2::UntypedNull => Box::new(MinMaxUntypedNull), - // PhysicalType2::Boolean => Box::new(MaxBoolImpl::new()), - // PhysicalType2::Float16 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float32 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Float64 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int8 => { - // Box::new(MaxPrimitiveImpl::::new(datatype.clone())) - // } - // PhysicalType2::Int16 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int32 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int64 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Int128 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt8 => { - // Box::new(MaxPrimitiveImpl::::new(datatype.clone())) - // } - // PhysicalType2::UInt16 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt32 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt64 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::UInt128 => Box::new(MaxPrimitiveImpl::::new( - // datatype.clone(), - // )), - // PhysicalType2::Interval => Box::new( - // MaxPrimitiveImpl::::new(datatype.clone()), - // ), - // PhysicalType2::Binary => Box::new(MaxBinaryImpl::new(datatype.clone())), - // PhysicalType2::Utf8 => Box::new(MaxBinaryImpl::new(datatype.clone())), - // PhysicalType2::List => { - // not_implemented!("MAX for list arrays") - // } - // }; - unimplemented!() - - // Ok(PlannedAggregateFunction { - // function: Box::new(*self), - // return_type: datatype, - // inputs, - // function_impl, - // }) - } -} - -#[derive(Debug, Clone)] -pub struct MinMaxUntypedNull; - -impl AggregateFunctionImpl for MinMaxUntypedNull { - fn new_states(&self) -> Box { - // Note min vs max doesn't matter. Everything is null. - new_unary_aggregate_states2::( - MinState::::default, - untyped_null_finalize, - ) + Ok(PlannedAggregateFunction { + function: Box::new(*self), + return_type: datatype, + inputs, + function_impl, + }) } } -pub type MinBinaryImpl = MinMaxBinaryImpl; -pub type MaxBinaryImpl = MinMaxBinaryImpl; - -#[derive(Debug)] -pub struct MinMaxBinaryImpl { - datatype: DataType, - _m: PhantomData, +#[derive(Debug, Clone, Copy)] +pub struct MaxPrimitiveImpl { + _s: PhantomData, } -impl MinMaxBinaryImpl { - fn new(datatype: DataType) -> Self { - MinMaxBinaryImpl { - datatype, - _m: PhantomData, - } +impl MaxPrimitiveImpl { + const fn new() -> Self { + MaxPrimitiveImpl { _s: PhantomData } } } -impl AggregateFunctionImpl for MinMaxBinaryImpl +impl AggregateFunctionImpl for MaxPrimitiveImpl where - M: for<'a> AggregateState2<&'a [u8], Vec> + Default + Sync + Send + 'static, + S: MutablePhysicalStorage, + S::StorageType: Default + Debug + Sync + Send + PartialOrd + Copy, { fn new_states(&self) -> Box { - let datatype = self.datatype.clone(); - - new_unary_aggregate_states2::(M::default, move |states| { - let builder = ArrayBuilder { - datatype: datatype.clone(), - buffer: GermanVarlenBuffer::<[u8]>::with_len(states.len()), - }; - StateFinalizer::finalize(states, builder) - }) + Box::new(TypedAggregateGroupStates::new( + MaxStatePrimitive::::default, + unary_update::, + drain::, + )) } } -impl Clone for MinMaxBinaryImpl { - fn clone(&self) -> Self { - Self::new(self.datatype.clone()) +#[derive(Debug, Clone, Copy)] +pub struct MaxBinaryImpl; + +impl AggregateFunctionImpl for MaxBinaryImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + MaxStateBinary::default, + unary_update::, + drain::, + )) } } -pub type MinBoolImpl = MinMaxBoolImpl>; -pub type MaxBoolImpl = MinMaxBoolImpl>; +#[derive(Debug, Clone, Copy)] +pub struct MaxStringImpl; -#[derive(Debug)] -pub struct MinMaxBoolImpl { - _m: PhantomData, +impl AggregateFunctionImpl for MaxStringImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + MaxStateString::default, + unary_update::, + drain::, + )) + } } -impl MinMaxBoolImpl { - fn new() -> Self { - MinMaxBoolImpl { _m: PhantomData } - } +#[derive(Debug, Default)] +pub struct MaxStatePrimitive { + max: T, + valid: bool, } -impl AggregateFunctionImpl for MinMaxBoolImpl +impl AggregateState<&T, T> for MaxStatePrimitive where - M: AggregateState2 + Default + Sync + Send + 'static, + T: Debug + Sync + Send + PartialOrd + Copy, { - fn new_states(&self) -> Box { - new_unary_aggregate_states2::(M::default, move |states| { - boolean_finalize(DataType::Boolean, states) - }) - } -} - -impl Clone for MinMaxBoolImpl { - fn clone(&self) -> Self { - Self::new() - } -} + fn merge(&mut self, other: &mut Self) -> Result<()> { + if !self.valid { + self.valid = other.valid; + std::mem::swap(&mut self.max, &mut other.max); + return Ok(()); + } -pub type MinPrimitiveImpl = MinMaxPrimitiveImpl, S, T>; -pub type MaxPrimitiveImpl2 = MinMaxPrimitiveImpl, S, T>; + if self.max.lt(&other.max) { + std::mem::swap(&mut self.max, &mut other.max); + } -// TODO: Remove T -#[derive(Debug)] -pub struct MinMaxPrimitiveImpl { - datatype: DataType, - _m: PhantomData, - _s: PhantomData, - _t: PhantomData, -} + Ok(()) + } -impl MinMaxPrimitiveImpl { - fn new(datatype: DataType) -> Self { - MinMaxPrimitiveImpl { - datatype, - _m: PhantomData, - _s: PhantomData, - _t: PhantomData, + fn update(&mut self, input: &T) -> Result<()> { + if !self.valid { + self.max = *input; + return Ok(()); } - } -} -impl AggregateFunctionImpl for MinMaxPrimitiveImpl -where - for<'a> S: PhysicalStorage2 = T>, - T: PartialOrd + Debug + Default + Sync + Send + Copy + 'static, - M: AggregateState2 + Default + Sync + Send + 'static, - ArrayData2: From>, -{ - fn new_states(&self) -> Box { - let datatype = self.datatype.clone(); + if self.max.lt(input) { + self.max = *input; + } - new_unary_aggregate_states2::(M::default, move |states| { - primitive_finalize(datatype.clone(), states) - }) + Ok(()) } -} -impl Clone for MinMaxPrimitiveImpl { - fn clone(&self) -> Self { - Self::new(self.datatype.clone()) + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { + if self.valid { + output.put(&self.max); + } else { + output.put_null(); + } + + Ok(()) } } #[derive(Debug, Default)] -pub struct MinState { - min: T, +pub struct MaxStateBinary { + max: Vec, valid: bool, } -impl AggregateState2 for MinState -where - T: PartialOrd + Debug + Default + Copy, -{ +impl AggregateState<&[u8], [u8]> for MaxStateBinary { fn merge(&mut self, other: &mut Self) -> Result<()> { if !self.valid { self.valid = other.valid; - self.min = other.min; - } else if other.valid && other.min < self.min { - self.min = other.min; + std::mem::swap(&mut self.max, &mut other.max); + return Ok(()); + } + + if self.max.lt(&other.max) { + std::mem::swap(&mut self.max, &mut other.max); } Ok(()) } - fn update(&mut self, input: T) -> Result<()> { + fn update(&mut self, input: &[u8]) -> Result<()> { if !self.valid { - self.valid = true; - self.min = input; - } else if input < self.min { - self.min = input + self.max = input.to_vec(); + return Ok(()); } + + if self.max.as_slice().lt(input) { + self.max = input.to_vec(); + } + Ok(()) } - fn finalize(&mut self) -> Result<(T, bool)> { + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { if self.valid { - Ok((self.min, true)) + output.put(&self.max); } else { - Ok((T::default(), false)) + output.put_null(); } + + Ok(()) } } #[derive(Debug, Default)] -pub struct MinStateBinary { - min: Vec, +pub struct MaxStateString { + max: String, valid: bool, } -impl AggregateState2<&[u8], Vec> for MinStateBinary { +impl AggregateState<&str, str> for MaxStateString { fn merge(&mut self, other: &mut Self) -> Result<()> { if !self.valid { self.valid = other.valid; - std::mem::swap(&mut self.min, &mut other.min); - } else if other.valid && other.min < self.min { - std::mem::swap(&mut self.min, &mut other.min); + std::mem::swap(&mut self.max, &mut other.max); + return Ok(()); + } + + if self.max.lt(&other.max) { + std::mem::swap(&mut self.max, &mut other.max); } Ok(()) } - fn update(&mut self, input: &[u8]) -> Result<()> { + fn update(&mut self, input: &str) -> Result<()> { if !self.valid { - self.valid = true; - self.min = input.into(); - } else if input < self.min.as_slice() { - self.min = input.into(); + self.max = input.to_string(); + return Ok(()); + } + + if self.max.as_str().lt(input) { + self.max = input.to_string(); } Ok(()) } - fn finalize(&mut self) -> Result<(Vec, bool)> { + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { if self.valid { - Ok((std::mem::take(&mut self.min), true)) + output.put(&self.max); } else { - Ok((Vec::new(), false)) + output.put_null(); } + + Ok(()) } } #[derive(Debug, Clone, Copy)] -pub struct MaxPrimitiveImpl { +pub struct MinPrimitiveImpl { _s: PhantomData, } -impl MaxPrimitiveImpl { +impl MinPrimitiveImpl { const fn new() -> Self { - MaxPrimitiveImpl { _s: PhantomData } + MinPrimitiveImpl { _s: PhantomData } } } -impl AggregateFunctionImpl for MaxPrimitiveImpl +impl AggregateFunctionImpl for MinPrimitiveImpl where S: MutablePhysicalStorage, S::StorageType: Default + Debug + Sync + Send + PartialOrd + Copy, { fn new_states(&self) -> Box { Box::new(TypedAggregateGroupStates::new( - MaxStatePrimitive::::default, + MinStatePrimitive::::default, unary_update::, drain::, )) } } +#[derive(Debug, Clone, Copy)] +pub struct MinBinaryImpl; + +impl AggregateFunctionImpl for MinBinaryImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + MinStateBinary::default, + unary_update::, + drain::, + )) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct MinStringImpl; + +impl AggregateFunctionImpl for MinStringImpl { + fn new_states(&self) -> Box { + Box::new(TypedAggregateGroupStates::new( + MinStateString::default, + unary_update::, + drain::, + )) + } +} + #[derive(Debug, Default)] -pub struct MaxStatePrimitive { - max: T, +pub struct MinStatePrimitive { + min: T, valid: bool, } -impl AggregateState<&T, T> for MaxStatePrimitive +impl AggregateState<&T, T> for MinStatePrimitive where T: Debug + Sync + Send + PartialOrd + Copy, { fn merge(&mut self, other: &mut Self) -> Result<()> { if !self.valid { self.valid = other.valid; - std::mem::swap(&mut self.max, &mut other.max); + std::mem::swap(&mut self.min, &mut other.min); return Ok(()); } - if self.max.lt(&other.max) { - std::mem::swap(&mut self.max, &mut other.max); + if self.min.gt(&other.min) { + std::mem::swap(&mut self.min, &mut other.min); } Ok(()) @@ -551,12 +448,12 @@ where fn update(&mut self, input: &T) -> Result<()> { if !self.valid { - self.max = *input; + self.min = *input; return Ok(()); } - if self.max.lt(input) { - self.max = *input; + if self.min.gt(input) { + self.min = *input; } Ok(()) @@ -567,7 +464,7 @@ where M: AddressableMut, { if self.valid { - output.put(&self.max); + output.put(&self.min); } else { output.put_null(); } @@ -577,79 +474,97 @@ where } #[derive(Debug, Default)] -pub struct MaxState2 { - max: T, +pub struct MinStateBinary { + min: Vec, valid: bool, } -impl AggregateState2 for MaxState2 -where - T: PartialOrd + Debug + Default + Copy, -{ +impl AggregateState<&[u8], [u8]> for MinStateBinary { fn merge(&mut self, other: &mut Self) -> Result<()> { if !self.valid { self.valid = other.valid; - self.max = other.max; - } else if other.valid && other.max > self.max { - self.max = other.max; + std::mem::swap(&mut self.min, &mut other.min); + return Ok(()); } + + if self.min.gt(&other.min) { + std::mem::swap(&mut self.min, &mut other.min); + } + Ok(()) } - fn update(&mut self, input: T) -> Result<()> { + fn update(&mut self, input: &[u8]) -> Result<()> { if !self.valid { - self.valid = true; - self.max = input; - } else if input > self.max { - self.max = input + self.min = input.to_vec(); + return Ok(()); + } + + if self.min.as_slice().gt(input) { + self.min = input.to_vec(); } Ok(()) } - fn finalize(&mut self) -> Result<(T, bool)> { + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { if self.valid { - Ok((self.max, true)) + output.put(&self.min); } else { - Ok((T::default(), false)) + output.put_null(); } + + Ok(()) } } #[derive(Debug, Default)] -pub struct MaxStateBinary2 { - max: Vec, +pub struct MinStateString { + min: String, valid: bool, } -impl AggregateState2<&[u8], Vec> for MaxStateBinary2 { +impl AggregateState<&str, str> for MinStateString { fn merge(&mut self, other: &mut Self) -> Result<()> { if !self.valid { self.valid = other.valid; - std::mem::swap(&mut self.max, &mut other.max); - } else if other.valid && other.max > self.max { - std::mem::swap(&mut self.max, &mut other.max); + std::mem::swap(&mut self.min, &mut other.min); + return Ok(()); + } + + if self.min.gt(&other.min) { + std::mem::swap(&mut self.min, &mut other.min); } Ok(()) } - fn update(&mut self, input: &[u8]) -> Result<()> { + fn update(&mut self, input: &str) -> Result<()> { if !self.valid { - self.valid = true; - self.max = input.into(); - } else if input > self.max.as_slice() { - self.max = input.into(); + self.min = input.to_string(); + return Ok(()); + } + + if self.min.as_str().gt(input) { + self.min = input.to_string(); } Ok(()) } - fn finalize(&mut self) -> Result<(Vec, bool)> { + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { if self.valid { - Ok((std::mem::take(&mut self.max), true)) + output.put(&self.min); } else { - Ok((Vec::new(), false)) + output.put_null(); } + + Ok(()) } } diff --git a/crates/rayexec_execution/src/functions/aggregate/builtin/regr_count.rs b/crates/rayexec_execution/src/functions/aggregate/builtin/regr_count.rs index 29914d218..ddad9d65f 100644 --- a/crates/rayexec_execution/src/functions/aggregate/builtin/regr_count.rs +++ b/crates/rayexec_execution/src/functions/aggregate/builtin/regr_count.rs @@ -1,16 +1,14 @@ use std::fmt::Debug; +use std::marker::PhantomData; use rayexec_error::Result; +use crate::arrays::buffer::physical_type::{AddressableMut, PhysicalF64, PhysicalStorage}; use crate::arrays::datatype::{DataType, DataTypeId}; -use crate::arrays::executor::aggregate::AggregateState2; -use crate::arrays::executor::physical_type::PhysicalAny; +use crate::arrays::executor_exp::aggregate::AggregateState; +use crate::arrays::executor_exp::PutBuffer; use crate::expr::Expression; -use crate::functions::aggregate::states::{ - new_binary_aggregate_states2, - primitive_finalize, - AggregateGroupStates, -}; +use crate::functions::aggregate::states::AggregateGroupStates; use crate::functions::aggregate::{ AggregateFunction, AggregateFunctionImpl, @@ -59,7 +57,7 @@ impl AggregateFunction for RegrCount { function: Box::new(*self), return_type: DataType::Float64, inputs, - function_impl: Box::new(RegrCountImpl), + function_impl: Box::new(RegrCountImpl::::new()), }), (a, b) => Err(invalid_input_types_error(self, &[a, b])), } @@ -67,14 +65,22 @@ impl AggregateFunction for RegrCount { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct RegrCountImpl; +pub struct RegrCountImpl { + _s: PhantomData, +} + +impl RegrCountImpl { + const fn new() -> Self { + RegrCountImpl { _s: PhantomData } + } +} -impl AggregateFunctionImpl for RegrCountImpl { +impl AggregateFunctionImpl for RegrCountImpl +where + S: PhysicalStorage, +{ fn new_states(&self) -> Box { - new_binary_aggregate_states2::( - RegrCountState::default, - move |states| primitive_finalize(DataType::Int64, states), - ) + unimplemented!() } } @@ -83,22 +89,30 @@ impl AggregateFunctionImpl for RegrCountImpl { /// Note that this can be used for any input type, but the sql function we /// expose only accepts f64 (to match Postgres). #[derive(Debug, Clone, Copy, Default)] -pub struct RegrCountState { +pub struct RegrCountState { count: i64, + _s: PhantomData, } -impl AggregateState2<((), ()), i64> for RegrCountState { +impl AggregateState<&S::StorageType, i64> for RegrCountState +where + S: PhysicalStorage, +{ fn merge(&mut self, other: &mut Self) -> Result<()> { self.count += other.count; Ok(()) } - fn update(&mut self, _input: ((), ())) -> Result<()> { + fn update(&mut self, _input: &S::StorageType) -> Result<()> { self.count += 1; Ok(()) } - fn finalize(&mut self) -> Result<(i64, bool)> { - Ok((self.count, true)) + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { + output.put(&self.count); + Ok(()) } } diff --git a/crates/rayexec_execution/src/functions/aggregate/builtin/string_agg.rs b/crates/rayexec_execution/src/functions/aggregate/builtin/string_agg.rs index cd1ec0d65..b1407c091 100644 --- a/crates/rayexec_execution/src/functions/aggregate/builtin/string_agg.rs +++ b/crates/rayexec_execution/src/functions/aggregate/builtin/string_agg.rs @@ -2,13 +2,18 @@ use std::fmt::Debug; use rayexec_error::{RayexecError, Result}; +use crate::arrays::buffer::physical_type::{AddressableMut, PhysicalUtf8}; use crate::arrays::datatype::{DataType, DataTypeId}; -use crate::arrays::executor::aggregate::{AggregateState2, StateFinalizer}; -use crate::arrays::executor::builder::{ArrayBuilder, GermanVarlenBuffer}; -use crate::arrays::executor::physical_type::PhysicalUtf8_2; +use crate::arrays::executor_exp::aggregate::AggregateState; +use crate::arrays::executor_exp::PutBuffer; use crate::arrays::scalar::ScalarValue; use crate::expr::Expression; -use crate::functions::aggregate::states::{new_unary_aggregate_states2, AggregateGroupStates}; +use crate::functions::aggregate::states::{ + drain, + unary_update, + AggregateGroupStates, + TypedAggregateGroupStates, +}; use crate::functions::aggregate::{ AggregateFunction, AggregateFunctionImpl, @@ -99,13 +104,11 @@ impl AggregateFunctionImpl for StringAggImpl { string: None, }; - new_unary_aggregate_states2::(state_init, move |states| { - let builder = ArrayBuilder { - datatype: DataType::Utf8, - buffer: GermanVarlenBuffer::::with_len(states.len()), - }; - StateFinalizer::finalize(states, builder) - }) + Box::new(TypedAggregateGroupStates::new( + state_init, + unary_update::, + drain::, + )) } } @@ -119,7 +122,7 @@ pub struct StringAggState { string: Option, } -impl AggregateState2<&str, String> for StringAggState { +impl AggregateState<&str, str> for StringAggState { fn merge(&mut self, other: &mut Self) -> Result<()> { if self.string.is_none() { std::mem::swap(self, other); @@ -148,10 +151,14 @@ impl AggregateState2<&str, String> for StringAggState { Ok(()) } - fn finalize(&mut self) -> Result<(String, bool)> { - match self.string.take() { - Some(s) => Ok((s, true)), - None => Ok((String::new(), false)), + fn finalize(&mut self, output: PutBuffer) -> Result<()> + where + M: AddressableMut, + { + match &self.string { + Some(s) => output.put(s), + None => output.put_null(), } + Ok(()) } } diff --git a/crates/stdutil/src/marker.rs b/crates/stdutil/src/marker.rs index 27a8868b4..956ddfd7e 100644 --- a/crates/stdutil/src/marker.rs +++ b/crates/stdutil/src/marker.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; /// bounds. This lets us make structs and other types covariant to `T` but /// without the potential inheritence of `?Sized` (or other undesired traits) in /// the outer type. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct PhantomCovariant(PhantomData T>) where T: ?Sized; @@ -20,3 +20,23 @@ where PhantomCovariant(PhantomData) } } + +impl Clone for PhantomCovariant +where + T: ?Sized, +{ + fn clone(&self) -> Self { + Self::new() + } +} + +impl Copy for PhantomCovariant where T: ?Sized {} + +impl Default for PhantomCovariant +where + T: ?Sized, +{ + fn default() -> Self { + Self::new() + } +}