Skip to content

Commit

Permalink
Avoid rebinding PrimExprs that were created in the Aggregator itself
Browse files Browse the repository at this point in the history
It's redundant because we only need to rebind PrimExprs that came from
a lateral subquery.  As explained at [1], rather than carefully
analysing which PrimExprs came from a lateral subquery we can just
rebind everything.  A previous commit [2] changed things so that
everything mentioned in aggregator order expressions is rebound.
Instead we probably should have done what this commit does, that is,
added an Unpackspec constraint to aggregate.

Fixes #587

This is a simpler approach to resolving the issues discussed at

* #585
* #578

This still suffers from the problem described at:

#578 (comment)

i.e. if we find a way of duplicating field names, such as

    O.asc (\x -> snd x O..++ snd x)

then we can still create crashing queries.  The benefit of this
comment is that there is a way of generating non-crashing queries!

[1] https://github.com/tomjaguarpaw/haskell-opaleye/blob/52a4063188dd617ff91050dc0f2e27fc0570633c/src/Opaleye/Internal/Aggregate.hs#L111-L114

[2] d848317, part of
    #576
  • Loading branch information
tomjaguarpaw committed Jan 21, 2024
1 parent 4dcf913 commit 0a2ef1b
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 44 deletions.
5 changes: 2 additions & 3 deletions Test/Opaleye/Test/Arbitrary.hs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ genSelectArr =
genSelectMapper :: [TQ.Gen (O.Select Fields -> O.Select Fields)]
genSelectMapper =
[ do
return (O.distinctExplicit distinctFields)
return (O.distinctExplicit unpackFields distinctFields)
, do
ArbitraryPositiveInt l <- TQ.arbitrary
return (O.limit l)
Expand All @@ -481,9 +481,8 @@ genSelectMapper =
, do
o <- TQ.arbitrary
return (O.orderBy (arbitraryOrder o))

, do
return (O.aggregate aggregateFields)
return (O.aggregateExplicit unpackFields aggregateFields)
, do
let q' q = P.dimap (\_ -> fst . firstBoolOrTrue (O.sqlBool True))
(fieldsList
Expand Down
4 changes: 2 additions & 2 deletions Test/QuickCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ order o (ArbitrarySelect q) =

distinct :: ArbitrarySelect -> Connection -> IO TQ.Property
distinct =
compareDenotation' (O.distinctExplicit distinctFields) nub
compareDenotation' (O.distinctExplicit unpackFields distinctFields) nub

-- When we generalise compareDenotation... we can just test
--
Expand All @@ -455,7 +455,7 @@ valuesEmpty l =

aggregate :: ArbitrarySelect -> Connection -> IO TQ.Property
aggregate =
compareDenotationNoSort' (O.aggregate aggregateFields)
compareDenotationNoSort' (O.aggregateExplicit unpackFields aggregateFields)
aggregateDenotation


Expand Down
4 changes: 2 additions & 2 deletions Test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ testStringArrayAggregateOrdered = it "" $ q `selectShouldReturnSorted` expected


testStringArrayAggregateOrderedDistinct :: Test
testStringArrayAggregateOrderedDistinct = xit "" $ q `selectShouldReturnSorted` expected
testStringArrayAggregateOrderedDistinct = it "" $ q `selectShouldReturnSorted` expected
where q =
O.aggregateOrdered
(O.asc snd)
Expand Down Expand Up @@ -1477,7 +1477,7 @@ testUnnest = do

testSetAggregate :: Test
testSetAggregate = do
xit "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation])
it "set aggregate (percentile_cont)" $ testH query (`shouldBe` [expectation])
where query :: Select (Field O.SqlFloat8)
query = O.aggregate median (O.values as)

Expand Down
22 changes: 15 additions & 7 deletions src/Opaleye/Aggregate.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}

-- | Perform aggregation on 'S.Select's. To aggregate a 'S.Select' you
-- should construct an 'Aggregator' encoding how you want the
Expand Down Expand Up @@ -33,11 +34,14 @@ module Opaleye.Aggregate
, stringAgg
-- * Counting rows
, countRows
-- * Explicit
, aggregateExplicit
) where

import Control.Arrow (second)
import Control.Arrow (second, (<<<))
import Data.Profunctor (lmap)
import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product.Default as D

