From 0a2ef1bd1e3f245189c3fddaa72c3869ffe16c6d Mon Sep 17 00:00:00 2001 From: Tom Ellis Date: Sun, 21 Jan 2024 09:41:56 +0000 Subject: [PATCH] Avoid rebinding PrimExprs that were created in the Aggregator itself It's redundant because we only need to rebind PrimExprs that came from a lateral subquery. As explained at [1], rather than carefully analysing which PrimExprs came from a lateral subquery we can just rebind everything. A previous commit [2] changed things so that everything mentioned in aggregator order expressions is rebound. Instead we probably should have done what this commit does, that is, added an Unpackspec constraint to aggregate. Fixes https://github.com/tomjaguarpaw/haskell-opaleye/issues/587 This is a simpler approach to resolving the issues discussed at * https://github.com/tomjaguarpaw/haskell-opaleye/pull/585 * https://github.com/tomjaguarpaw/haskell-opaleye/pull/578 This still suffers from the problem described at: https://github.com/tomjaguarpaw/haskell-opaleye/pull/578#issuecomment-1782638274 i.e. if we find a way of duplicating field names, such as O.asc (\x -> snd x O..++ snd x) then we can still create crashing queries. The benefit of this comment is that there is a way of generating non-crashing queries! [1] https://github.com/tomjaguarpaw/haskell-opaleye/blob/52a4063188dd617ff91050dc0f2e27fc0570633c/src/Opaleye/Internal/Aggregate.hs#L111-L114 [2] d8483176cc15c32f2fd833aee960bfd0dce1360d, part of https://github.com/tomjaguarpaw/haskell-opaleye/pull/576 --- Test/Opaleye/Test/Arbitrary.hs | 5 ++--- Test/QuickCheck.hs | 4 ++-- Test/Test.hs | 4 ++-- src/Opaleye/Aggregate.hs | 22 ++++++++++++++------- src/Opaleye/Distinct.hs | 4 +++- src/Opaleye/Internal/Aggregate.hs | 32 ++++++++++--------------------- src/Opaleye/Internal/Distinct.hs | 8 +++++--- src/Opaleye/Internal/PrimQuery.hs | 4 ++-- src/Opaleye/Internal/Sql.hs | 4 ++-- 9 files changed, 43 insertions(+), 44 deletions(-) diff --git a/Test/Opaleye/Test/Arbitrary.hs b/Test/Opaleye/Test/Arbitrary.hs index d1eeb4a27..6401e9de7 100644 --- a/Test/Opaleye/Test/Arbitrary.hs +++ b/Test/Opaleye/Test/Arbitrary.hs @@ -471,7 +471,7 @@ genSelectArr = genSelectMapper :: [TQ.Gen (O.Select Fields -> O.Select Fields)] genSelectMapper = [ do - return (O.distinctExplicit distinctFields) + return (O.distinctExplicit unpackFields distinctFields) , do ArbitraryPositiveInt l <- TQ.arbitrary return (O.limit l) @@ -481,9 +481,8 @@ genSelectMapper = , do o <- TQ.arbitrary return (O.orderBy (arbitraryOrder o)) - , do - return (O.aggregate aggregateFields) + return (O.aggregateExplicit unpackFields aggregateFields) , do let q' q = P.dimap (\_ -> fst . firstBoolOrTrue (O.sqlBool True)) (fieldsList diff --git a/Test/QuickCheck.hs b/Test/QuickCheck.hs index 6ec4e4a4e..6738fd278 100644 --- a/Test/QuickCheck.hs +++ b/Test/QuickCheck.hs @@ -432,7 +432,7 @@ order o (ArbitrarySelect q) = distinct :: ArbitrarySelect -> Connection -> IO TQ.Property distinct = - compareDenotation' (O.distinctExplicit distinctFields) nub + compareDenotation' (O.distinctExplicit unpackFields distinctFields) nub -- When we generalise compareDenotation... we can just test -- @@ -455,7 +455,7 @@ valuesEmpty l = aggregate :: ArbitrarySelect -> Connection -> IO TQ.Property aggregate = - compareDenotationNoSort' (O.aggregate aggregateFields) + compareDenotationNoSort' (O.aggregateExplicit unpackFields aggregateFields) aggregateDenotation diff --git a/Test/Test.hs b/Test/Test.hs index 77ed58b33..658457042 100644 --- a/Test/Test.hs +++ b/Test/Test.hs @@ -620,7 +620,7 @@ testStringArrayAggregateOrdered = it "" $ q `selectShouldReturnSorted` expected testStringArrayAggregateOrderedDistinct :: Test -testStringArrayAggregateOrderedDistinct = xit "" $ q `selectShouldReturnSorted` expected +testStringArrayAggregateOrderedDistinct = it "" $ q `selectShouldReturnSorted` expected where q = O.aggregateOrdered (O.asc snd) @@ -1477,7 +1477,7 @@ testUnnest = do testSetAggregate :: Test testSetAggregate = do - xit "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation]) + it "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation]) where query :: Select (Field O.SqlFloat8) query = O.aggregate median (O.values as) diff --git a/src/Opaleye/Aggregate.hs b/src/Opaleye/Aggregate.hs index 6b00f9f8c..f6e936aa7 100644 --- a/src/Opaleye/Aggregate.hs +++ b/src/Opaleye/Aggregate.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} -- | Perform aggregation on 'S.Select's. To aggregate a 'S.Select' you -- should construct an 'Aggregator' encoding how you want the @@ -33,11 +34,14 @@ module Opaleye.Aggregate , stringAgg -- * Counting rows , countRows + -- * Explicit + , aggregateExplicit ) where -import Control.Arrow (second) +import Control.Arrow (second, (<<<)) import Data.Profunctor (lmap) import qualified Data.Profunctor as P +import qualified Data.Profunctor.Product.Default as D import qualified Opaleye.Internal.Aggregate as A import Opaleye.Internal.Aggregate (Aggregator, orderAggregate) @@ -46,6 +50,7 @@ import qualified Opaleye.Internal.QueryArr as Q import qualified Opaleye.Internal.HaskellDB.PrimQuery as HPQ import qualified Opaleye.Internal.Operators as O import qualified Opaleye.Internal.PackMap as PM +import Opaleye.Internal.Rebind (rebindExplicit) import qualified Opaleye.Internal.Tag as Tag import qualified Opaleye.Internal.Unpackspec as U @@ -85,11 +90,8 @@ result of an aggregation. -} -- See 'Opaleye.Internal.Sql.aggregate' for details of how aggregating -- by an empty query with no group by is handled. -aggregate :: Aggregator a b -> S.Select a -> S.Select b -aggregate agg q = Q.productQueryArr $ do - (a, pq) <- Q.runSimpleSelect q - t <- Tag.fresh - pure (second ($ pq) (A.aggregateU agg (a, t))) +aggregate :: D.Default U.Unpackspec a a => Aggregator a b -> S.Select a -> S.Select b +aggregate = aggregateExplicit D.def -- | Order the values within each aggregation in `Aggregator` using -- the given ordering. This is only relevant for aggregations that @@ -100,7 +102,13 @@ aggregate agg q = Q.productQueryArr $ do -- you need different orderings for different aggregations, use -- 'Opaleye.Internal.Aggregate.orderAggregate'. -aggregateOrdered :: Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b +aggregateExplicit :: U.Unpackspec a a' -> Aggregator a' b -> S.Select a -> S.Select b +aggregateExplicit u agg q = Q.productQueryArr $ do + (a, pq) <- Q.runSimpleSelect (rebindExplicit u <<< q) + t <- Tag.fresh + pure (second ($ pq) (A.aggregateU agg (a, t))) + +aggregateOrdered :: D.Default U.Unpackspec a a => Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b aggregateOrdered o agg = aggregate (orderAggregate o agg) -- | Aggregate only distinct values diff --git a/src/Opaleye/Distinct.hs b/src/Opaleye/Distinct.hs index adfecd7fb..299e6c3b4 100644 --- a/src/Opaleye/Distinct.hs +++ b/src/Opaleye/Distinct.hs @@ -18,6 +18,7 @@ import Opaleye.Internal.Distinct import Opaleye.Order import qualified Data.Profunctor.Product.Default as D +import Opaleye.Internal.Unpackspec (Unpackspec) -- | Remove duplicate rows from the 'Select'. -- @@ -40,5 +41,6 @@ import qualified Data.Profunctor.Product.Default as D -- 'Opaleye.Lateral.laterally' 'distinct' :: 'Data.Profunctor.Product.Default' 'Distinctspec' fields fields => 'Opaleye.Select.SelectArr' i fields -> 'Opaleye.Select.SelectArr' i fields -- @ distinct :: D.Default Distinctspec fields fields => + D.Default Unpackspec fields fields => Select fields -> Select fields -distinct = distinctExplicit D.def +distinct = distinctExplicit D.def D.def diff --git a/src/Opaleye/Internal/Aggregate.hs b/src/Opaleye/Internal/Aggregate.hs index 2f31aed02..eb69f7b75 100644 --- a/src/Opaleye/Internal/Aggregate.hs +++ b/src/Opaleye/Internal/Aggregate.hs @@ -135,37 +135,25 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) -> aggregateU :: Aggregator a b -> (a, T.Tag) -> (b, PQ.PrimQuery -> PQ.PrimQuery) aggregateU agg (c0, t0) = (c1, primQ') - where (c1, projPEs_inners) = + where projPEs_inners :: PQ.Bindings HPQ.Aggregate + (c1, projPEs_inners) = PM.run (runAggregator agg (extractAggregateFields t0) c0) - projPEs = map fst projPEs_inners - inners = concatMap snd projPEs_inners + projPEs = projPEs_inners - primQ' = PQ.Aggregate projPEs . PQ.Rebind True inners + primQ' = PQ.Aggregate projPEs extractAggregateFields - :: Traversable t - => T.Tag - -> t HPQ.PrimExpr - -> PM.PM [((HPQ.Symbol, - t HPQ.Symbol), - PQ.Bindings HPQ.PrimExpr)] - HPQ.PrimExpr + :: T.Tag + -> HPQ.Aggregate + -> PM.PM (PQ.Bindings HPQ.Aggregate) HPQ.PrimExpr extractAggregateFields tag agg = do i <- PM.new + let sinner = HPQ.Symbol ("result" ++ i) tag - let souter = HPQ.Symbol ("result" ++ i) tag + PM.write (sinner, agg) - bindings <- for agg $ \pe -> do - j <- PM.new - let sinner = HPQ.Symbol ("inner" ++ j) tag - pure (sinner, pe) - - let agg' = fmap fst bindings - - PM.write ((souter, agg'), toList bindings) - - pure (HPQ.AttrExpr souter) + pure (HPQ.AttrExpr sinner) unsafeMax :: Aggregator (C.Field a) (C.Field a) unsafeMax = makeAggr HPQ.AggrMax diff --git a/src/Opaleye/Internal/Distinct.hs b/src/Opaleye/Internal/Distinct.hs index 888b67859..9eb6346c7 100644 --- a/src/Opaleye/Internal/Distinct.hs +++ b/src/Opaleye/Internal/Distinct.hs @@ -5,19 +5,21 @@ module Opaleye.Internal.Distinct where import qualified Opaleye.Internal.MaybeFields as M import Opaleye.Select (Select) import Opaleye.Field (Field_) -import Opaleye.Aggregate (Aggregator, groupBy, aggregate) +import Opaleye.Aggregate (Aggregator, groupBy, aggregateExplicit) import qualified Data.Profunctor as P import qualified Data.Profunctor.Product as PP import Data.Profunctor.Product.Default (Default, def) +import Opaleye.Internal.Unpackspec (Unpackspec) -- We implement distinct simply by grouping by all columns. We could -- instead implement it as SQL's DISTINCT but implementing it in terms -- of something else that we already have is easier at this point. -distinctExplicit :: Distinctspec fields fields' +distinctExplicit :: Unpackspec fields fields + -> Distinctspec fields fields' -> Select fields -> Select fields' -distinctExplicit (Distinctspec agg) = aggregate agg +distinctExplicit u (Distinctspec agg) = aggregateExplicit u agg newtype Distinctspec a b = Distinctspec (Aggregator a b) diff --git a/src/Opaleye/Internal/PrimQuery.hs b/src/Opaleye/Internal/PrimQuery.hs index 88a557a2b..d19cdabd9 100644 --- a/src/Opaleye/Internal/PrimQuery.hs +++ b/src/Opaleye/Internal/PrimQuery.hs @@ -133,7 +133,7 @@ data PrimQuery' a = Unit | Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr] -- | The subqueries to take the product of and the -- restrictions to apply - | Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol)) + | Aggregate (Bindings HPQ.Aggregate) (PrimQuery' a) | Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a) -- | Represents both @DISTINCT ON@ and @ORDER BY@ @@ -178,7 +178,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold , empty :: a -> p' , baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p' , product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p' - , aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol) + , aggregate :: Bindings HPQ.Aggregate -> p -> p' , window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p' diff --git a/src/Opaleye/Internal/Sql.hs b/src/Opaleye/Internal/Sql.hs index db25d84ff..e00a4ab34 100644 --- a/src/Opaleye/Internal/Sql.hs +++ b/src/Opaleye/Internal/Sql.hs @@ -160,7 +160,7 @@ product ss pes = SelectFrom $ PQ.Lateral -> Lateral PQ.NonLateral -> NonLateral -aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol) +aggregate :: PQ.Bindings HPQ.Aggregate -> Select -> Select aggregate aggrs' s = @@ -191,7 +191,7 @@ aggregate aggrs' s = handleEmpty = ensureColumnsGen SP.deliteral aggrs :: [(Symbol, HPQ.Aggregate)] - aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs' + aggrs = aggrs' groupBy' :: [(symbol, HPQ.Aggregate)] -> NEL.NonEmpty HSql.SqlExpr