Skip to content

Commit a5b9133

Browse files
Fix/simplify intersectBySorted
* Remove MonadIO and Eq constraints * Simplify implementation * Simplify tests * Fix formatting * Use longer benchmarks
1 parent 836b0f5 commit a5b9133

File tree

5 files changed

+85
-97
lines changed

5 files changed

+85
-97
lines changed

benchmark/Streamly/Benchmark/Prelude/Serial/NestedStream.hs

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -417,54 +417,58 @@ o_n_space_monad value =
417417
-- Joining
418418
-------------------------------------------------------------------------------
419419

420-
toKvMap :: Int -> (Int, Int)
421-
toKvMap p = (p, p)
420+
toKv :: Int -> (Int, Int)
421+
toKv p = (p, p)
422422

423-
mkStreamLen :: (S.IsStream t, S.MonadAsync m) => Int -> t m Int
424-
mkStreamLen count = sourceUnfoldrM count 0
423+
mkStream1 :: (S.IsStream t, S.MonadAsync m) => Int -> t m Int
424+
mkStream1 count = sourceUnfoldrM count 0
425+
426+
mkStream2 :: (S.IsStream t, S.MonadAsync m) => Int -> t m Int
427+
mkStream2 count = sourceUnfoldrM count (count `div` 2)
425428