import qualified Opaleye.Internal.Aggregate as A
import Opaleye.Internal.Aggregate (Aggregator, orderAggregate)
Expand All @@ -46,6 +50,7 @@ import qualified Opaleye.Internal.QueryArr as Q
import qualified Opaleye.Internal.HaskellDB.PrimQuery as HPQ
import qualified Opaleye.Internal.Operators as O
import qualified Opaleye.Internal.PackMap as PM
import Opaleye.Internal.Rebind (rebindExplicit)
import qualified Opaleye.Internal.Tag as Tag
import qualified Opaleye.Internal.Unpackspec as U

Expand Down Expand Up @@ -85,11 +90,8 @@ result of an aggregation.
-}
-- See 'Opaleye.Internal.Sql.aggregate' for details of how aggregating
-- by an empty query with no group by is handled.
aggregate :: Aggregator a b -> S.Select a -> S.Select b
aggregate agg q = Q.productQueryArr $ do
(a, pq) <- Q.runSimpleSelect q
t <- Tag.fresh
pure (second ($ pq) (A.aggregateU agg (a, t)))
aggregate :: D.Default U.Unpackspec a a => Aggregator a b -> S.Select a -> S.Select b
aggregate = aggregateExplicit D.def

-- | Order the values within each aggregation in `Aggregator` using
-- the given ordering. This is only relevant for aggregations that
Expand All @@ -100,7 +102,13 @@ aggregate agg q = Q.productQueryArr $ do
-- you need different orderings for different aggregations, use
-- 'Opaleye.Internal.Aggregate.orderAggregate'.

aggregateOrdered :: Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b
aggregateExplicit :: U.Unpackspec a a' -> Aggregator a' b -> S.Select a -> S.Select b
aggregateExplicit u agg q = Q.productQueryArr $ do
(a, pq) <- Q.runSimpleSelect (rebindExplicit u <<< q)
t <- Tag.fresh
pure (second ($ pq) (A.aggregateU agg (a, t)))

aggregateOrdered :: D.Default U.Unpackspec a a => Ord.Order a -> Aggregator a b -> S.Select a -> S.Select b
aggregateOrdered o agg = aggregate (orderAggregate o agg)

-- | Aggregate only distinct values
Expand Down
4 changes: 3 additions & 1 deletion src/Opaleye/Distinct.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Opaleye.Internal.Distinct
import Opaleye.Order

import qualified Data.Profunctor.Product.Default as D
import Opaleye.Internal.Unpackspec (Unpackspec)

