Skip to content

Commit ec20f5f

Browse files
Fix/simplify intersectBySorted
* Remove MonadIO and Eq constraints * Simplify implementation * Simplify tests * Fix formatting * Use longer benchmarks
1 parent 160393c commit ec20f5f

File tree

5 files changed

+73
-90
lines changed

5 files changed

+73
-90
lines changed

benchmark/Streamly/Benchmark/Prelude/Serial/NestedStream.hs

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -417,57 +417,56 @@ o_n_space_monad value =
417417
-- Joining
418418
-------------------------------------------------------------------------------
419419

420-
toKvMap :: Int -> (Int, Int)
421-
toKvMap p = (p, p)
420+
toKv :: Int -> (Int, Int)
421+
toKv p = (p, p)
422422

423423
{-# INLINE joinWith #-}
424424
joinWith :: (S.MonadAsync m) =>
425425
((Int -> Int -> Bool) -> SerialT m Int -> SerialT m Int -> SerialT m b)
426426
-> Int
427427
-> Int
428-
-> Int
429428
-> m ()
430-
joinWith j val1 val2 i =
431-
S.drain $ j (==) (sourceUnfoldrM val1 i) (sourceUnfoldrM val2 i)
429+
joinWith j val i =
430+
S.drain $ j (==) (sourceUnfoldrM val i) (sourceUnfoldrM val (val `div` 2))
432431

433432
{-# INLINE joinMapWith #-}
434433
joinMapWith :: (S.MonadAsync m) =>
435434
(SerialT m (Int, Int) -> SerialT m (Int, Int) -> SerialT m b)
436435
-> Int
437436
-> Int
438-
-> Int
439437
-> m ()
440-
joinMapWith j val1 val2 i =
438+
joinMapWith j val i =
441439
S.drain
442440
$ j
443-
(fmap toKvMap (sourceUnfoldrM val1 i))
444-
(fmap toKvMap (sourceUnfoldrM val2 i))
441+
(fmap toKv (sourceUnfoldrM val i))
442+
(fmap toKv (sourceUnfoldrM val (val `div` 2)))
445443

446444
o_n_heap_buffering :: Int -> [Benchmark]
447445
o_n_heap_buffering value =
448446
[ bgroup "buffered"
449447
[
450-
benchIOSrc1 "joinInner"
451-
$ joinWith Internal.joinInner sqrtVal sqrtVal
448+
benchIOSrc1 "joinInner (sqrtVal)"
449+
$ joinWith Internal.joinInner sqrtVal
452450
, benchIOSrc1 "joinInnerMap"
453-
$ joinMapWith Internal.joinInnerMap sqrtVal sqrtVal
454-
, benchIOSrc1 "joinLeft"
455-
$ joinWith Internal.joinLeft sqrtVal sqrtVal
451+
$ joinMapWith Internal.joinInnerMap halfVal
452+
, benchIOSrc1 "joinLeft (sqrtVal)"
453+
$ joinWith Internal.joinLeft sqrtVal
456454
, benchIOSrc1 "joinLeftMap "
457-
$ joinMapWith Internal.joinLeftMap sqrtVal sqrtVal
458-
, benchIOSrc1 "joinOuter"
459-
$ joinWith Internal.joinOuter sqrtVal sqrtVal
455+
$ joinMapWith Internal.joinLeftMap halfVal
456+
, benchIOSrc1 "joinOuter (sqrtVal)"
457+
$ joinWith Internal.joinOuter sqrtVal
460458
, benchIOSrc1 "joinOuterMap"
461-
$ joinMapWith Internal.joinOuterMap sqrtVal sqrtVal
462-
, benchIOSrc1 "intersectBy"
463-
$ joinWith Internal.intersectBy sqrtVal sqrtVal
459+
$ joinMapWith Internal.joinOuterMap halfVal
460+
, benchIOSrc1 "intersectBy (sqrtVal)"
461+
$ joinWith Internal.intersectBy sqrtVal
464462
, benchIOSrc1 "intersectBySorted"
465-
$ joinMapWith Internal.intersectBySorted sqrtVal sqrtVal
463+
$ joinMapWith (Internal.intersectBySorted compare) halfVal
466464
]
467465
]
468466

469467
where
470468

469+
halfVal = value `div` 2
471470
sqrtVal = round $ sqrt (fromIntegral value :: Double)
472471

473472
-------------------------------------------------------------------------------

src/Streamly/Internal/Data/Stream/IsStream/Top.hs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
2828
-- | These are not exactly set operations because streams are not
2929
-- necessarily sets, they may have duplicated elements.
3030
, intersectBy
31-
, intersectBySorted
31+
, intersectBySorted
3232
, differenceBy
3333
, mergeDifferenceBy
3434
, unionBy
@@ -65,7 +65,6 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
6565
import Streamly.Internal.Data.Stream.IsStream.Type
6666
(IsStream(..), adapt, foldl', fromList)
6767
import Streamly.Internal.Data.Stream.Serial (SerialT)
68-
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
6968
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
7069

7170
import qualified Data.List as List
@@ -576,18 +575,20 @@ intersectBy eq s1 s2 =
576575
xs <- Stream.toListRev $ Stream.uniqBy eq $ adapt s2
577576
return $ Stream.filter (\x -> List.any (eq x) xs) s1
578577

579-
-- | Like 'intersectBy' but works only on sorted streams.
578+
-- | Like 'intersectBy' but works only on streams sorted in ascending order.
580579
--
581580
-- Space: O(1)
582581
--
583582
-- Time: O(m+n)
584583
--
585584
-- /Pre-release/
586585
{-# INLINE intersectBySorted #-}
587-
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
586+
intersectBySorted :: (IsStream t, Monad m) =>
588587
(a -> a -> Ordering) -> t m a -> t m a -> t m a
589588
intersectBySorted eq s1 =
590-
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
589+
IsStream.fromStreamD
590+
. StreamD.intersectBySorted eq (IsStream.toStreamD s1)
591+
. IsStream.toStreamD
591592

592593
-- Roughly joinLeft s1 s2 = s1 `difference` s2 + s1 `intersection` s2
593594

src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -484,57 +484,47 @@ mergeBy
484484
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
485485

486486
-------------------------------------------------------------------------------
487-
-- Intersection of sorted streams ---------------------------------------------
487+
-- Intersection of sorted streams
488488
-------------------------------------------------------------------------------
489+
490+
-- Assuming the streams are sorted in ascending order
489491
{-# INLINE_NORMAL intersectBySorted #-}
490-
intersectBySorted
491-
:: (MonadIO m, Eq a)
492+
intersectBySorted :: Monad m
492493
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
493494
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
494-
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
495+
Stream step
496+
( ta -- left stream state
497+
, tb -- right stream state
498+
, Nothing -- left value
499+
, Nothing -- right value
500+
)
495501

496502
where
497-
{-# INLINE_LATE step #-}
498503

499-
-- step 1
500-
step gst (Just sa, sb, Nothing, b, Nothing) = do
504+
{-# INLINE_LATE step #-}
505+
-- step 1, fetch the first value
506+
step gst (sa, sb, Nothing, b) = do
501507
r <- stepa gst sa
502508
return $ case r of
503-
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
504-
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
509+
Yield a sa' -> Skip (sa', sb, Just a, b) -- step 2/3
510+
Skip sa' -> Skip (sa', sb, Nothing, b)
505511
Stop -> Stop
506512

507-
-- step 2
508-
step gst (sa, Just sb, a, Nothing, Nothing) = do
513+
-- step 2, fetch the second value
514+
step gst (sa, sb, a@(Just _), Nothing) = do
509515
r <- stepb gst sb
510516
return $ case r of
511-
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
512-
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
517+
Yield b sb' -> Skip (sa, sb', a, Just b) -- step 3
518+
Skip sb' -> Skip (sa, sb', a, Nothing)
513519
Stop -> Stop
514520

515-
-- step 3
516-
-- both the values are available compare it
517-
step _ (sa, sb, Just a, Just b, Nothing) = do
521+
-- step 3, compare the two values
522+
step _ (sa, sb, Just a, Just b) = do
518523
let res = cmp a b
519524
return $ case res of
520-
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
521-
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
522-
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
523-
524-
-- step 4
525-
-- Matching element
526-
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
527-
r1 <- stepa gst sa
528-
return $ case r1 of
529-
Yield a' sa' -> do
530-
if a' == b -- match with prev a
531-
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
532-
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
533-
534-
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
535-
Stop -> Stop
536-
537-
step _ (_, _, _, _, _) = return Stop
525+
GT -> Skip (sa, sb, Just a, Nothing) -- step 2
526+
LT -> Skip (sa, sb, Nothing, Just b) -- step 1
527+
EQ -> Yield a (sa, sb, Nothing, Just b) -- step 1
538528

539529
------------------------------------------------------------------------------
540530
-- Combine N Streams - unfoldMany

streamly.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ extra-source-files:
101101
test/Streamly/Test/Data/Array/Prim/Pinned.hs
102102
test/Streamly/Test/Data/Array/Foreign.hs
103103
test/Streamly/Test/Data/Array/Stream/Foreign.hs
104-
test/Streamly/Test/Data/Parser/ParserD.hs
104+
test/Streamly/Test/Data/Parser/ParserD.hs
105105
test/Streamly/Test/FileSystem/Event.hs
106106
test/Streamly/Test/FileSystem/Event/Common.hs
107107
test/Streamly/Test/FileSystem/Event/Darwin.hs

test/Streamly/Test/Prelude/Top.hs

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
module Main (main)
2-
where
1+
module Main (main) where
32

43
import Data.List (elem, intersect, nub, sort)
54
import Data.Maybe (isNothing)
5+
import Streamly.Prelude (SerialT)
66
import Test.QuickCheck
77
( Gen
88
, Property
@@ -169,8 +169,16 @@ joinLeftMap =
169169
let v2 = joinLeftList ls0 ls1
170170
assert (v1 == v2)
171171

172-
intersectBy :: Property
173-
intersectBy =
172+
intersectBy ::
173+
([Int] -> [Int])
174+
-> ( (Int -> Int -> a)
175+
-> SerialT IO Int
176+
-> SerialT IO Int
177+
-> SerialT IO Int
178+
)
179+
-> (Int -> Int -> a)
180+
-> Property
181+
intersectBy _srt intersectFunc cmp =
174182
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
175183
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
176184
monadicIO $ action (sort ls0) (sort ls1)
@@ -181,41 +189,24 @@ intersectBy =
181189
v1 <-
182190
run
183191
$ S.toList
184-
$ Top.intersectBy
185-
(==)
192+
$ intersectFunc
193+
cmp
186194
(S.fromList ls0)
187195
(S.fromList ls1)
188-
let v2 = intersect ls0 ls1
189-
assert (v1 == sort v2)
190-
191-
intersectBySorted :: Property
192-
intersectBySorted =
193-
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
194-
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
195-
monadicIO $ action (sort ls0) (sort ls1)
196-
197-
where
198-
199-
action ls0 ls1 = do
200-
v1 <-
201-
run
202-
$ S.toList
203-
$ Top.intersectBySorted
204-
compare
205-
(S.fromList ls0)
206-
(S.fromList ls1)
207-
let v2 = intersect ls0 ls1
196+
let v2 = ls0 `intersect` ls1
208197
assert (v1 == sort v2)
209198

210199
-------------------------------------------------------------------------------
200+
-- Main
201+
-------------------------------------------------------------------------------
202+
211203
moduleName :: String
212204
moduleName = "Prelude.Top"
213205

214206
main :: IO ()
215207
main = hspec $ do
216208
describe moduleName $ do
217209
-- Joins
218-
219210
prop "joinInner" Main.joinInner
220211
prop "joinInnerMap" Main.joinInnerMap
221212
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1032
@@ -224,5 +215,7 @@ main = hspec $ do
224215
prop "joinLeft" Main.joinLeft
225216
prop "joinLeftMap" Main.joinLeftMap
226217
-- intersect
227-
prop "intersectBy" Main.intersectBy
228-
prop "intersectBySorted" Main.intersectBySorted
218+
-- XXX currently API is broken https://github.com/composewell/streamly/issues/1471
219+
--prop "intersectBy" (intersectBy id Top.intersectBy (==))
220+
prop "intersectBySorted"
221+
(intersectBy sort Top.intersectBySorted compare)

0 commit comments

Comments
 (0)