426429
{-# INLINE joinInner #-}
427-
joinInner :: Int -> Int -> Int -> IO ()
428-
joinInner val1 val2 _ =
429-
S.drain $ Internal.joinInner (==) (mkStreamLen val1) $ mkStreamLen val2
430+
joinInner :: Int -> Int -> IO ()
431+
joinInner val _ =
432+
S.drain $ Internal.joinInner (==) (mkStream1 val) (mkStream2 val)
430433

431434
{-# INLINE joinInnerMap #-}
432-
joinInnerMap :: Int -> Int -> Int -> IO ()
433-
joinInnerMap val1 val2 _ =
434-
S.drain $
435-
Internal.joinInnerMap
436-
(fmap toKvMap (mkStreamLen val1))
437-
(fmap toKvMap (mkStreamLen val2))
435+
joinInnerMap :: Int -> Int -> IO ()
436+
joinInnerMap val _ =
437+
S.drain $
438+
Internal.joinInnerMap
439+
(fmap toKv (mkStream1 val))
440+
(fmap toKv (mkStream2 val))
438441

439442
{-# INLINE intersectBy #-}
440-
intersectBy :: Int -> Int -> Int -> IO ()
441-
intersectBy val1 val2 _ =
442-
S.drain $
443-
Internal.intersectBy (==)
444-
(fmap toKvMap (mkStreamLen val1))
445-
(fmap toKvMap (mkStreamLen val2))
443+
intersectBy :: Int -> Int -> IO ()
444+
intersectBy val _ =
445+
S.drain $
446+
Internal.intersectBy (==)
447+
(fmap toKv (mkStream1 val))
448+
(fmap toKv (mkStream2 val))
446449

447450
{-# INLINE intersectBySorted #-}
448-
intersectBySorted :: Int -> Int -> Int -> IO ()
449-
intersectBySorted val1 val2 _ =
450-
S.drain $
451-
Internal.intersectBySorted compare
452-
(fmap toKvMap (mkStreamLen val1))
453-
(fmap toKvMap (mkStreamLen val2))
451+
intersectBySorted :: Int -> Int -> IO ()
452+
intersectBySorted val _ =
453+
S.drain $
454+
Internal.intersectBySorted compare
455+
(fmap toKv (mkStream1 val))
456+
(fmap toKv (mkStream2 val))
454457

455458
o_n_heap_buffering :: Int -> [Benchmark]
456459
o_n_heap_buffering value =
457460
[ bgroup "buffered"
458461
[
459-
benchIOSrc1 "joinInner" (joinInner sqrtVal sqrtVal)
460-
, benchIOSrc1 "joinInnerMap" (joinInnerMap sqrtVal sqrtVal)
461-
, benchIOSrc1 "intersectBy" (intersectBy sqrtVal sqrtVal)
462-
, benchIOSrc1 "intersectBySorted" (intersectBySorted sqrtVal sqrtVal)
462+
benchIOSrc1 "joinInner (sqrtVal)" (joinInner sqrtVal)
463+
, benchIOSrc1 "joinInnerMap" (joinInnerMap halfVal)
464+
, benchIOSrc1 "intersectBy" (intersectBy halfVal)
465+
, benchIOSrc1 "intersectBySorted" (intersectBySorted halfVal)
463466
]
464467
]
465468

466469
where
467470

471+
halfVal = value `div` 2
468472
sqrtVal = round $ sqrt (fromIntegral value :: Double)
469473

470474
-------------------------------------------------------------------------------

src/Streamly/Internal/Data/Stream/IsStream/Top.hs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top
2828
-- | These are not exactly set operations because streams are not
2929
-- necessarily sets, they may have duplicated elements.
3030
, intersectBy
31-
, intersectBySorted
31+
, intersectBySorted
3232
, differenceBy
3333
, mergeDifferenceBy
3434
, unionBy
@@ -65,7 +65,6 @@ import Streamly.Internal.Data.Stream.IsStream.Common (concatM)
6565
import Streamly.Internal.Data.Stream.IsStream.Type
6666
(IsStream(..), adapt, foldl', fromList)
6767
import Streamly.Internal.Data.Stream.Serial (SerialT)
68-
--import Streamly.Internal.Data.Stream.StreamD (fromStreamD, toStreamD)
6968
import Streamly.Internal.Data.Time.Units (NanoSecond64(..), toRelTime64)
7069

7170
import qualified Data.List as List
@@ -536,18 +535,20 @@ intersectBy eq s1 s2 =
536535
xs <- Stream.toListRev $ Stream.uniqBy eq $ adapt s2
537536
return $ Stream.filter (\x -> List.any (eq x) xs) s1
538537

539-
-- | Like 'intersectBy' but works only on sorted streams.
538+
-- | Like 'intersectBy' but works only on streams sorted in ascending order.
540539
--
541540
-- Space: O(1)
542541
--
543542
-- Time: O(m+n)
544543
--
545544
-- /Pre-release/
546545
{-# INLINE intersectBySorted #-}
547-
intersectBySorted :: (IsStream t, MonadIO m, Eq a) =>
546+
intersectBySorted :: (IsStream t, Monad m) =>
548547
(a -> a -> Ordering) -> t m a -> t m a -> t m a
549548
intersectBySorted eq s1 =
550-
IsStream.fromStreamD . StreamD.intersectBySorted eq (IsStream.toStreamD s1) . IsStream.toStreamD
549+
IsStream.fromStreamD
550+
. StreamD.intersectBySorted eq (IsStream.toStreamD s1)
551+
. IsStream.toStreamD
551552

552553
-- Roughly leftJoin s1 s2 = s1 `difference` s2 + s1 `intersection` s2
553554

src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -484,57 +484,47 @@ mergeBy
484484
mergeBy cmp = mergeByM (\a b -> return $ cmp a b)
485485

486486
-------------------------------------------------------------------------------
487-
-- Intersection of sorted streams ---------------------------------------------
487+
-- Intersection of sorted streams
488488
-------------------------------------------------------------------------------
489+
490+
-- Assuming the streams are sorted in ascending order
489491
{-# INLINE_NORMAL intersectBySorted #-}
490-
intersectBySorted
491-
:: (MonadIO m, Eq a)
492+
intersectBySorted :: Monad m
492493
=> (a -> a -> Ordering) -> Stream m a -> Stream m a -> Stream m a
493494
intersectBySorted cmp (Stream stepa ta) (Stream stepb tb) =
494-
Stream step (Just ta, Just tb, Nothing, Nothing, Nothing)
495+
Stream step
496+
( ta -- left stream state
497+
, tb -- right stream state
498+
, Nothing -- left value
499+
, Nothing -- right value
500+
)
495501

496502
where
497-
{-# INLINE_LATE step #-}
498503

499-
-- step 1
500-
step gst (Just sa, sb, Nothing, b, Nothing) = do
504+
{-# INLINE_LATE step #-}
505+
-- step 1, fetch the first value
506+
step gst (sa, sb, Nothing, b) = do
501507
r <- stepa gst sa
502508
return $ case r of
503-
Yield a sa' -> Skip (Just sa', sb, Just a, b, Nothing)
504-
Skip sa' -> Skip (Just sa', sb, Nothing, b, Nothing)
509+
Yield a sa' -> Skip (sa', sb, Just a, b) -- step 2/3
510+
Skip sa' -> Skip (sa', sb, Nothing, b)
505511
Stop -> Stop
506512

507-
-- step 2
508-
step gst (sa, Just sb, a, Nothing, Nothing) = do
513+
-- step 2, fetch the second value
514+
step gst (sa, sb, a@(Just _), Nothing) = do
509515
r <- stepb gst sb
510516
return $ case r of
511-
Yield b sb' -> Skip (sa, Just sb', a, Just b, Nothing)
512-
Skip sb' -> Skip (sa, Just sb', a, Nothing, Nothing)
517+
Yield b sb' -> Skip (sa, sb', a, Just b) -- step 3
518+
Skip sb' -> Skip (sa, sb', a, Nothing)
513519
Stop -> Stop
514520

515-
-- step 3
516-
-- both the values are available compare it
517-
step _ (sa, sb, Just a, Just b, Nothing) = do
521+
-- step 3, compare the two values
522+
step _ (sa, sb, Just a, Just b) = do
518523
let res = cmp a b
519524
return $ case res of
520-
GT -> Skip (sa, sb, Just a, Nothing, Nothing)
521-
LT -> Skip (sa, sb, Nothing, Just b, Nothing)
522-
EQ -> Yield a (sa, sb, Nothing, Just a, Just b) -- step 4
523-
524-
-- step 4
525-
-- Matching element
526-
step gst (Just sa, Just sb, Nothing, Just _, Just b) = do
527-
r1 <- stepa gst sa
528-
return $ case r1 of
529-
Yield a' sa' -> do
530-
if a' == b -- match with prev a
531-
then Yield a' (Just sa', Just sb, Nothing, Just b, Just b) --step 1
532-
else Skip (Just sa', Just sb, Just a', Nothing, Nothing)
533-
534-
Skip sa' -> Skip (Just sa', Just sb, Nothing, Nothing, Nothing)
535-
Stop -> Stop
536-
537-
step _ (_, _, _, _, _) = return Stop
525+
GT -> Skip (sa, sb, Just a, Nothing) -- step 2
526+
LT -> Skip (sa, sb, Nothing, Just b) -- step 1
527+
EQ -> Yield a (sa, sb, Nothing, Just b) -- step 1
538528