-- | Remove duplicate rows from the 'Select'.
--
Expand All @@ -40,5 +41,6 @@ import qualified Data.Profunctor.Product.Default as D
-- 'Opaleye.Lateral.laterally' 'distinct' :: 'Data.Profunctor.Product.Default' 'Distinctspec' fields fields => 'Opaleye.Select.SelectArr' i fields -> 'Opaleye.Select.SelectArr' i fields
-- @
distinct :: D.Default Distinctspec fields fields =>
D.Default Unpackspec fields fields =>
Select fields -> Select fields
distinct = distinctExplicit D.def
distinct = distinctExplicit D.def D.def
32 changes: 10 additions & 22 deletions src/Opaleye/Internal/Aggregate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,25 @@ aggregatorApply = Aggregator $ PM.PackMap $ \f (agg, a) ->
aggregateU :: Aggregator a b
-> (a, T.Tag) -> (b, PQ.PrimQuery -> PQ.PrimQuery)
aggregateU agg (c0, t0) = (c1, primQ')
where (c1, projPEs_inners) =
where projPEs_inners :: PQ.Bindings HPQ.Aggregate
(c1, projPEs_inners) =
PM.run (runAggregator agg (extractAggregateFields t0) c0)

projPEs = map fst projPEs_inners
inners = concatMap snd projPEs_inners
projPEs = projPEs_inners

primQ' = PQ.Aggregate projPEs . PQ.Rebind True inners
primQ' = PQ.Aggregate projPEs

extractAggregateFields
:: Traversable t
=> T.Tag
-> t HPQ.PrimExpr
-> PM.PM [((HPQ.Symbol,
t HPQ.Symbol),
PQ.Bindings HPQ.PrimExpr)]
HPQ.PrimExpr
:: T.Tag
-> HPQ.Aggregate
-> PM.PM (PQ.Bindings HPQ.Aggregate) HPQ.PrimExpr
extractAggregateFields tag agg = do
i <- PM.new
let sinner = HPQ.Symbol ("result" ++ i) tag

let souter = HPQ.Symbol ("result" ++ i) tag
PM.write (sinner, agg)

bindings <- for agg $ \pe -> do
j <- PM.new
let sinner = HPQ.Symbol ("inner" ++ j) tag
pure (sinner, pe)

let agg' = fmap fst bindings

PM.write ((souter, agg'), toList bindings)

pure (HPQ.AttrExpr souter)
pure (HPQ.AttrExpr sinner)

unsafeMax :: Aggregator (C.Field a) (C.Field a)
unsafeMax = makeAggr HPQ.AggrMax
Expand Down
8 changes: 5 additions & 3 deletions src/Opaleye/Internal/Distinct.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@ module Opaleye.Internal.Distinct where
import qualified Opaleye.Internal.MaybeFields as M
import Opaleye.Select (Select)
import Opaleye.Field (Field_)
import Opaleye.Aggregate (Aggregator, groupBy, aggregate)
import Opaleye.Aggregate (Aggregator, groupBy, aggregateExplicit)

import qualified Data.Profunctor as P
import qualified Data.Profunctor.Product as PP
import Data.Profunctor.Product.Default (Default, def)
import Opaleye.Internal.Unpackspec (Unpackspec)

-- We implement distinct simply by grouping by all columns. We could
-- instead implement it as SQL's DISTINCT but implementing it in terms
-- of something else that we already have is easier at this point.

distinctExplicit :: Distinctspec fields fields'
distinctExplicit :: Unpackspec fields fields
-> Distinctspec fields fields'
-> Select fields -> Select fields'
distinctExplicit (Distinctspec agg) = aggregate agg
distinctExplicit u (Distinctspec agg) = aggregateExplicit u agg

newtype Distinctspec a b = Distinctspec (Aggregator a b)

Expand Down
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/PrimQuery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ data PrimQuery' a = Unit
| Product (NEL.NonEmpty (Lateral, PrimQuery' a)) [HPQ.PrimExpr]
-- | The subqueries to take the product of and the
-- restrictions to apply
| Aggregate (Bindings (HPQ.Aggregate' HPQ.Symbol))
| Aggregate (Bindings HPQ.Aggregate)
(PrimQuery' a)
| Window (Bindings (HPQ.WndwOp, HPQ.Partition)) (PrimQuery' a)
-- | Represents both @DISTINCT ON@ and @ORDER BY@
Expand Down Expand Up @@ -178,7 +178,7 @@ data PrimQueryFoldP a p p' = PrimQueryFold
, empty :: a -> p'
, baseTable :: TableIdentifier -> Bindings HPQ.PrimExpr -> p'
, product :: NEL.NonEmpty (Lateral, p) -> [HPQ.PrimExpr] -> p'
, aggregate :: Bindings (HPQ.Aggregate' HPQ.Symbol)
, aggregate :: Bindings HPQ.Aggregate
-> p
-> p'
, window :: Bindings (HPQ.WndwOp, HPQ.Partition) -> p -> p'
Expand Down
4 changes: 2 additions & 2 deletions src/Opaleye/Internal/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ product ss pes = SelectFrom $
PQ.Lateral -> Lateral
PQ.NonLateral -> NonLateral

aggregate :: PQ.Bindings (HPQ.Aggregate' HPQ.Symbol)
aggregate :: PQ.Bindings HPQ.Aggregate
-> Select
-> Select
aggregate aggrs' s =
Expand Down Expand Up @@ -191,7 +191,7 @@ aggregate aggrs' s =
handleEmpty = ensureColumnsGen SP.deliteral

aggrs :: [(Symbol, HPQ.Aggregate)]
aggrs = (map . Arr.second . fmap) HPQ.AttrExpr aggrs'
aggrs = aggrs'

groupBy' :: [(symbol, HPQ.Aggregate)]
-> NEL.NonEmpty HSql.SqlExpr
Expand Down

0 comments on commit 0a2ef1b

Please sign in to comment.