539529
------------------------------------------------------------------------------
540530
-- Combine N Streams - unfoldMany

streamly.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ extra-source-files:
101101
test/Streamly/Test/Data/Array/Prim/Pinned.hs
102102
test/Streamly/Test/Data/Array/Foreign.hs
103103
test/Streamly/Test/Data/Array/Stream/Foreign.hs
104-
test/Streamly/Test/Data/Parser/ParserD.hs
104+
test/Streamly/Test/Data/Parser/ParserD.hs
105105
test/Streamly/Test/FileSystem/Event.hs
106106
test/Streamly/Test/FileSystem/Event/Common.hs
107107
test/Streamly/Test/FileSystem/Event/Darwin.hs

test/Streamly/Test/Prelude/Top.hs

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
module Main (main)
2-
where
1+
module Main (main) where
32

43
import Data.List (intersect, sort)
4+
import Streamly.Prelude (SerialT)
55
import Test.QuickCheck
66
( Gen
77
, Property
@@ -66,8 +66,16 @@ joinInnerMap =
6666
]
6767
assert (v1 == v2)
6868

69-
intersectBy :: Property
70-
intersectBy =
69+
intersectBy ::
70+
([Int] -> [Int])
71+
-> ( (Int -> Int -> a)
72+
-> SerialT IO Int
73+
-> SerialT IO Int
74+
-> SerialT IO Int
75+
)
76+
-> (Int -> Int -> a)
77+
-> Property
78+
intersectBy srt intersectFunc cmp =
7179
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
7280
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
7381
monadicIO $ action (sort ls0) (sort ls1)
@@ -78,43 +86,28 @@ intersectBy =
7886
v1 <-
7987
run
8088
$ S.toList
81-
$ Top.intersectBy
82-
(==)
83-
(S.fromList ls0)
84-
(S.fromList ls1)
85-
let v2 = intersect ls0 ls1
86-
assert (v1 == sort v2)
87-
88-
intersectBySorted :: Property
89-
intersectBySorted =
90-
forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 ->
91-
forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 ->
92-
monadicIO $ action (sort ls0) (sort ls1)
93-
94-
where
95-
96-
action ls0 ls1 = do
97-
v1 <-
98-
run
99-
$ S.toList
100-
$ Top.intersectBySorted
101-
compare
89+
$ intersectFunc
90+
cmp
10291
(S.fromList ls0)
10392
(S.fromList ls1)
10493
let v2 = intersect ls0 ls1
10594
assert (v1 == sort v2)
10695

10796
-------------------------------------------------------------------------------
97+
-- Main
98+
-------------------------------------------------------------------------------
99+
108100
moduleName :: String
109101
moduleName = "Prelude.Top"
110102

111103
main :: IO ()
112104
main = hspec $ do
113105
describe moduleName $ do
114106
-- Joins
115-
116107
prop "joinInner" Main.joinInner
117108
prop "joinInnerMap" Main.joinInnerMap
109+
118110
-- intersect
119-
prop "intersectBy" Main.intersectBy
120-
prop "intersectBySorted" Main.intersectBySorted
111+
prop "intersectBy" (intersectBy id Top.intersectBy (==))
112+
prop "intersectBySorted"
113+
(intersectBy sort Top.intersectBySorted compare)

0 commit comments

Comments
 (0)