From 66e8a549f45cbdc5097b1798283c2720b4ea6e65 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Thu, 18 Sep 2025 19:24:31 +0200 Subject: [PATCH 01/18] WPB-19712: Allow team admin to update the channels to user-group association --- changelog.d/2-features/WPB-19713 | 1 + integration/test/API/Brig.hs | 5 +++ integration/test/Test/UserGroup.hs | 3 +- .../src/Wire/API/Routes/Public/Brig.hs | 5 +-- .../src/Wire/UserGroupStore.hs | 1 + .../src/Wire/UserGroupStore/Postgres.hs | 23 +++++++++++++ .../src/Wire/UserGroupSubsystem.hs | 1 + .../Wire/UserGroupSubsystem/Interpreter.hs | 15 +++++++++ .../Wire/MockInterpreters/UserGroupStore.hs | 11 ++++++- postgres-schema.sql | 11 +++++++ services/brig/brig.cabal | 3 +- services/brig/default.nix | 1 + services/brig/src/Brig/API/Public.hs | 33 +++++++++++++++++-- 13 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 changelog.d/2-features/WPB-19713 diff --git a/changelog.d/2-features/WPB-19713 b/changelog.d/2-features/WPB-19713 new file mode 100644 index 0000000000..cdbf1fd8a2 --- /dev/null +++ b/changelog.d/2-features/WPB-19713 @@ -0,0 +1 @@ +Implement `channels` and `channelsCount` in `user-groups` endpoints. diff --git a/integration/test/API/Brig.hs b/integration/test/API/Brig.hs index bce2531ada..9963e05709 100644 --- a/integration/test/API/Brig.hs +++ b/integration/test/API/Brig.hs @@ -1061,6 +1061,11 @@ getUserGroup user gid = do req <- baseRequest user Brig Versioned $ joinHttpPath ["user-groups", gid] submit "GET" req +getUserGroupWithChannels :: (MakesValue user) => user -> String -> App Response +getUserGroupWithChannels user gid = do + req <- baseRequest user Brig Versioned $ joinHttpPath ["user-groups", gid] + submit "GET" $ req & addQueryParams [("include_channels", "true")] + updateUserGroupChannels :: (MakesValue user) => user -> String -> [String] -> App Response updateUserGroupChannels user gid convIds = do req <- baseRequest user Brig Versioned $ joinHttpPath ["user-groups", gid, "channels"] diff --git a/integration/test/Test/UserGroup.hs b/integration/test/Test/UserGroup.hs index 6b3fae9674..9cab119809 100644 --- a/integration/test/Test/UserGroup.hs +++ b/integration/test/Test/UserGroup.hs @@ -423,8 +423,9 @@ testUserGroupUpdateChannels = do notif %. "payload.0.user_group.id" `shouldMatch` gid -- bobId <- asString $ bob %. "id" - bindResponse (getUserGroup alice gid) $ \resp -> do + bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 + resp.json %. "channels" `shouldMatch` [object ["id" .= convId.id_, "domain" .= convId.domain]] -- FUTUREWORK: check the actual associated channels -- resp.json %. "members" `shouldMatch` [bobId] diff --git a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs index fc8e1ea258..a81fcd5496 100644 --- a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs +++ b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs @@ -314,7 +314,7 @@ type UserGroupAPI = ) :<|> Named "get-user-group" - ( Summary "[STUB] (channels in response not implemented)" + ( Summary "Fetch a group accessible from the logged-in user" :> From 'V10 :> ZLocalUser :> CanThrow 'UserGroupNotFound @@ -331,7 +331,7 @@ type UserGroupAPI = ) :<|> Named "get-user-groups" - ( Summary "[STUB] (channelsCount not implemented)" + ( Summary "Fetch groups accessible from the logged-in user" :> From 'V10 :> ZLocalUser :> "user-groups" @@ -342,6 +342,7 @@ type UserGroupAPI = :> QueryParam' '[Optional, Strict, LastSeenNameDesc] "last_seen_name" UserGroupName :> QueryParam' '[Optional, Strict, LastSeenCreatedAtDesc] "last_seen_created_at" UTCTimeMillis :> QueryParam' '[Optional, Strict, LastSeenIdDesc] "last_seen_id" UserGroupId + :> QueryFlag "include_channels" :> QueryFlag "include_member_count" :> Get '[JSON] UserGroupPage ) diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore.hs b/libs/wire-subsystems/src/Wire/UserGroupStore.hs index 2136b2a60a..2e31839ce9 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore.hs @@ -41,5 +41,6 @@ data UserGroupStore m a where UpdateUsers :: UserGroupId -> Vector UserId -> UserGroupStore m () RemoveUser :: UserGroupId -> UserId -> UserGroupStore m () UpdateUserGroupChannels :: UserGroupId -> Vector ConvId -> UserGroupStore m () + ListUserGroupChannels :: UserGroupId -> UserGroupStore m (Vector ConvId) makeSem ''UserGroupStore diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 3249df272e..48d49bdab7 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -54,6 +54,7 @@ interpretUserGroupStoreToPostgres = UpdateUsers gid uids -> updateUsers gid uids RemoveUser gid uid -> removeUser gid uid UpdateUserGroupChannels gid convIds -> updateUserGroupChannels gid convIds + ListUserGroupChannels gid -> listUserGroupChannels gid updateUsers :: (UserGroupStorePostgresEffectConstraints r) => UserGroupId -> Vector UserId -> Sem r () updateUsers gid uids = do @@ -441,6 +442,28 @@ updateUserGroupChannels gid convIds = do on conflict (user_group_id, conv_id) do nothing |] +listUserGroupChannels :: + forall r. + (UserGroupStorePostgresEffectConstraints r) => + UserGroupId -> + Sem r (Vector ConvId) +listUserGroupChannels gid = do + pool <- input + eitherErrorOrUnit <- liftIO $ use pool session + either throw pure eitherErrorOrUnit + where + session :: Session (Vector ConvId) + session = statement gid selectStatement + + selectStatement :: Statement UserGroupId (Vector ConvId) + selectStatement = + dimap + toUUID + (fmap Id) + [vectorStatement| + select (conv_id :: uuid) from user_group_channel where user_group_id = ($1 :: uuid) + |] + crudUser :: forall r. (UserGroupStorePostgresEffectConstraints r) => diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs index 16df18d5b5..1bfebb3705 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs @@ -34,5 +34,6 @@ data UserGroupSubsystem m a where RemoveUser :: UserId -> UserGroupId -> UserId -> UserGroupSubsystem m () RemoveUserFromAllGroups :: UserId -> TeamId -> UserGroupSubsystem m () UpdateChannels :: UserId -> UserGroupId -> Vector ConvId -> UserGroupSubsystem m () + ListChannels :: UserId -> UserGroupId -> UserGroupSubsystem m (Vector ConvId) makeSem ''UserGroupSubsystem diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs index 5eb9936add..d275718f3b 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs @@ -55,6 +55,7 @@ interpretUserGroupSubsystem = interpret $ \case RemoveUser remover groupId removeeId -> removeUser remover groupId removeeId RemoveUserFromAllGroups uid tid -> removeUserFromAllGroups uid tid UpdateChannels performer groupId channelIds -> updateChannels performer groupId channelIds + ListChannels performer groupId -> listChannels performer groupId data UserGroupSubsystemError = UserGroupNotATeamAdmin @@ -397,3 +398,17 @@ updateChannels performer groupId channelIds = do pushNotifications [ mkEvent performer (UserGroupUpdated groupId) admins ] + +listChannels :: + ( Member UserSubsystem r, + Member Store.UserGroupStore r, + Member (Error UserGroupSubsystemError) r, + Member TeamSubsystem r + ) => + UserId -> + UserGroupId -> + Sem r (Vector ConvId) +listChannels performer groupId = do + void $ getUserGroup performer groupId >>= note UserGroupNotFound + void $ getUserTeam performer >>= note UserGroupNotATeamAdmin + Store.listUserGroupChannels groupId diff --git a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs index 935197ffe2..59c901df94 100644 --- a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs +++ b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs @@ -11,7 +11,7 @@ import Data.Domain (Domain (Domain)) import Data.Id import Data.Json.Util import Data.Map qualified as Map -import Data.Qualified (Qualified (Qualified)) +import Data.Qualified import Data.Text qualified as T import Data.Time.Clock import Data.Vector (Vector, fromList) @@ -65,6 +65,7 @@ userGroupStoreTestInterpreter = UpdateUsers gid uids -> updateUsersImpl gid uids RemoveUser gid uid -> removeUserImpl gid uid UpdateUserGroupChannels gid convIds -> updateUserGroupChannelsImpl gid convIds + ListUserGroupChannels gid -> listUserGroupChannelsImpl gid updateUsersImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupId -> Vector UserId -> Sem r () updateUsersImpl gid uids = do @@ -201,6 +202,14 @@ updateUserGroupChannelsImpl gid convIds = do modifyUserGroupsGidOnly gid (Map.alter f) +listUserGroupChannelsImpl :: + (UserGroupStoreInMemEffectConstraints r) => + UserGroupId -> + Sem r (Vector ConvId) +listUserGroupChannelsImpl gid = + foldMap (fmap qUnqualified) . (runIdentity . (.channels) . snd <=< find ((== gid) . snd . fst) . Map.toList) + <$> get @(Map (TeamId, UserGroupId) UserGroup) + ---------------------------------------------------------------------- modifyUserGroupsGidOnly :: diff --git a/postgres-schema.sql b/postgres-schema.sql index 9075fe312f..70443f919e 100644 --- a/postgres-schema.sql +++ b/postgres-schema.sql @@ -104,6 +104,17 @@ CREATE TABLE public.user_group_channel ( ); + +-- +-- Name: user_group_channel; Type: TABLE; Schema: public; Owner: wire-server +-- + +CREATE TABLE public.user_group_channel ( + user_group_id uuid NOT NULL, + conv_id uuid NOT NULL +); + + ALTER TABLE public.user_group_channel OWNER TO "wire-server"; -- diff --git a/services/brig/brig.cabal b/services/brig/brig.cabal index 67449f7bb0..4d921d809b 100644 --- a/services/brig/brig.cabal +++ b/services/brig/brig.cabal @@ -218,7 +218,7 @@ library , amqp , async >=2.1 , auto-update >=0.1 - , base >=4 && <5 + , base >=4 && <5 , base-prelude , base16-bytestring >=0.1 , base64-bytestring >=1.0 @@ -314,6 +314,7 @@ library , uri-bytestring >=0.2 , utf8-string , uuid >=1.3.5 + , vector >=0.13.2.0 , wai >=3.0 , wai-extra >=3.0 , wai-middleware-gunzip >=0.0.2 diff --git a/services/brig/default.nix b/services/brig/default.nix index 550ed4212b..5827352d74 100644 --- a/services/brig/default.nix +++ b/services/brig/default.nix @@ -266,6 +266,7 @@ mkDerivation { uri-bytestring utf8-string uuid + vector wai wai-extra wai-middleware-gunzip diff --git a/services/brig/src/Brig/API/Public.hs b/services/brig/src/Brig/API/Public.hs index 7ca8f6b1a0..53fc7562b7 100644 --- a/services/brig/src/Brig/API/Public.hs +++ b/services/brig/src/Brig/API/Public.hs @@ -85,6 +85,7 @@ import Data.Qualified import Data.Range import Data.Schema () import Data.Text.Encoding qualified as Text +import Data.Vector qualified as Vector import Data.ZAuth.CryptoSign (CryptoSign) import Data.ZAuth.Token qualified as ZAuth import FileEmbedLzma @@ -1678,7 +1679,20 @@ createUserGroup :: (_) => Local UserId -> NewUserGroup -> Handler r UserGroup createUserGroup lusr newUserGroup = lift . liftSem $ UserGroup.createGroup (tUnqualified lusr) newUserGroup getUserGroup :: (_) => Local UserId -> UserGroupId -> Bool -> Handler r (Maybe UserGroup) -getUserGroup lusr ugid _ = lift . liftSem $ UserGroup.getGroup (tUnqualified lusr) ugid +getUserGroup lusr ugid includeChannels = + lift . liftSem $ do + mUserGroup <- UserGroup.getGroup (tUnqualified lusr) ugid + if includeChannels + then forM mUserGroup $ \userGroup -> do + fetchedChannels <- + fmap (tUntagged . qualifyAs lusr) + <$> UserGroup.listChannels (tUnqualified lusr) userGroup.id_ + pure + userGroup + { channels = Identity $ Just fetchedChannels, + channelsCount = Just $ Vector.length fetchedChannels + } + else pure mUserGroup getUserGroups :: (_) => @@ -1691,9 +1705,22 @@ getUserGroups :: Maybe UTCTimeMillis -> Maybe UserGroupId -> Bool -> + Bool -> Handler r UserGroupPage -getUserGroups lusr q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeMemberCount = - lift . liftSem $ UserGroup.getGroups (tUnqualified lusr) q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeMemberCount +getUserGroups lusr q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeChannels includeMemberCount = + lift . liftSem $ do + userGroups <- UserGroup.getGroups (tUnqualified lusr) q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeMemberCount + if includeChannels + then do + newPage <- + forM userGroups.page $ \userGroup -> do + fetchedChannels <- UserGroup.listChannels (tUnqualified lusr) userGroup.id_ + pure + userGroup + { channelsCount = Just $ Vector.length fetchedChannels + } + pure userGroups {page = newPage} + else pure userGroups updateUserGroup :: (_) => Local UserId -> UserGroupId -> UserGroupUpdate -> (Handler r) () updateUserGroup lusr gid gupd = lift . liftSem $ UserGroup.updateGroup (tUnqualified lusr) gid gupd From f48d12e5a318e8e543ef210b330b795acc23597b Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Tue, 7 Oct 2025 10:46:03 +0200 Subject: [PATCH 02/18] Update postgres-schema.sql Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- postgres-schema.sql | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/postgres-schema.sql b/postgres-schema.sql index 70443f919e..5c3fea345a 100644 --- a/postgres-schema.sql +++ b/postgres-schema.sql @@ -105,16 +105,6 @@ CREATE TABLE public.user_group_channel ( --- --- Name: user_group_channel; Type: TABLE; Schema: public; Owner: wire-server --- - -CREATE TABLE public.user_group_channel ( - user_group_id uuid NOT NULL, - conv_id uuid NOT NULL -); - - ALTER TABLE public.user_group_channel OWNER TO "wire-server"; -- From 7b1396004a2949534df8eda7185097118bcd67ed Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Tue, 7 Oct 2025 11:31:21 +0000 Subject: [PATCH 03/18] wip move include channels to store --- .../src/Wire/UserGroupStore.hs | 6 +-- .../src/Wire/UserGroupStore/Postgres.hs | 8 ++-- .../src/Wire/UserGroupSubsystem.hs | 4 +- .../Wire/UserGroupSubsystem/Interpreter.hs | 42 +++++++------------ services/brig/src/Brig/API/Public.hs | 29 +------------ 5 files changed, 27 insertions(+), 62 deletions(-) diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore.hs b/libs/wire-subsystems/src/Wire/UserGroupStore.hs index 2e31839ce9..ad973bd7a2 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore.hs @@ -18,7 +18,8 @@ data UserGroupPageRequest = UserGroupPageRequest paginationState :: PaginationState, sortOrder :: SortOrder, pageSize :: PageSize, - includeMemberCount :: Bool + includeMemberCount :: Bool, + includeChannels :: Bool } data PaginationState = PaginationSortByName (Maybe (UserGroupName, UserGroupId)) | PaginationSortByCreatedAt (Maybe (UTCTimeMillis, UserGroupId)) @@ -33,7 +34,7 @@ toSortBy = \case data UserGroupStore m a where CreateUserGroup :: TeamId -> NewUserGroup -> ManagedBy -> UserGroupStore m UserGroup - GetUserGroup :: TeamId -> UserGroupId -> UserGroupStore m (Maybe UserGroup) + GetUserGroup :: TeamId -> UserGroupId -> Bool -> UserGroupStore m (Maybe UserGroup) GetUserGroups :: UserGroupPageRequest -> UserGroupStore m UserGroupPage UpdateUserGroup :: TeamId -> UserGroupId -> UserGroupUpdate -> UserGroupStore m (Maybe ()) DeleteUserGroup :: TeamId -> UserGroupId -> UserGroupStore m (Maybe ()) @@ -41,6 +42,5 @@ data UserGroupStore m a where UpdateUsers :: UserGroupId -> Vector UserId -> UserGroupStore m () RemoveUser :: UserGroupId -> UserId -> UserGroupStore m () UpdateUserGroupChannels :: UserGroupId -> Vector ConvId -> UserGroupStore m () - ListUserGroupChannels :: UserGroupId -> UserGroupStore m (Vector ConvId) makeSem ''UserGroupStore diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 48d49bdab7..87d70034de 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -46,7 +46,7 @@ interpretUserGroupStoreToPostgres :: interpretUserGroupStoreToPostgres = interpret $ \case CreateUserGroup team newUserGroup managedBy -> createUserGroup team newUserGroup managedBy - GetUserGroup team userGroupId -> getUserGroup team userGroupId + GetUserGroup team userGroupId includeChannels -> getUserGroup team userGroupId includeChannels GetUserGroups req -> getUserGroups req UpdateUserGroup tid gid gup -> updateGroup tid gid gup DeleteUserGroup tid gid -> deleteGroup tid gid @@ -54,7 +54,6 @@ interpretUserGroupStoreToPostgres = UpdateUsers gid uids -> updateUsers gid uids RemoveUser gid uid -> removeUser gid uid UpdateUserGroupChannels gid convIds -> updateUserGroupChannels gid convIds - ListUserGroupChannels gid -> listUserGroupChannels gid updateUsers :: (UserGroupStorePostgresEffectConstraints r) => UserGroupId -> Vector UserId -> Sem r () updateUsers gid uids = do @@ -79,8 +78,10 @@ getUserGroup :: (UserGroupStorePostgresEffectConstraints r) => TeamId -> UserGroupId -> + Bool -> Sem r (Maybe UserGroup) -getUserGroup team id_ = do +getUserGroup team id_ includeChannels = do + todo "implement includeChannels" includeChannels pool <- input eitherUserGroup <- liftIO $ use pool session either throw pure eitherUserGroup @@ -131,6 +132,7 @@ getUserGroups :: UserGroupPageRequest -> Sem r UserGroupPage getUserGroups req@(UserGroupPageRequest {..}) = do + todo "implement includeChannels" includeChannels pool <- input eitherResult <- liftIO $ use pool do TxSessions.transaction TxSessions.ReadCommitted TxSessions.Read do diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs index 1bfebb3705..f50ab4046a 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem.hs @@ -14,7 +14,7 @@ import Wire.API.UserGroup.Pagination data UserGroupSubsystem m a where CreateGroup :: UserId -> NewUserGroup -> UserGroupSubsystem m UserGroup - GetGroup :: UserId -> UserGroupId -> UserGroupSubsystem m (Maybe UserGroup) + GetGroup :: UserId -> UserGroupId -> Bool -> UserGroupSubsystem m (Maybe UserGroup) GetGroups :: UserId -> Maybe Text -> @@ -25,6 +25,7 @@ data UserGroupSubsystem m a where Maybe UTCTimeMillis -> Maybe UserGroupId -> Bool -> + Bool -> UserGroupSubsystem m UserGroupPage UpdateGroup :: UserId -> UserGroupId -> UserGroupUpdate -> UserGroupSubsystem m () DeleteGroup :: UserId -> UserGroupId -> UserGroupSubsystem m () @@ -34,6 +35,5 @@ data UserGroupSubsystem m a where RemoveUser :: UserId -> UserGroupId -> UserId -> UserGroupSubsystem m () RemoveUserFromAllGroups :: UserId -> TeamId -> UserGroupSubsystem m () UpdateChannels :: UserId -> UserGroupId -> Vector ConvId -> UserGroupSubsystem m () - ListChannels :: UserId -> UserGroupId -> UserGroupSubsystem m (Vector ConvId) makeSem ''UserGroupSubsystem diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs index d275718f3b..50b34416d1 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs @@ -44,9 +44,9 @@ interpretUserGroupSubsystem :: InterpreterFor UserGroupSubsystem r interpretUserGroupSubsystem = interpret $ \case CreateGroup creator newGroup -> createUserGroup creator newGroup - GetGroup getter gid -> getUserGroup getter gid - GetGroups getter q sortByKeys sortOrder pSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount -> - getUserGroups getter q sortByKeys sortOrder pSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount + GetGroup getter gid includeChannels -> getUserGroup getter gid includeChannels + GetGroups getter q sortByKeys sortOrder pSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount includeChannels -> + getUserGroups getter q sortByKeys sortOrder pSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount includeChannels UpdateGroup updater groupId groupUpdate -> updateGroup updater groupId groupUpdate DeleteGroup deleter groupId -> deleteGroup deleter groupId AddUser adder groupId addeeId -> addUser adder groupId addeeId @@ -55,7 +55,6 @@ interpretUserGroupSubsystem = interpret $ \case RemoveUser remover groupId removeeId -> removeUser remover groupId removeeId RemoveUserFromAllGroups uid tid -> removeUserFromAllGroups uid tid UpdateChannels performer groupId channelIds -> updateChannels performer groupId channelIds - ListChannels performer groupId -> listChannels performer groupId data UserGroupSubsystemError = UserGroupNotATeamAdmin @@ -144,11 +143,12 @@ getUserGroup :: ) => UserId -> UserGroupId -> + Bool -> Sem r (Maybe UserGroup) -getUserGroup getter gid = runMaybeT $ do +getUserGroup getter gid includeChannels = runMaybeT $ do team <- MaybeT $ getUserTeam getter getterCanSeeAll <- mkGetterCanSeeAll getter team - userGroup <- MaybeT $ Store.getUserGroup team gid + userGroup <- MaybeT $ Store.getUserGroup team gid includeChannels if getterCanSeeAll || getter `elem` (toList (runIdentity userGroup.members)) then pure userGroup else MaybeT $ pure Nothing @@ -179,8 +179,9 @@ getUserGroups :: Maybe UTCTimeMillis -> Maybe UserGroupId -> Bool -> + Bool -> Sem r UserGroupPage -getUserGroups getter searchString sortBy' sortOrder' mPageSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount' = do +getUserGroups getter searchString sortBy' sortOrder' mPageSize mLastGroupName mLastCreatedAt mLastGroupId includeMemberCount' includeChannels' = do team :: TeamId <- getUserTeam getter >>= ifNothing UserGroupNotATeamAdmin getterCanSeeAll :: Bool <- fromMaybe False <$> runMaybeT (mkGetterCanSeeAll getter team) unless getterCanSeeAll (throw UserGroupNotATeamAdmin) @@ -193,7 +194,8 @@ getUserGroups getter searchString sortBy' sortOrder' mPageSize mLastGroupName mL SortByCreatedAt -> PaginationSortByCreatedAt $ (,) <$> mLastCreatedAt <*> mLastGroupId, team = team, searchString = searchString, - includeMemberCount = includeMemberCount' + includeMemberCount = includeMemberCount', + includeChannels = includeChannels' } Store.getUserGroups pageReq where @@ -258,7 +260,7 @@ addUser :: UserId -> Sem r () addUser adder groupId addeeId = do - ug <- getUserGroup adder groupId >>= note UserGroupNotFound + ug <- getUserGroup adder groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin adder >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam unless (addeeId `elem` runIdentity ug.members) $ do @@ -280,7 +282,7 @@ addUsers :: Vector UserId -> Sem r () addUsers adder groupId addeeIds = do - ug <- getUserGroup adder groupId >>= note UserGroupNotFound + ug <- getUserGroup adder groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin adder >>= note UserGroupNotATeamAdmin forM_ addeeIds $ \addeeId -> internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam @@ -305,7 +307,7 @@ updateUsers :: Vector UserId -> Sem r () updateUsers updater groupId uids = do - void $ getUserGroup updater groupId >>= note UserGroupNotFound + void $ getUserGroup updater groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin updater >>= note UserGroupNotATeamAdmin forM_ uids $ \uid -> internalGetTeamMember uid team >>= note UserGroupMemberIsNotInTheSameTeam @@ -327,7 +329,7 @@ removeUser :: UserId -> Sem r () removeUser remover groupId removeeId = do - ug <- getUserGroup remover groupId >>= note UserGroupNotFound + ug <- getUserGroup remover groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin remover >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember removeeId team >>= note UserGroupMemberIsNotInTheSameTeam when (removeeId `elem` runIdentity ug.members) $ do @@ -385,7 +387,7 @@ updateChannels :: Vector ConvId -> Sem r () updateChannels performer groupId channelIds = do - void $ getUserGroup performer groupId >>= note UserGroupNotFound + void $ getUserGroup performer groupId False >>= note UserGroupNotFound teamId <- getTeamAsAdmin performer >>= note UserGroupNotATeamAdmin for_ channelIds $ \channelId -> do conv <- internalGetConversation channelId >>= note UserGroupChannelNotFound @@ -398,17 +400,3 @@ updateChannels performer groupId channelIds = do pushNotifications [ mkEvent performer (UserGroupUpdated groupId) admins ] - -listChannels :: - ( Member UserSubsystem r, - Member Store.UserGroupStore r, - Member (Error UserGroupSubsystemError) r, - Member TeamSubsystem r - ) => - UserId -> - UserGroupId -> - Sem r (Vector ConvId) -listChannels performer groupId = do - void $ getUserGroup performer groupId >>= note UserGroupNotFound - void $ getUserTeam performer >>= note UserGroupNotATeamAdmin - Store.listUserGroupChannels groupId diff --git a/services/brig/src/Brig/API/Public.hs b/services/brig/src/Brig/API/Public.hs index 53fc7562b7..96f7a8a637 100644 --- a/services/brig/src/Brig/API/Public.hs +++ b/services/brig/src/Brig/API/Public.hs @@ -85,7 +85,6 @@ import Data.Qualified import Data.Range import Data.Schema () import Data.Text.Encoding qualified as Text -import Data.Vector qualified as Vector import Data.ZAuth.CryptoSign (CryptoSign) import Data.ZAuth.Token qualified as ZAuth import FileEmbedLzma @@ -1680,19 +1679,7 @@ createUserGroup lusr newUserGroup = lift . liftSem $ UserGroup.createGroup (tUnq getUserGroup :: (_) => Local UserId -> UserGroupId -> Bool -> Handler r (Maybe UserGroup) getUserGroup lusr ugid includeChannels = - lift . liftSem $ do - mUserGroup <- UserGroup.getGroup (tUnqualified lusr) ugid - if includeChannels - then forM mUserGroup $ \userGroup -> do - fetchedChannels <- - fmap (tUntagged . qualifyAs lusr) - <$> UserGroup.listChannels (tUnqualified lusr) userGroup.id_ - pure - userGroup - { channels = Identity $ Just fetchedChannels, - channelsCount = Just $ Vector.length fetchedChannels - } - else pure mUserGroup + lift . liftSem $ UserGroup.getGroup (tUnqualified lusr) ugid includeChannels getUserGroups :: (_) => @@ -1708,19 +1695,7 @@ getUserGroups :: Bool -> Handler r UserGroupPage getUserGroups lusr q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeChannels includeMemberCount = - lift . liftSem $ do - userGroups <- UserGroup.getGroups (tUnqualified lusr) q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeMemberCount - if includeChannels - then do - newPage <- - forM userGroups.page $ \userGroup -> do - fetchedChannels <- UserGroup.listChannels (tUnqualified lusr) userGroup.id_ - pure - userGroup - { channelsCount = Just $ Vector.length fetchedChannels - } - pure userGroups {page = newPage} - else pure userGroups + lift . liftSem $ UserGroup.getGroups (tUnqualified lusr) q sortByKeys sortOrder pSize mLastName mLastCreatedAt mLastId includeMemberCount includeChannels updateUserGroup :: (_) => Local UserId -> UserGroupId -> UserGroupUpdate -> (Handler r) () updateUserGroup lusr gid gupd = lift . liftSem $ UserGroup.updateGroup (tUnqualified lusr) gid gupd From 07769d8d97194941b5d27c05d36b08c8615b6117 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Tue, 7 Oct 2025 11:55:50 +0000 Subject: [PATCH 04/18] wip get channels impl --- .../src/Wire/UserGroupStore/Postgres.hs | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 87d70034de..be188c257c 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -15,6 +15,7 @@ import Data.Text.Encoding qualified as TE import Data.Time import Data.UUID as UUID import Data.Vector (Vector) +import Data.Vector qualified as V import Hasql.Decoders qualified as HD import Hasql.Encoders qualified as HE import Hasql.Pool @@ -81,9 +82,8 @@ getUserGroup :: Bool -> Sem r (Maybe UserGroup) getUserGroup team id_ includeChannels = do - todo "implement includeChannels" includeChannels pool <- input - eitherUserGroup <- liftIO $ use pool session + eitherUserGroup <- liftIO $ use pool (if includeChannels then sessionWithChannels else session) either throw pure eitherUserGroup where session :: Session (Maybe UserGroup) @@ -95,12 +95,27 @@ getUserGroup team id_ includeChannels = do channels = mempty pure $ UserGroup_ {..} + sessionWithChannels :: Session (Maybe UserGroup) + sessionWithChannels = runMaybeT do + (name, managedBy, createdAt, memberIds, channelIds) <- MaybeT $ statement (id_, team) getGroupWithMembersAndChannelsStatement + let members = Identity (fmap Id memberIds) + membersCount = Just (fromIntegral (V.length memberIds)) + channels = Identity (Just (fmap (todo "qualify channel" . Id) channelIds)) + channelsCount = Just (fromIntegral (V.length channelIds)) + pure $ UserGroup_ {..} + decodeMetadataRow :: (Text, Int32, UTCTime) -> Either Text (UserGroupName, ManagedBy, UTCTimeMillis) decodeMetadataRow (name, managedByInt, utcTime) = (,,toUTCTimeMillis utcTime) <$> userGroupNameFromText name <*> managedByFromInt32 managedByInt + decodeWithArrays :: (Text, Int32, UTCTime, Vector UUID, Vector UUID) -> Either Text (UserGroupName, ManagedBy, UTCTimeMillis, Vector UUID, Vector UUID) + decodeWithArrays (name, managedByInt, utcTime, membs, chans) = do + n <- userGroupNameFromText name + m <- managedByFromInt32 managedByInt + pure (n, m, toUTCTimeMillis utcTime, membs, chans) + getGroupMetadataStatement :: Statement (UserGroupId, TeamId) (Maybe (UserGroupName, ManagedBy, UTCTimeMillis)) getGroupMetadataStatement = lmap (\(gid, tid) -> (gid.toUUID, tid.toUUID)) @@ -117,6 +132,21 @@ getUserGroup team id_ includeChannels = do select (user_id :: uuid) from user_group_member where user_group_id = ($1 :: uuid) |] + getGroupWithMembersAndChannelsStatement :: Statement (UserGroupId, TeamId) (Maybe (UserGroupName, ManagedBy, UTCTimeMillis, Vector UUID, Vector UUID)) + getGroupWithMembersAndChannelsStatement = + lmap (\(gid, tid) -> (gid.toUUID, tid.toUUID)) + . refineResult (mapM decodeWithArrays) + $ [maybeStatement| + select + (name :: text), + (managed_by :: int), + (created_at :: timestamptz), + coalesce((select array_agg(ugm.user_id) from user_group_member ugm where ugm.user_group_id = ug.id), array[]::uuid[]) :: uuid[], + coalesce((select array_agg(ugc.conv_id) from user_group_channel ugc where ugc.user_group_id = ug.id), array[]::uuid[]) :: uuid[] + from user_group ug + where ug.id = ($1 :: uuid) and ug.team_id = ($2 :: uuid) + |] + divide3 :: (Divisible f) => (p -> (a, b, c)) -> f a -> f b -> f c -> f p divide3 f a b c = divide (\p -> let (x, y, z) = f p in (x, (y, z))) a (divide id b c) @@ -132,7 +162,6 @@ getUserGroups :: UserGroupPageRequest -> Sem r UserGroupPage getUserGroups req@(UserGroupPageRequest {..}) = do - todo "implement includeChannels" includeChannels pool <- input eitherResult <- liftIO $ use pool do TxSessions.transaction TxSessions.ReadCommitted TxSessions.Read do @@ -221,19 +250,20 @@ getUserGroups req@(UserGroupPageRequest {..}) = do encodeTime :: HE.Params UTCTimeMillis encodeTime = contramap fromUTCTimeMillis $ HE.param $ HE.nonNullable HE.timestamptz - decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32)] + decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32, Maybe Int32)] decodeRow = HD.rowList - ( (,,,,) + ( (,,,,,) <$> HD.column (HD.nonNullable HD.uuid) <*> HD.column (HD.nonNullable HD.text) <*> HD.column (HD.nonNullable HD.int4) <*> HD.column (HD.nonNullable HD.timestamptz) <*> (if req.includeMemberCount then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing) + <*> (if req.includeChannels then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing) ) - parseRow :: (UUID, Text, Int32, UTCTime, Maybe Int32) -> Either Text UserGroupMeta - parseRow (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw) = do + parseRow :: (UUID, Text, Int32, UTCTime, Maybe Int32, Maybe Int32) -> Either Text UserGroupMeta + parseRow (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw) = do managedBy <- case managedByPre of 0 -> pure ManagedByWire 1 -> pure ManagedByScim @@ -241,7 +271,7 @@ getUserGroups req@(UserGroupPageRequest {..}) = do name <- userGroupNameFromText namePre let members = Const () membersCount = fromIntegral <$> membersCountRaw - channelsCount = Nothing + channelsCount = fromIntegral <$> channelsCountRaw channels = mempty pure $ UserGroup_ {..} @@ -257,6 +287,7 @@ getUserGroups req@(UserGroupPageRequest {..}) = do filter (not . T.null) $ ["id", "name", "managed_by", "created_at"] <> ["(select count(*) from user_group_member as ugm where ugm.user_group_id = ug.id) as members" | includeMemberCount] + <> ["(select count(*) from user_group_channel as ugc where ugc.user_group_id = ug.id) as channels" | includeChannels] whr = "where team_id = ($1 :: uuid)" sortColumn = toSortBy paginationState orderBy = T.unwords ["order by", sortColumnName sortColumn, sortOrderClause sortOrder <> ", id", sortOrderClause sortOrder] From 4831bda9e9502984cf9b32db61ef7d9d9ddf4a45 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Tue, 7 Oct 2025 12:11:43 +0000 Subject: [PATCH 05/18] remove unused function --- .../src/Wire/UserGroupStore/Postgres.hs | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index be188c257c..1bc9ab1050 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -475,28 +475,6 @@ updateUserGroupChannels gid convIds = do on conflict (user_group_id, conv_id) do nothing |] -listUserGroupChannels :: - forall r. - (UserGroupStorePostgresEffectConstraints r) => - UserGroupId -> - Sem r (Vector ConvId) -listUserGroupChannels gid = do - pool <- input - eitherErrorOrUnit <- liftIO $ use pool session - either throw pure eitherErrorOrUnit - where - session :: Session (Vector ConvId) - session = statement gid selectStatement - - selectStatement :: Statement UserGroupId (Vector ConvId) - selectStatement = - dimap - toUUID - (fmap Id) - [vectorStatement| - select (conv_id :: uuid) from user_group_channel where user_group_id = ($1 :: uuid) - |] - crudUser :: forall r. (UserGroupStorePostgresEffectConstraints r) => From 9d497d3d5bc93049fbd403321390cf9042d38bd5 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Tue, 7 Oct 2025 13:00:38 +0000 Subject: [PATCH 06/18] qualify channel ids, update tests --- .../src/Wire/UserGroupStore/Postgres.hs | 51 +++++++++++++----- .../Wire/MockInterpreters/UserGroupStore.hs | 11 ++-- .../UserGroupSubsystem/InterpreterSpec.hs | 54 +++++++++---------- 3 files changed, 70 insertions(+), 46 deletions(-) diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 1bc9ab1050..fad1b1e541 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -9,6 +9,7 @@ import Data.Functor.Contravariant.Divisible import Data.Id import Data.Json.Util import Data.Profunctor +import Data.Qualified (Local, QualifiedWithTag (tUntagged), qualifyAs) import Data.Range import Data.Text qualified as T import Data.Text.Encoding qualified as TE @@ -42,7 +43,7 @@ type UserGroupStorePostgresEffectConstraints r = interpretUserGroupStoreToPostgres :: forall r. - (UserGroupStorePostgresEffectConstraints r) => + (UserGroupStorePostgresEffectConstraints r, Member (Input (Local ())) r) => InterpreterFor UserGroupStore r interpretUserGroupStoreToPostgres = interpret $ \case @@ -74,33 +75,41 @@ updateUsers gid uids = do delete from user_group_member where user_group_id = ($1 :: uuid) |] +-- TODO: move to a shared place +qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) +qualifyLocal a = do + l <- input + pure $ qualifyAs l a + getUserGroup :: forall r. - (UserGroupStorePostgresEffectConstraints r) => + (UserGroupStorePostgresEffectConstraints r, Member (Input (Local ())) r) => TeamId -> UserGroupId -> Bool -> Sem r (Maybe UserGroup) getUserGroup team id_ includeChannels = do pool <- input - eitherUserGroup <- liftIO $ use pool (if includeChannels then sessionWithChannels else session) + loc <- qualifyLocal () + eitherUserGroup <- liftIO $ use pool (if includeChannels then sessionWithChannels loc else session) either throw pure eitherUserGroup where session :: Session (Maybe UserGroup) session = runMaybeT do (name, managedBy, createdAt) <- MaybeT $ statement (id_, team) getGroupMetadataStatement members <- lift $ Identity <$> statement id_ getGroupMembersStatement + -- TODO: add counts let membersCount = Nothing channelsCount = Nothing channels = mempty pure $ UserGroup_ {..} - sessionWithChannels :: Session (Maybe UserGroup) - sessionWithChannels = runMaybeT do + sessionWithChannels :: Local a -> Session (Maybe UserGroup) + sessionWithChannels loc = runMaybeT do (name, managedBy, createdAt, memberIds, channelIds) <- MaybeT $ statement (id_, team) getGroupWithMembersAndChannelsStatement let members = Identity (fmap Id memberIds) membersCount = Just (fromIntegral (V.length memberIds)) - channels = Identity (Just (fmap (todo "qualify channel" . Id) channelIds)) + channels = Identity (Just (fmap (tUntagged . qualifyAs loc . Id) channelIds)) channelsCount = Just (fromIntegral (V.length channelIds)) pure $ UserGroup_ {..} @@ -250,20 +259,34 @@ getUserGroups req@(UserGroupPageRequest {..}) = do encodeTime :: HE.Params UTCTimeMillis encodeTime = contramap fromUTCTimeMillis $ HE.param $ HE.nonNullable HE.timestamptz - decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32, Maybe Int32)] + decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID))] decodeRow = HD.rowList - ( (,,,,,) + ( (,,,,,,) <$> HD.column (HD.nonNullable HD.uuid) <*> HD.column (HD.nonNullable HD.text) <*> HD.column (HD.nonNullable HD.int4) <*> HD.column (HD.nonNullable HD.timestamptz) <*> (if req.includeMemberCount then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing) - <*> (if req.includeChannels then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing) + <*> HD.column (HD.nonNullable HD.int4) + <*> ( if req.includeChannels + then + Just + <$> HD.column + ( HD.nonNullable + ( HD.array + ( HD.dimension + V.replicateM + (HD.element (HD.nonNullable HD.uuid)) + ) + ) + ) + else pure Nothing + ) ) - parseRow :: (UUID, Text, Int32, UTCTime, Maybe Int32, Maybe Int32) -> Either Text UserGroupMeta - parseRow (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw) = do + parseRow :: (UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID)) -> Either Text UserGroupMeta + parseRow (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw, _maybeChannels) = do managedBy <- case managedByPre of 0 -> pure ManagedByWire 1 -> pure ManagedByScim @@ -271,7 +294,8 @@ getUserGroups req@(UserGroupPageRequest {..}) = do name <- userGroupNameFromText namePre let members = Const () membersCount = fromIntegral <$> membersCountRaw - channelsCount = fromIntegral <$> channelsCountRaw + channelsCount = Just (fromIntegral channelsCountRaw) + -- TODO: process channels channels = mempty pure $ UserGroup_ {..} @@ -287,7 +311,8 @@ getUserGroups req@(UserGroupPageRequest {..}) = do filter (not . T.null) $ ["id", "name", "managed_by", "created_at"] <> ["(select count(*) from user_group_member as ugm where ugm.user_group_id = ug.id) as members" | includeMemberCount] - <> ["(select count(*) from user_group_channel as ugc where ugc.user_group_id = ug.id) as channels" | includeChannels] + <> ["(select count(*) from user_group_channel as ugc where ugc.user_group_id = ug.id) as channels"] + <> ["coalesce((select array_agg(ugc.conv_id) from user_group_channel as ugc where ugc.user_group_id = ug.id), array[]::uuid[]) as channel_ids" | includeChannels] whr = "where team_id = ($1 :: uuid)" sortColumn = toSortBy paginationState orderBy = T.unwords ["order by", sortColumnName sortColumn, sortOrderClause sortOrder <> ", id", sortOrderClause sortOrder] diff --git a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs index 59c901df94..eac892b614 100644 --- a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs +++ b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs @@ -57,7 +57,7 @@ userGroupStoreTestInterpreter :: (UserGroupStoreInMemEffectConstraints r) => Int userGroupStoreTestInterpreter = interpret $ \case CreateUserGroup tid ng mb -> createUserGroupImpl tid ng mb - GetUserGroup tid gid -> getUserGroupImpl tid gid + GetUserGroup tid gid includeChannels -> getUserGroupImpl tid gid includeChannels GetUserGroups req -> getUserGroupsImpl req UpdateUserGroup tid gid gup -> updateUserGroupImpl tid gid gup DeleteUserGroup tid gid -> deleteUserGroupImpl tid gid @@ -65,7 +65,6 @@ userGroupStoreTestInterpreter = UpdateUsers gid uids -> updateUsersImpl gid uids RemoveUser gid uid -> removeUserImpl gid uid UpdateUserGroupChannels gid convIds -> updateUserGroupChannelsImpl gid convIds - ListUserGroupChannels gid -> listUserGroupChannelsImpl gid updateUsersImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupId -> Vector UserId -> Sem r () updateUsersImpl gid uids = do @@ -94,8 +93,8 @@ createUserGroupImpl tid nug managedBy = do modify (Map.insert (tid, gid) ug) pure ug -getUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> Sem r (Maybe UserGroup) -getUserGroupImpl tid gid = (Map.lookup (tid, gid)) <$> get @UserGroupInMemState +getUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> Bool -> Sem r (Maybe UserGroup) +getUserGroupImpl tid gid _includeChannels = (Map.lookup (tid, gid)) <$> get @UserGroupInMemState getUserGroupsImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupPageRequest -> Sem r UserGroupPage getUserGroupsImpl UserGroupPageRequest {..} = do @@ -153,7 +152,7 @@ getUserGroupsImpl UserGroupPageRequest {..} = do updateUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> UserGroupUpdate -> Sem r (Maybe ()) updateUserGroupImpl tid gid (UserGroupUpdate newName) = do - exists <- getUserGroupImpl tid gid + exists <- getUserGroupImpl tid gid False let f :: Maybe UserGroup -> Maybe UserGroup f Nothing = Nothing f (Just g) = Just (g {name = newName} :: UserGroup) @@ -163,7 +162,7 @@ updateUserGroupImpl tid gid (UserGroupUpdate newName) = do deleteUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> Sem r (Maybe ()) deleteUserGroupImpl tid gid = do - exists <- getUserGroupImpl tid gid + exists <- getUserGroupImpl tid gid False modify (Map.delete (tid, gid)) pure $ exists $> () diff --git a/libs/wire-subsystems/test/unit/Wire/UserGroupSubsystem/InterpreterSpec.hs b/libs/wire-subsystems/test/unit/Wire/UserGroupSubsystem/InterpreterSpec.hs index 77fd7e3031..e01f2500e8 100644 --- a/libs/wire-subsystems/test/unit/Wire/UserGroupSubsystem/InterpreterSpec.hs +++ b/libs/wire-subsystems/test/unit/Wire/UserGroupSubsystem/InterpreterSpec.hs @@ -138,7 +138,7 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do members = User.userId <$> V.fromList members } createdGroup <- createGroup (ownerId team) newUserGroup - retrievedGroup <- getGroup (ownerId team) createdGroup.id_ + retrievedGroup <- getGroup (ownerId team) createdGroup.id_ False now <- toUTCTimeMillis <$> get let assert = createdGroup.name === newUserGroupName @@ -238,7 +238,7 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do . runDependencies (allUsers team) (galleyTeam team) . interpretUserGroupSubsystem $ do - mGroup <- getGroup (ownerId team) groupId + mGroup <- getGroup (ownerId team) groupId False pure $ mGroup === Nothing prop "team admins can get all groups in their team; outsiders can see nothing" $ \team otherTeam userGroupName -> @@ -253,11 +253,11 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do } group1 <- createGroup (ownerId team) newUserGroup - getGroupAdmin <- getGroup (ownerId team) group1.id_ - getGroupOutsider <- getGroup (ownerId otherTeam) group1.id_ + getGroupAdmin <- getGroup (ownerId team) group1.id_ False + getGroupOutsider <- getGroup (ownerId otherTeam) group1.id_ False - getGroupsAdmin <- getGroups (ownerId team) (Just (userGroupNameToText userGroupName)) Nothing Nothing Nothing Nothing Nothing Nothing False - getGroupsOutsider <- try $ getGroups (ownerId otherTeam) (Just (userGroupNameToText userGroupName)) Nothing Nothing Nothing Nothing Nothing Nothing False + getGroupsAdmin <- getGroups (ownerId team) (Just (userGroupNameToText userGroupName)) Nothing Nothing Nothing Nothing Nothing Nothing False False + getGroupsOutsider <- try $ getGroups (ownerId otherTeam) (Just (userGroupNameToText userGroupName)) Nothing Nothing Nothing Nothing Nothing Nothing False False pure $ getGroupAdmin === Just group1 @@ -288,10 +288,10 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do group1 <- createGroup (ownerId team1) newUserGroup1 group2 <- createGroup (ownerId team2) newUserGroup2 - getOwnGroup <- getGroup (ownerId team1) group1.id_ - getOtherGroup <- getGroup (ownerId team1) group2.id_ - getOwnGroups <- getGroups (ownerId team1) (Just (userGroupNameToText userGroupName1)) Nothing Nothing Nothing Nothing Nothing Nothing False - getOtherGroups <- getGroups (ownerId team1) (Just (userGroupNameToText userGroupName2)) Nothing Nothing Nothing Nothing Nothing Nothing False + getOwnGroup <- getGroup (ownerId team1) group1.id_ False + getOtherGroup <- getGroup (ownerId team1) group2.id_ False + getOwnGroups <- getGroups (ownerId team1) (Just (userGroupNameToText userGroupName1)) Nothing Nothing Nothing Nothing Nothing Nothing False False + getOtherGroups <- getGroups (ownerId team1) (Just (userGroupNameToText userGroupName2)) Nothing Nothing Nothing Nothing Nothing Nothing False False pure $ getOwnGroup === Just group1 @@ -305,10 +305,10 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do let newGroups = [NewUserGroup (either undefined id $ userGroupNameFromText name) mempty | name <- ["1", "2", "2", "33"]] groups <- (\ng -> passTime 1 >> createGroup (ownerId team1) ng) `mapM` newGroups - get0 <- getGroups (ownerId team1) (Just "nope") Nothing Nothing Nothing Nothing Nothing Nothing False - get1 <- getGroups (ownerId team1) (Just "1") Nothing Nothing Nothing Nothing Nothing Nothing False - get2 <- getGroups (ownerId team1) (Just "2") Nothing Nothing Nothing Nothing Nothing Nothing False - get3 <- getGroups (ownerId team1) (Just "3") Nothing Nothing Nothing Nothing Nothing Nothing False + get0 <- getGroups (ownerId team1) (Just "nope") Nothing Nothing Nothing Nothing Nothing Nothing False False + get1 <- getGroups (ownerId team1) (Just "1") Nothing Nothing Nothing Nothing Nothing Nothing False False + get2 <- getGroups (ownerId team1) (Just "2") Nothing Nothing Nothing Nothing Nothing Nothing False False + get3 <- getGroups (ownerId team1) (Just "3") Nothing Nothing Nothing Nothing Nothing Nothing False False pure do get0.page `shouldBe` [] @@ -336,7 +336,7 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do results :: [UserGroupPage] <- do let fetch mLastName mLastCreatedAt mLastGroupId = do - p <- getGroups (ownerId team1) Nothing (Just SortByCreatedAt) Nothing (Just pageSize) mLastName mLastCreatedAt mLastGroupId False + p <- getGroups (ownerId team1) Nothing (Just SortByCreatedAt) Nothing (Just pageSize) mLastName mLastCreatedAt mLastGroupId False False if length p.page < pageSizeToInt pageSize then pure [p] else do @@ -377,9 +377,9 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do group1b <- mkGroup "1" group3b <- mkGroup "3" - sortByDefaults <- getGroups (ownerId team1) Nothing Nothing Nothing Nothing Nothing Nothing Nothing False - sortByNameDesc <- getGroups (ownerId team1) Nothing (Just SortByName) (Just Desc) Nothing Nothing Nothing Nothing False - sortByCreatedAtAsc <- getGroups (ownerId team1) Nothing (Just SortByCreatedAt) (Just Asc) Nothing Nothing Nothing Nothing False + sortByDefaults <- getGroups (ownerId team1) Nothing Nothing Nothing Nothing Nothing Nothing Nothing False False + sortByNameDesc <- getGroups (ownerId team1) Nothing (Just SortByName) (Just Desc) Nothing Nothing Nothing Nothing False False + sortByCreatedAtAsc <- getGroups (ownerId team1) Nothing (Just SortByCreatedAt) (Just Asc) Nothing Nothing Nothing Nothing False False let expectSortByDefaults = [[group1b, group2b, group3b], [group1a, group2a, group3a]] expectSortByNameDesc = [[group3a, group3b], [group2a, group2b], [group1a, group1b]] @@ -409,9 +409,9 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do . interpretUserGroupSubsystem $ do ug0 :: UserGroup <- createGroup (ownerId team) (NewUserGroup originalName mempty) - ug1 :: Maybe UserGroup <- getGroup (ownerId team) ug0.id_ + ug1 :: Maybe UserGroup <- getGroup (ownerId team) ug0.id_ False updateGroup (ownerId team) ug0.id_ userGroupUpdate - ug2 :: Maybe UserGroup <- getGroup (ownerId team) ug0.id_ + ug2 :: Maybe UserGroup <- getGroup (ownerId team) ug0.id_ False pure $ (ug1 === Just ug0) .&&. (ug2 === Just (ug0 {name = userGroupUpdate.name} :: UserGroup)) @@ -475,9 +475,9 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do ug <- createGroup (ownerId team) (NewUserGroup name mempty) ug2 <- createGroup (ownerId team) (NewUserGroup name2 mempty) - mUg <- getGroup (ownerId team) ug.id_ - isDeleted <- isNothing <$> (deleteGroup (ownerId team) ug.id_ >> getGroup (ownerId team) ug.id_) - mUg2 <- getGroup (ownerId team) ug2.id_ + mUg <- getGroup (ownerId team) ug.id_ False + isDeleted <- isNothing <$> (deleteGroup (ownerId team) ug.id_ >> getGroup (ownerId team) ug.id_ False) + mUg2 <- getGroup (ownerId team) ug2.id_ False e1 <- catchExpectedError $ deleteGroup (ownerId team2) ug.id_ e2 <- catchExpectedError $ deleteGroup (ownerId team) (Id UUID.nil) @@ -544,16 +544,16 @@ spec = timeoutHook $ describe "UserGroupSubsystem.Interpreter" do ug :: UserGroup <- createGroup (ownerId team) (NewUserGroup newGroupName mempty) addUser (ownerId team) ug.id_ (User.userId mbr1) - ugWithFirst <- getGroup (ownerId team) ug.id_ + ugWithFirst <- getGroup (ownerId team) ug.id_ False addUser (ownerId team) ug.id_ (User.userId mbr1) - ugWithIdemP <- getGroup (ownerId team) ug.id_ + ugWithIdemP <- getGroup (ownerId team) ug.id_ False addUser (ownerId team) ug.id_ (User.userId mbr2) - ugWithSecond <- getGroup (ownerId team) ug.id_ + ugWithSecond <- getGroup (ownerId team) ug.id_ False removeUser (ownerId team) ug.id_ (User.userId mbr1) - ugWithoutFirst <- getGroup (ownerId team) ug.id_ + ugWithoutFirst <- getGroup (ownerId team) ug.id_ False removeUser (ownerId team) ug.id_ (User.userId mbr1) -- idemp let propertyCheck = ((.members) <$> ugWithFirst) === Just (Identity $ V.fromList [User.userId mbr1]) From 018c2fe412dff1dfe4764e137713c9922a85b6c5 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Tue, 7 Oct 2025 13:10:10 +0000 Subject: [PATCH 07/18] more test cases --- integration/test/Test/UserGroup.hs | 41 +++++++++++++++++-- .../Wire/UserGroupSubsystem/Interpreter.hs | 3 +- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/integration/test/Test/UserGroup.hs b/integration/test/Test/UserGroup.hs index 9cab119809..68899407ed 100644 --- a/integration/test/Test/UserGroup.hs +++ b/integration/test/Test/UserGroup.hs @@ -375,6 +375,7 @@ testUserGroupMembersCount = do resp.json %. "page.0.membersCount" `shouldMatchInt` 2 resp.json %. "total" `shouldMatchInt` 1 +<<<<<<< HEAD testUserGroupRemovalOnDelete :: (HasCallStack) => App () testUserGroupRemovalOnDelete = do (alice, tid, [bob, charlie]) <- createTeam OwnDomain 3 @@ -394,6 +395,13 @@ testUserGroupRemovalOnDelete = do testUserGroupUpdateChannels :: (HasCallStack) => App () testUserGroupUpdateChannels = do +||||||| constructed merge base +testUserGroupUpdateChannels :: (HasCallStack) => App () +testUserGroupUpdateChannels = do +======= +testUserGroupUpdateChannelsSucceeds :: (HasCallStack) => App () +testUserGroupUpdateChannelsSucceeds = do +>>>>>>> more test cases (alice, tid, [_bob]) <- createTeam OwnDomain 2 setTeamFeatureLockStatus alice tid "channels" "unlocked" let config = @@ -412,23 +420,48 @@ testUserGroupUpdateChannels = do >>= getJSON 200 gid <- ug %. "id" & asString +<<<<<<< HEAD convId <- postConversation alice (defMLS {team = Just tid, groupConvType = Just "channel"}) +||||||| constructed merge base + convId <- + postConversation alice (defProteus {team = Just tid}) +======= + convs <- + replicateM 5 + $ postConversation alice (defProteus {team = Just tid}) +>>>>>>> more test cases >>= getJSON 201 >>= objConvId +<<<<<<< HEAD withWebSocket alice $ \wsAlice -> do updateUserGroupChannels alice gid [convId.id_] >>= assertSuccess notif <- awaitMatch isUserGroupUpdatedNotif wsAlice notif %. "payload.0.user_group.id" `shouldMatch` gid +||||||| constructed merge base + updateUserGroupChannels alice gid [convId.id_] >>= assertSuccess +======= +>>>>>>> more test cases + + -- TODO: also check the user-groups search endpoint reflects the channels + updateUserGroupChannels alice gid ((.id_) <$> take 2 convs) >>= assertSuccess + + bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do + resp.status `shouldMatchInt` 200 + (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (take 2 convs) objQid + + updateUserGroupChannels alice gid ((.id_) <$> drop 1 convs) >>= assertSuccess - -- bobId <- asString $ bob %. "id" bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 - resp.json %. "channels" `shouldMatch` [object ["id" .= convId.id_, "domain" .= convId.domain]] + (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (drop 1 convs) objQid --- FUTUREWORK: check the actual associated channels --- resp.json %. "members" `shouldMatch` [bobId] + updateUserGroupChannels alice gid [] >>= assertSuccess + + bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do + resp.status `shouldMatchInt` 200 + (resp.json %. "channels" >>= fmap length . asList) `shouldMatchInt` 0 testUserGroupUpdateChannelsNonAdmin :: (HasCallStack) => App () testUserGroupUpdateChannelsNonAdmin = do diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs index 50b34416d1..1a522145fb 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs @@ -371,7 +371,8 @@ removeUserFromAllGroups uid tid = do fmap Store.userGroupCreatedAtPaginationState mug, team = tid, searchString = Nothing, - includeMemberCount = False + includeMemberCount = False, + includeChannels = False } updateChannels :: From 18252962549ea17bfdfcf7e0d652187f084c5215 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Tue, 7 Oct 2025 16:20:09 +0200 Subject: [PATCH 08/18] fix: mock interpreter --- .../Wire/MockInterpreters/UserGroupStore.hs | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs index eac892b614..bb82139310 100644 --- a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs +++ b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs @@ -18,6 +18,7 @@ import Data.Vector (Vector, fromList) import GHC.Stack import Imports import Polysemy +import Polysemy.Input import Polysemy.Internal (Append) import Polysemy.State import System.Random (StdGen, mkStdGen) @@ -42,6 +43,7 @@ type UserGroupStoreInMemEffectConstraints r = type UserGroupStoreInMemEffectStack = '[ UserGroupStore, State UserGroupInMemState, + Input (Local ()), Rnd.Random, State StdGen ] @@ -50,10 +52,11 @@ runInMemoryUserGroupStore :: (Member MockNow r) => UserGroupInMemState -> Sem (U runInMemoryUserGroupStore state = evalState (mkStdGen 3) . randomToStatefulStdGen + . runInputConst (toLocalUnsafe (Domain "my-domain") ()) . evalState state . userGroupStoreTestInterpreter -userGroupStoreTestInterpreter :: (UserGroupStoreInMemEffectConstraints r) => InterpreterFor UserGroupStore r +userGroupStoreTestInterpreter :: (UserGroupStoreInMemEffectConstraints r, Member (Input (Local ())) r) => InterpreterFor UserGroupStore r userGroupStoreTestInterpreter = interpret $ \case CreateUserGroup tid ng mb -> createUserGroupImpl tid ng mb @@ -94,11 +97,20 @@ createUserGroupImpl tid nug managedBy = do pure ug getUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> Bool -> Sem r (Maybe UserGroup) -getUserGroupImpl tid gid _includeChannels = (Map.lookup (tid, gid)) <$> get @UserGroupInMemState +getUserGroupImpl tid gid includeChannels = fmap filterChannels . Map.lookup (tid, gid) <$> get @UserGroupInMemState + where + filterChannels ug = + if includeChannels + then ug + else (ug :: UserGroup) {channels = mempty} getUserGroupsImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupPageRequest -> Sem r UserGroupPage getUserGroupsImpl UserGroupPageRequest {..} = do - meta <- ((snd <$>) . sieve . fmap (_2 %~ userGroupToMeta) . Map.toList) <$> get @UserGroupInMemState + let filterChannels ug = + if includeChannels + then (ug :: UserGroup) {channels = mempty, channelsCount = Just $ maybe 0 length ug.channels.runIdentity} + else (ug :: UserGroup) {channels = mempty} + meta <- ((snd <$>) . sieve . fmap (_2 %~ userGroupToMeta . filterChannels) . Map.toList) <$> get @UserGroupInMemState pure $ UserGroupPage meta (length meta) where sieve, @@ -183,18 +195,19 @@ removeUserImpl gid uid = do modifyUserGroupsGidOnly gid (Map.alter f) updateUserGroupChannelsImpl :: - (UserGroupStoreInMemEffectConstraints r) => + (UserGroupStoreInMemEffectConstraints r, Member (Input (Local ())) r) => UserGroupId -> Vector ConvId -> Sem r () updateUserGroupChannelsImpl gid convIds = do + qualifyLocal <- qualifyAs <$> input let f :: Maybe UserGroup -> Maybe UserGroup f Nothing = Nothing f (Just g) = Just ( g - { channels = Identity $ Just $ flip Qualified (Domain "") <$> convIds, - channelsCount = Just $ length convIds + { channels = Identity $ Just $ tUntagged . qualifyLocal <$> convIds, + channelsCount = Nothing } :: UserGroup ) From e88b1be20e77da682cebde3d70bb575243baab57 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Tue, 7 Oct 2025 18:19:38 +0200 Subject: [PATCH 09/18] refactor: split UserGroup/UserGroupMeta --- .../src/Wire/API/Routes/Public/Brig.hs | 2 +- libs/wire-api/src/Wire/API/UserGroup.hs | 89 +++++++------------ .../Test/Wire/API/Golden/Manual/Pagination.hs | 16 ++-- .../Test/Wire/API/Golden/Manual/UserGroup.hs | 12 +-- .../golden/testObject_UserGroupPage_2.json | 4 +- .../golden/testObject_UserGroupPage_3.json | 2 +- .../test/golden/testObject_UserGroup_2.json | 1 - .../src/Wire/UserGroupStore/Postgres.hs | 30 +++---- .../Wire/UserGroupSubsystem/Interpreter.hs | 8 +- 9 files changed, 58 insertions(+), 106 deletions(-) diff --git a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs index a81fcd5496..209dfc0b38 100644 --- a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs +++ b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs @@ -342,7 +342,7 @@ type UserGroupAPI = :> QueryParam' '[Optional, Strict, LastSeenNameDesc] "last_seen_name" UserGroupName :> QueryParam' '[Optional, Strict, LastSeenCreatedAtDesc] "last_seen_created_at" UTCTimeMillis :> QueryParam' '[Optional, Strict, LastSeenIdDesc] "last_seen_id" UserGroupId - :> QueryFlag "include_channels" + :> QueryFlag "include_channels_count" :> QueryFlag "include_member_count" :> Get '[JSON] UserGroupPage ) diff --git a/libs/wire-api/src/Wire/API/UserGroup.hs b/libs/wire-api/src/Wire/API/UserGroup.hs index 7b36a366e8..3ac71d58b9 100644 --- a/libs/wire-api/src/Wire/API/UserGroup.hs +++ b/libs/wire-api/src/Wire/API/UserGroup.hs @@ -25,7 +25,6 @@ import Control.Applicative import Data.Aeson qualified as A import Data.Id import Data.Json.Util -import Data.Kind import Data.OpenApi qualified as OpenApi import Data.Qualified (Qualified) import Data.Range @@ -100,86 +99,60 @@ instance ToSchema UserGroupAddUsers where UserGroupAddUsers <$> (.members) .= field "members" (vector schema) -type UserGroup = UserGroup_ Identity - -type UserGroupMeta = UserGroup_ (Const ()) - -userGroupToMeta :: UserGroup -> UserGroupMeta -userGroupToMeta ug = - UserGroup_ - { id_ = ug.id_, - name = ug.name, - members = Const (), - channels = Const (), - membersCount = ug.membersCount, - channelsCount = ug.channelsCount, - managedBy = ug.managedBy, - createdAt = ug.createdAt - } +data UserGroup = UserGroup + { id_ :: UserGroupId, + name :: UserGroupName, + members :: Vector UserId, + channels :: Maybe (Vector (Qualified ConvId)), + managedBy :: ManagedBy, + createdAt :: UTCTimeMillis + } + deriving (Eq, Ord, Show, Generic) + deriving (A.ToJSON, A.FromJSON, OpenApi.ToSchema) via Schema UserGroup + deriving (Arbitrary) via GenericUniform UserGroup -data UserGroup_ (f :: Type -> Type) = UserGroup_ +data UserGroupMeta = UserGroupMeta { id_ :: UserGroupId, name :: UserGroupName, - members :: f (Vector UserId), - channels :: f (Maybe (Vector (Qualified ConvId))), membersCount :: Maybe Int, channelsCount :: Maybe Int, managedBy :: ManagedBy, createdAt :: UTCTimeMillis } - deriving (Generic) - -deriving instance Eq (UserGroup_ (Const ())) - -deriving instance Ord (UserGroup_ (Const ())) - -deriving instance Show (UserGroup_ (Const ())) - -deriving via GenericUniform (UserGroup_ (Const ())) instance Arbitrary (UserGroup_ (Const ())) - -deriving via Schema (UserGroup_ (Const ())) instance A.ToJSON (UserGroup_ (Const ())) - -deriving via Schema (UserGroup_ (Const ())) instance A.FromJSON (UserGroup_ (Const ())) + deriving (Eq, Ord, Show, Generic) + deriving (A.ToJSON, A.FromJSON, OpenApi.ToSchema) via Schema UserGroupMeta + deriving (Arbitrary) via GenericUniform UserGroupMeta -deriving via Schema (UserGroup_ (Const ())) instance OpenApi.ToSchema (UserGroup_ (Const ())) +userGroupToMeta :: UserGroup -> UserGroupMeta +userGroupToMeta ug = + UserGroupMeta + { id_ = ug.id_, + name = ug.name, + membersCount = Just $ length ug.members, + channelsCount = length <$> ug.channels, + managedBy = ug.managedBy, + createdAt = ug.createdAt + } -instance ToSchema (UserGroup_ (Const ())) where +instance ToSchema UserGroupMeta where schema = object "UserGroupMeta" $ - UserGroup_ + UserGroupMeta <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema - <*> (.members) .= pure mempty - <*> (.channels) .= pure mempty <*> (.membersCount) .= maybe_ (optField "membersCount" schema) <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema -deriving instance Eq (UserGroup_ Identity) - -deriving instance Ord (UserGroup_ Identity) - -deriving instance Show (UserGroup_ Identity) - -deriving via GenericUniform (UserGroup_ Identity) instance Arbitrary (UserGroup_ Identity) - -deriving via Schema (UserGroup_ Identity) instance A.ToJSON (UserGroup_ Identity) - -deriving via Schema (UserGroup_ Identity) instance A.FromJSON (UserGroup_ Identity) - -deriving via Schema (UserGroup_ Identity) instance OpenApi.ToSchema (UserGroup_ Identity) - -instance ToSchema (UserGroup_ Identity) where +instance ToSchema UserGroup where schema = object "UserGroup" $ - UserGroup_ + UserGroup <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema - <*> (runIdentity . (.members)) .= field "members" (Identity <$> vector schema) - <*> (runIdentity . (.channels)) .= (Identity <$> maybe_ (optField "channels" (vector schema))) - <*> (.membersCount) .= maybe_ (optField "membersCount" schema) - <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) + <*> (.members) .= field "members" (vector schema) + <*> (.channels) .= (maybe_ (optField "channels" (vector schema))) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema diff --git a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs index 791a8acb58..add8ee0cb8 100644 --- a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs +++ b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs @@ -19,24 +19,22 @@ Just someOtherUTCTime = readUTCTimeMillis "2021-12-12T00:00:00.000Z" ug1 :: UserGroup ug1 = - UserGroup_ + UserGroup { id_ = Id UUID.nil, name = either (error . show) id (userGroupNameFromText "*"), members = mempty, channels = mempty, - membersCount = Nothing, - channelsCount = Just 1, managedBy = ManagedByWire, createdAt = someUTCTime } ug2 :: UserGroup ug2 = - UserGroup_ + UserGroup { id_ = Id . fromJust $ UUID.fromString "63dd98c0-552d-11f0-8df7-b3e03cd56036", name = either (error . show) id (userGroupNameFromText "##name1##"), members = - Identity . Vec.fromList $ + Vec.fromList $ ( Id . fromJust . UUID.fromString <$> [ "1f815fa2-552f-11f0-8642-77f29e68cbc9", "28a9c560-552f-11f0-9082-97e15e952720", @@ -44,22 +42,18 @@ ug2 = ] ), channels = mempty, - membersCount = Nothing, - channelsCount = Just 1, managedBy = ManagedByWire, createdAt = someUTCTime } ug3 :: UserGroup ug3 = - UserGroup_ + UserGroup { id_ = Id . fromJust $ UUID.fromString "60278b50-552d-11f0-892b-ebd66f6c2c30", name = either (error . show) id (userGroupNameFromText "!! user group !!"), members = - Identity $ Vec.fromList (Id . fromJust . UUID.fromString <$> ["37b636e2-552f-11f0-abe8-5bf7b2ad08c9"]), + Vec.fromList (Id . fromJust . UUID.fromString <$> ["37b636e2-552f-11f0-abe8-5bf7b2ad08c9"]), channels = mempty, - membersCount = Nothing, - channelsCount = Nothing, managedBy = ManagedByScim, createdAt = someOtherUTCTime } diff --git a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs index ae40851ffc..5109f6cca3 100644 --- a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs +++ b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs @@ -40,11 +40,9 @@ testObject_UserGroupUpdate_2 = UserGroupUpdate (unsafeToUserGroupName "some name testObject_UserGroup_1 :: UserGroupMeta testObject_UserGroup_1 = - UserGroup_ + UserGroupMeta { id_ = userGroupId1, name = (unsafeToUserGroupName "name"), - members = (Const ()), - channels = (Const ()), membersCount = Nothing, channelsCount = Just 0, managedBy = ManagedByWire, @@ -53,16 +51,14 @@ testObject_UserGroup_1 = testObject_UserGroup_2 :: UserGroup testObject_UserGroup_2 = - UserGroup_ + UserGroup { id_ = userGroupId2, name = (unsafeToUserGroupName "yet another one"), - members = (Identity $ fromList [userId1, userId2]), + members = fromList [userId1, userId2], channels = - Identity . Just . fromList $ + Just . fromList $ [ Qualified (Id (fromJust (UUID.fromString "445c08d2-a16b-49ea-a274-4208bb2efe8f"))) (Domain "example.com") ], - membersCount = Nothing, - channelsCount = Just 1, managedBy = ManagedByScim, createdAt = someUTCTime } diff --git a/libs/wire-api/test/golden/testObject_UserGroupPage_2.json b/libs/wire-api/test/golden/testObject_UserGroupPage_2.json index 8d25eee556..a06529dfce 100644 --- a/libs/wire-api/test/golden/testObject_UserGroupPage_2.json +++ b/libs/wire-api/test/golden/testObject_UserGroupPage_2.json @@ -1,17 +1,17 @@ { "page": [ { - "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "00000000-0000-0000-0000-000000000000", "managedBy": "wire", + "membersCount": 0, "name": "*" }, { - "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "63dd98c0-552d-11f0-8df7-b3e03cd56036", "managedBy": "wire", + "membersCount": 3, "name": "##name1##" } ], diff --git a/libs/wire-api/test/golden/testObject_UserGroupPage_3.json b/libs/wire-api/test/golden/testObject_UserGroupPage_3.json index 4ae0bf5816..aa338298e8 100644 --- a/libs/wire-api/test/golden/testObject_UserGroupPage_3.json +++ b/libs/wire-api/test/golden/testObject_UserGroupPage_3.json @@ -1,10 +1,10 @@ { "page": [ { - "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "63dd98c0-552d-11f0-8df7-b3e03cd56036", "managedBy": "wire", + "membersCount": 3, "name": "##name1##" } ], diff --git a/libs/wire-api/test/golden/testObject_UserGroup_2.json b/libs/wire-api/test/golden/testObject_UserGroup_2.json index a295351d44..dad96c1366 100644 --- a/libs/wire-api/test/golden/testObject_UserGroup_2.json +++ b/libs/wire-api/test/golden/testObject_UserGroup_2.json @@ -5,7 +5,6 @@ "id": "445c08d2-a16b-49ea-a274-4208bb2efe8f" } ], - "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "19bdd268-1adc-11f0-9a71-d351719dd165", "managedBy": "scim", diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index fad1b1e541..158388f718 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -97,21 +97,16 @@ getUserGroup team id_ includeChannels = do session :: Session (Maybe UserGroup) session = runMaybeT do (name, managedBy, createdAt) <- MaybeT $ statement (id_, team) getGroupMetadataStatement - members <- lift $ Identity <$> statement id_ getGroupMembersStatement - -- TODO: add counts - let membersCount = Nothing - channelsCount = Nothing - channels = mempty - pure $ UserGroup_ {..} + members <- lift $ statement id_ getGroupMembersStatement + let channels = mempty + pure $ UserGroup {..} sessionWithChannels :: Local a -> Session (Maybe UserGroup) sessionWithChannels loc = runMaybeT do (name, managedBy, createdAt, memberIds, channelIds) <- MaybeT $ statement (id_, team) getGroupWithMembersAndChannelsStatement - let members = Identity (fmap Id memberIds) - membersCount = Just (fromIntegral (V.length memberIds)) - channels = Identity (Just (fmap (tUntagged . qualifyAs loc . Id) channelIds)) - channelsCount = Just (fromIntegral (V.length channelIds)) - pure $ UserGroup_ {..} + let members = fmap Id memberIds + channels = Just (fmap (tUntagged . qualifyAs loc . Id) channelIds) + pure $ UserGroup {..} decodeMetadataRow :: (Text, Int32, UTCTime) -> Either Text (UserGroupName, ManagedBy, UTCTimeMillis) decodeMetadataRow (name, managedByInt, utcTime) = @@ -292,12 +287,9 @@ getUserGroups req@(UserGroupPageRequest {..}) = do 1 -> pure ManagedByScim bad -> Left $ "Could not parse managedBy value: " <> T.pack (show bad) name <- userGroupNameFromText namePre - let members = Const () - membersCount = fromIntegral <$> membersCountRaw + let membersCount = fromIntegral <$> membersCountRaw channelsCount = Just (fromIntegral channelsCountRaw) - -- TODO: process channels - channels = mempty - pure $ UserGroup_ {..} + pure $ UserGroupMeta {..} -- \| Compile a pagination state into select query to return the next page. Result is the -- query string and the search string (which needs escaping). @@ -358,12 +350,10 @@ createUserGroup team newUserGroup managedBy = do (id_, name, managedBy_, createdAt) <- Tx.statement (newUserGroup.name, team, managedBy) insertGroupStatement Tx.statement (toUUID id_, newUserGroup.members) insertGroupMembersStatement pure - UserGroup_ - { membersCount = Nothing, - members = Identity newUserGroup.members, + UserGroup + { members = newUserGroup.members, channels = mempty, managedBy = managedBy_, - channelsCount = Nothing, id_, name, createdAt diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs index 1a522145fb..e9705faf32 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs @@ -149,7 +149,7 @@ getUserGroup getter gid includeChannels = runMaybeT $ do team <- MaybeT $ getUserTeam getter getterCanSeeAll <- mkGetterCanSeeAll getter team userGroup <- MaybeT $ Store.getUserGroup team gid includeChannels - if getterCanSeeAll || getter `elem` (toList (runIdentity userGroup.members)) + if getterCanSeeAll || getter `elem` toList userGroup.members then pure userGroup else MaybeT $ pure Nothing @@ -263,7 +263,7 @@ addUser adder groupId addeeId = do ug <- getUserGroup adder groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin adder >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam - unless (addeeId `elem` runIdentity ug.members) $ do + unless (addeeId `elem` ug.members) $ do Store.addUser groupId addeeId admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team pushNotifications @@ -287,7 +287,7 @@ addUsers adder groupId addeeIds = do forM_ addeeIds $ \addeeId -> internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam - let missingAddeeIds = toList addeeIds \\ toList (runIdentity ug.members) + let missingAddeeIds = toList addeeIds \\ toList ug.members unless (null missingAddeeIds) $ do mapM_ (Store.addUser groupId) missingAddeeIds admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team @@ -332,7 +332,7 @@ removeUser remover groupId removeeId = do ug <- getUserGroup remover groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin remover >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember removeeId team >>= note UserGroupMemberIsNotInTheSameTeam - when (removeeId `elem` runIdentity ug.members) $ do + when (removeeId `elem` ug.members) $ do Store.removeUser groupId removeeId admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team pushNotifications From 368f4deb9c7f6e551dd5a26bf9db61f88326e912 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Wed, 8 Oct 2025 08:01:47 +0000 Subject: [PATCH 10/18] Revert "refactor: split UserGroup/UserGroupMeta" This reverts commit 6985978e73589e0a60ba9523d57f1e1691d335c7. --- .../src/Wire/API/Routes/Public/Brig.hs | 2 +- libs/wire-api/src/Wire/API/UserGroup.hs | 89 ++++++++++++------- .../Test/Wire/API/Golden/Manual/Pagination.hs | 16 ++-- .../Test/Wire/API/Golden/Manual/UserGroup.hs | 12 ++- .../golden/testObject_UserGroupPage_2.json | 4 +- .../golden/testObject_UserGroupPage_3.json | 2 +- .../test/golden/testObject_UserGroup_2.json | 1 + .../src/Wire/UserGroupStore/Postgres.hs | 30 ++++--- .../Wire/UserGroupSubsystem/Interpreter.hs | 8 +- 9 files changed, 106 insertions(+), 58 deletions(-) diff --git a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs index 209dfc0b38..a81fcd5496 100644 --- a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs +++ b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs @@ -342,7 +342,7 @@ type UserGroupAPI = :> QueryParam' '[Optional, Strict, LastSeenNameDesc] "last_seen_name" UserGroupName :> QueryParam' '[Optional, Strict, LastSeenCreatedAtDesc] "last_seen_created_at" UTCTimeMillis :> QueryParam' '[Optional, Strict, LastSeenIdDesc] "last_seen_id" UserGroupId - :> QueryFlag "include_channels_count" + :> QueryFlag "include_channels" :> QueryFlag "include_member_count" :> Get '[JSON] UserGroupPage ) diff --git a/libs/wire-api/src/Wire/API/UserGroup.hs b/libs/wire-api/src/Wire/API/UserGroup.hs index 3ac71d58b9..7b36a366e8 100644 --- a/libs/wire-api/src/Wire/API/UserGroup.hs +++ b/libs/wire-api/src/Wire/API/UserGroup.hs @@ -25,6 +25,7 @@ import Control.Applicative import Data.Aeson qualified as A import Data.Id import Data.Json.Util +import Data.Kind import Data.OpenApi qualified as OpenApi import Data.Qualified (Qualified) import Data.Range @@ -99,60 +100,86 @@ instance ToSchema UserGroupAddUsers where UserGroupAddUsers <$> (.members) .= field "members" (vector schema) -data UserGroup = UserGroup - { id_ :: UserGroupId, - name :: UserGroupName, - members :: Vector UserId, - channels :: Maybe (Vector (Qualified ConvId)), - managedBy :: ManagedBy, - createdAt :: UTCTimeMillis - } - deriving (Eq, Ord, Show, Generic) - deriving (A.ToJSON, A.FromJSON, OpenApi.ToSchema) via Schema UserGroup - deriving (Arbitrary) via GenericUniform UserGroup +type UserGroup = UserGroup_ Identity -data UserGroupMeta = UserGroupMeta - { id_ :: UserGroupId, - name :: UserGroupName, - membersCount :: Maybe Int, - channelsCount :: Maybe Int, - managedBy :: ManagedBy, - createdAt :: UTCTimeMillis - } - deriving (Eq, Ord, Show, Generic) - deriving (A.ToJSON, A.FromJSON, OpenApi.ToSchema) via Schema UserGroupMeta - deriving (Arbitrary) via GenericUniform UserGroupMeta +type UserGroupMeta = UserGroup_ (Const ()) userGroupToMeta :: UserGroup -> UserGroupMeta userGroupToMeta ug = - UserGroupMeta + UserGroup_ { id_ = ug.id_, name = ug.name, - membersCount = Just $ length ug.members, - channelsCount = length <$> ug.channels, + members = Const (), + channels = Const (), + membersCount = ug.membersCount, + channelsCount = ug.channelsCount, managedBy = ug.managedBy, createdAt = ug.createdAt } -instance ToSchema UserGroupMeta where +data UserGroup_ (f :: Type -> Type) = UserGroup_ + { id_ :: UserGroupId, + name :: UserGroupName, + members :: f (Vector UserId), + channels :: f (Maybe (Vector (Qualified ConvId))), + membersCount :: Maybe Int, + channelsCount :: Maybe Int, + managedBy :: ManagedBy, + createdAt :: UTCTimeMillis + } + deriving (Generic) + +deriving instance Eq (UserGroup_ (Const ())) + +deriving instance Ord (UserGroup_ (Const ())) + +deriving instance Show (UserGroup_ (Const ())) + +deriving via GenericUniform (UserGroup_ (Const ())) instance Arbitrary (UserGroup_ (Const ())) + +deriving via Schema (UserGroup_ (Const ())) instance A.ToJSON (UserGroup_ (Const ())) + +deriving via Schema (UserGroup_ (Const ())) instance A.FromJSON (UserGroup_ (Const ())) + +deriving via Schema (UserGroup_ (Const ())) instance OpenApi.ToSchema (UserGroup_ (Const ())) + +instance ToSchema (UserGroup_ (Const ())) where schema = object "UserGroupMeta" $ - UserGroupMeta + UserGroup_ <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema + <*> (.members) .= pure mempty + <*> (.channels) .= pure mempty <*> (.membersCount) .= maybe_ (optField "membersCount" schema) <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema -instance ToSchema UserGroup where +deriving instance Eq (UserGroup_ Identity) + +deriving instance Ord (UserGroup_ Identity) + +deriving instance Show (UserGroup_ Identity) + +deriving via GenericUniform (UserGroup_ Identity) instance Arbitrary (UserGroup_ Identity) + +deriving via Schema (UserGroup_ Identity) instance A.ToJSON (UserGroup_ Identity) + +deriving via Schema (UserGroup_ Identity) instance A.FromJSON (UserGroup_ Identity) + +deriving via Schema (UserGroup_ Identity) instance OpenApi.ToSchema (UserGroup_ Identity) + +instance ToSchema (UserGroup_ Identity) where schema = object "UserGroup" $ - UserGroup + UserGroup_ <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema - <*> (.members) .= field "members" (vector schema) - <*> (.channels) .= (maybe_ (optField "channels" (vector schema))) + <*> (runIdentity . (.members)) .= field "members" (Identity <$> vector schema) + <*> (runIdentity . (.channels)) .= (Identity <$> maybe_ (optField "channels" (vector schema))) + <*> (.membersCount) .= maybe_ (optField "membersCount" schema) + <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema diff --git a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs index add8ee0cb8..791a8acb58 100644 --- a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs +++ b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/Pagination.hs @@ -19,22 +19,24 @@ Just someOtherUTCTime = readUTCTimeMillis "2021-12-12T00:00:00.000Z" ug1 :: UserGroup ug1 = - UserGroup + UserGroup_ { id_ = Id UUID.nil, name = either (error . show) id (userGroupNameFromText "*"), members = mempty, channels = mempty, + membersCount = Nothing, + channelsCount = Just 1, managedBy = ManagedByWire, createdAt = someUTCTime } ug2 :: UserGroup ug2 = - UserGroup + UserGroup_ { id_ = Id . fromJust $ UUID.fromString "63dd98c0-552d-11f0-8df7-b3e03cd56036", name = either (error . show) id (userGroupNameFromText "##name1##"), members = - Vec.fromList $ + Identity . Vec.fromList $ ( Id . fromJust . UUID.fromString <$> [ "1f815fa2-552f-11f0-8642-77f29e68cbc9", "28a9c560-552f-11f0-9082-97e15e952720", @@ -42,18 +44,22 @@ ug2 = ] ), channels = mempty, + membersCount = Nothing, + channelsCount = Just 1, managedBy = ManagedByWire, createdAt = someUTCTime } ug3 :: UserGroup ug3 = - UserGroup + UserGroup_ { id_ = Id . fromJust $ UUID.fromString "60278b50-552d-11f0-892b-ebd66f6c2c30", name = either (error . show) id (userGroupNameFromText "!! user group !!"), members = - Vec.fromList (Id . fromJust . UUID.fromString <$> ["37b636e2-552f-11f0-abe8-5bf7b2ad08c9"]), + Identity $ Vec.fromList (Id . fromJust . UUID.fromString <$> ["37b636e2-552f-11f0-abe8-5bf7b2ad08c9"]), channels = mempty, + membersCount = Nothing, + channelsCount = Nothing, managedBy = ManagedByScim, createdAt = someOtherUTCTime } diff --git a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs index 5109f6cca3..ae40851ffc 100644 --- a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs +++ b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs @@ -40,9 +40,11 @@ testObject_UserGroupUpdate_2 = UserGroupUpdate (unsafeToUserGroupName "some name testObject_UserGroup_1 :: UserGroupMeta testObject_UserGroup_1 = - UserGroupMeta + UserGroup_ { id_ = userGroupId1, name = (unsafeToUserGroupName "name"), + members = (Const ()), + channels = (Const ()), membersCount = Nothing, channelsCount = Just 0, managedBy = ManagedByWire, @@ -51,14 +53,16 @@ testObject_UserGroup_1 = testObject_UserGroup_2 :: UserGroup testObject_UserGroup_2 = - UserGroup + UserGroup_ { id_ = userGroupId2, name = (unsafeToUserGroupName "yet another one"), - members = fromList [userId1, userId2], + members = (Identity $ fromList [userId1, userId2]), channels = - Just . fromList $ + Identity . Just . fromList $ [ Qualified (Id (fromJust (UUID.fromString "445c08d2-a16b-49ea-a274-4208bb2efe8f"))) (Domain "example.com") ], + membersCount = Nothing, + channelsCount = Just 1, managedBy = ManagedByScim, createdAt = someUTCTime } diff --git a/libs/wire-api/test/golden/testObject_UserGroupPage_2.json b/libs/wire-api/test/golden/testObject_UserGroupPage_2.json index a06529dfce..8d25eee556 100644 --- a/libs/wire-api/test/golden/testObject_UserGroupPage_2.json +++ b/libs/wire-api/test/golden/testObject_UserGroupPage_2.json @@ -1,17 +1,17 @@ { "page": [ { + "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "00000000-0000-0000-0000-000000000000", "managedBy": "wire", - "membersCount": 0, "name": "*" }, { + "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "63dd98c0-552d-11f0-8df7-b3e03cd56036", "managedBy": "wire", - "membersCount": 3, "name": "##name1##" } ], diff --git a/libs/wire-api/test/golden/testObject_UserGroupPage_3.json b/libs/wire-api/test/golden/testObject_UserGroupPage_3.json index aa338298e8..4ae0bf5816 100644 --- a/libs/wire-api/test/golden/testObject_UserGroupPage_3.json +++ b/libs/wire-api/test/golden/testObject_UserGroupPage_3.json @@ -1,10 +1,10 @@ { "page": [ { + "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "63dd98c0-552d-11f0-8df7-b3e03cd56036", "managedBy": "wire", - "membersCount": 3, "name": "##name1##" } ], diff --git a/libs/wire-api/test/golden/testObject_UserGroup_2.json b/libs/wire-api/test/golden/testObject_UserGroup_2.json index dad96c1366..a295351d44 100644 --- a/libs/wire-api/test/golden/testObject_UserGroup_2.json +++ b/libs/wire-api/test/golden/testObject_UserGroup_2.json @@ -5,6 +5,7 @@ "id": "445c08d2-a16b-49ea-a274-4208bb2efe8f" } ], + "channelsCount": 1, "createdAt": "2025-04-16T16:22:21.703Z", "id": "19bdd268-1adc-11f0-9a71-d351719dd165", "managedBy": "scim", diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 158388f718..fad1b1e541 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -97,16 +97,21 @@ getUserGroup team id_ includeChannels = do session :: Session (Maybe UserGroup) session = runMaybeT do (name, managedBy, createdAt) <- MaybeT $ statement (id_, team) getGroupMetadataStatement - members <- lift $ statement id_ getGroupMembersStatement - let channels = mempty - pure $ UserGroup {..} + members <- lift $ Identity <$> statement id_ getGroupMembersStatement + -- TODO: add counts + let membersCount = Nothing + channelsCount = Nothing + channels = mempty + pure $ UserGroup_ {..} sessionWithChannels :: Local a -> Session (Maybe UserGroup) sessionWithChannels loc = runMaybeT do (name, managedBy, createdAt, memberIds, channelIds) <- MaybeT $ statement (id_, team) getGroupWithMembersAndChannelsStatement - let members = fmap Id memberIds - channels = Just (fmap (tUntagged . qualifyAs loc . Id) channelIds) - pure $ UserGroup {..} + let members = Identity (fmap Id memberIds) + membersCount = Just (fromIntegral (V.length memberIds)) + channels = Identity (Just (fmap (tUntagged . qualifyAs loc . Id) channelIds)) + channelsCount = Just (fromIntegral (V.length channelIds)) + pure $ UserGroup_ {..} decodeMetadataRow :: (Text, Int32, UTCTime) -> Either Text (UserGroupName, ManagedBy, UTCTimeMillis) decodeMetadataRow (name, managedByInt, utcTime) = @@ -287,9 +292,12 @@ getUserGroups req@(UserGroupPageRequest {..}) = do 1 -> pure ManagedByScim bad -> Left $ "Could not parse managedBy value: " <> T.pack (show bad) name <- userGroupNameFromText namePre - let membersCount = fromIntegral <$> membersCountRaw + let members = Const () + membersCount = fromIntegral <$> membersCountRaw channelsCount = Just (fromIntegral channelsCountRaw) - pure $ UserGroupMeta {..} + -- TODO: process channels + channels = mempty + pure $ UserGroup_ {..} -- \| Compile a pagination state into select query to return the next page. Result is the -- query string and the search string (which needs escaping). @@ -350,10 +358,12 @@ createUserGroup team newUserGroup managedBy = do (id_, name, managedBy_, createdAt) <- Tx.statement (newUserGroup.name, team, managedBy) insertGroupStatement Tx.statement (toUUID id_, newUserGroup.members) insertGroupMembersStatement pure - UserGroup - { members = newUserGroup.members, + UserGroup_ + { membersCount = Nothing, + members = Identity newUserGroup.members, channels = mempty, managedBy = managedBy_, + channelsCount = Nothing, id_, name, createdAt diff --git a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs index e9705faf32..1a522145fb 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupSubsystem/Interpreter.hs @@ -149,7 +149,7 @@ getUserGroup getter gid includeChannels = runMaybeT $ do team <- MaybeT $ getUserTeam getter getterCanSeeAll <- mkGetterCanSeeAll getter team userGroup <- MaybeT $ Store.getUserGroup team gid includeChannels - if getterCanSeeAll || getter `elem` toList userGroup.members + if getterCanSeeAll || getter `elem` (toList (runIdentity userGroup.members)) then pure userGroup else MaybeT $ pure Nothing @@ -263,7 +263,7 @@ addUser adder groupId addeeId = do ug <- getUserGroup adder groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin adder >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam - unless (addeeId `elem` ug.members) $ do + unless (addeeId `elem` runIdentity ug.members) $ do Store.addUser groupId addeeId admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team pushNotifications @@ -287,7 +287,7 @@ addUsers adder groupId addeeIds = do forM_ addeeIds $ \addeeId -> internalGetTeamMember addeeId team >>= note UserGroupMemberIsNotInTheSameTeam - let missingAddeeIds = toList addeeIds \\ toList ug.members + let missingAddeeIds = toList addeeIds \\ toList (runIdentity ug.members) unless (null missingAddeeIds) $ do mapM_ (Store.addUser groupId) missingAddeeIds admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team @@ -332,7 +332,7 @@ removeUser remover groupId removeeId = do ug <- getUserGroup remover groupId False >>= note UserGroupNotFound team <- getTeamAsAdmin remover >>= note UserGroupNotATeamAdmin void $ internalGetTeamMember removeeId team >>= note UserGroupMemberIsNotInTheSameTeam - when (removeeId `elem` ug.members) $ do + when (removeeId `elem` runIdentity ug.members) $ do Store.removeUser groupId removeeId admins <- fmap (^. TM.userId) . (^. teamMembers) <$> internalGetTeamAdmins team pushNotifications From 0167017422de2c796a1bfb67fff0e91e9b8c9053 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Wed, 8 Oct 2025 08:38:00 +0000 Subject: [PATCH 11/18] fix types and include channels in user-groups response --- integration/test/API/Brig.hs | 16 ++++++- integration/test/Test/UserGroup.hs | 6 ++- libs/wire-api/src/Wire/API/UserGroup.hs | 8 ++-- .../Test/Wire/API/Golden/Manual/UserGroup.hs | 12 +++--- .../src/Wire/UserGroupStore/Postgres.hs | 43 ++++++++++--------- .../Wire/MockInterpreters/UserGroupStore.hs | 6 +-- services/brig/brig.cabal | 3 +- services/brig/default.nix | 1 - 8 files changed, 54 insertions(+), 41 deletions(-) diff --git a/integration/test/API/Brig.hs b/integration/test/API/Brig.hs index 9963e05709..3b4d72d3b0 100644 --- a/integration/test/API/Brig.hs +++ b/integration/test/API/Brig.hs @@ -1079,11 +1079,23 @@ data GetUserGroupsArgs = GetUserGroupsArgs lastName :: Maybe String, lastCreatedAt :: Maybe String, lastId :: Maybe String, - includeMemberCount :: Bool + includeMemberCount :: Bool, + includeChannels :: Bool } instance Default GetUserGroupsArgs where - def = GetUserGroupsArgs Nothing Nothing Nothing Nothing Nothing Nothing Nothing False + def = + GetUserGroupsArgs + { q = Nothing, + sortByKeys = Nothing, + sortOrder = Nothing, + pSize = Nothing, + lastName = Nothing, + lastCreatedAt = Nothing, + lastId = Nothing, + includeMemberCount = False, + includeChannels = False + } getUserGroups :: (MakesValue user) => user -> GetUserGroupsArgs -> App Response getUserGroups user GetUserGroupsArgs {..} = do diff --git a/integration/test/Test/UserGroup.hs b/integration/test/Test/UserGroup.hs index 68899407ed..b0f4e52f0f 100644 --- a/integration/test/Test/UserGroup.hs +++ b/integration/test/Test/UserGroup.hs @@ -339,7 +339,8 @@ testUserGroupGetGroupsAllInputs = do lastName = lastName', lastCreatedAt = lastCreatedAt', lastId = lastId', - includeMemberCount = includeMemberCount' + includeMemberCount = includeMemberCount', + includeChannels = includeChannels' } | q' <- qs, sortBy' <- sortByKeysList, @@ -348,7 +349,8 @@ testUserGroupGetGroupsAllInputs = do lastName' <- lastNames, lastCreatedAt' <- lastCreatedAts, lastId' <- lastIds, - includeMemberCount' <- [False, True] + includeMemberCount' <- [False, True], + includeChannels' <- [False, True] ] where qs = [Nothing, Just "A"] diff --git a/libs/wire-api/src/Wire/API/UserGroup.hs b/libs/wire-api/src/Wire/API/UserGroup.hs index 7b36a366e8..f4dfa28d94 100644 --- a/libs/wire-api/src/Wire/API/UserGroup.hs +++ b/libs/wire-api/src/Wire/API/UserGroup.hs @@ -110,7 +110,7 @@ userGroupToMeta ug = { id_ = ug.id_, name = ug.name, members = Const (), - channels = Const (), + channels = ug.channels, membersCount = ug.membersCount, channelsCount = ug.channelsCount, managedBy = ug.managedBy, @@ -121,8 +121,8 @@ data UserGroup_ (f :: Type -> Type) = UserGroup_ { id_ :: UserGroupId, name :: UserGroupName, members :: f (Vector UserId), - channels :: f (Maybe (Vector (Qualified ConvId))), membersCount :: Maybe Int, + channels :: Maybe (Vector (Qualified ConvId)), channelsCount :: Maybe Int, managedBy :: ManagedBy, createdAt :: UTCTimeMillis @@ -150,8 +150,8 @@ instance ToSchema (UserGroup_ (Const ())) where <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema <*> (.members) .= pure mempty - <*> (.channels) .= pure mempty <*> (.membersCount) .= maybe_ (optField "membersCount" schema) + <*> (.channels) .= maybe_ (optField "channels" (vector schema)) <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema @@ -177,8 +177,8 @@ instance ToSchema (UserGroup_ Identity) where <$> (.id_) .= field "id" schema <*> (.name) .= field "name" schema <*> (runIdentity . (.members)) .= field "members" (Identity <$> vector schema) - <*> (runIdentity . (.channels)) .= (Identity <$> maybe_ (optField "channels" (vector schema))) <*> (.membersCount) .= maybe_ (optField "membersCount" schema) + <*> (.channels) .= maybe_ (optField "channels" (vector schema)) <*> (.channelsCount) .= maybe_ (optField "channelsCount" schema) <*> (.managedBy) .= field "managedBy" schema <*> (.createdAt) .= field "createdAt" schema diff --git a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs index ae40851ffc..4c2f9313df 100644 --- a/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs +++ b/libs/wire-api/test/golden/Test/Wire/API/Golden/Manual/UserGroup.hs @@ -42,9 +42,9 @@ testObject_UserGroup_1 :: UserGroupMeta testObject_UserGroup_1 = UserGroup_ { id_ = userGroupId1, - name = (unsafeToUserGroupName "name"), - members = (Const ()), - channels = (Const ()), + name = unsafeToUserGroupName "name", + members = Const (), + channels = Nothing, membersCount = Nothing, channelsCount = Just 0, managedBy = ManagedByWire, @@ -55,10 +55,10 @@ testObject_UserGroup_2 :: UserGroup testObject_UserGroup_2 = UserGroup_ { id_ = userGroupId2, - name = (unsafeToUserGroupName "yet another one"), - members = (Identity $ fromList [userId1, userId2]), + name = unsafeToUserGroupName "yet another one", + members = Identity $ fromList [userId1, userId2], channels = - Identity . Just . fromList $ + Just . fromList $ [ Qualified (Id (fromJust (UUID.fromString "445c08d2-a16b-49ea-a274-4208bb2efe8f"))) (Domain "example.com") ], membersCount = Nothing, diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index fad1b1e541..2694516a48 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -98,8 +98,7 @@ getUserGroup team id_ includeChannels = do session = runMaybeT do (name, managedBy, createdAt) <- MaybeT $ statement (id_, team) getGroupMetadataStatement members <- lift $ Identity <$> statement id_ getGroupMembersStatement - -- TODO: add counts - let membersCount = Nothing + let membersCount = Just . V.length $ runIdentity members channelsCount = Nothing channels = mempty pure $ UserGroup_ {..} @@ -108,9 +107,9 @@ getUserGroup team id_ includeChannels = do sessionWithChannels loc = runMaybeT do (name, managedBy, createdAt, memberIds, channelIds) <- MaybeT $ statement (id_, team) getGroupWithMembersAndChannelsStatement let members = Identity (fmap Id memberIds) - membersCount = Just (fromIntegral (V.length memberIds)) - channels = Identity (Just (fmap (tUntagged . qualifyAs loc . Id) channelIds)) - channelsCount = Just (fromIntegral (V.length channelIds)) + membersCount = Just $ V.length memberIds + channels = Just (fmap (tUntagged . qualifyAs loc . Id) channelIds) + channelsCount = Just $ V.length channelIds pure $ UserGroup_ {..} decodeMetadataRow :: (Text, Int32, UTCTime) -> Either Text (UserGroupName, ManagedBy, UTCTimeMillis) @@ -167,49 +166,52 @@ divide5 f a b c d e = divide (\p -> let (v, w, x, y, z) = f p in (v, (w, x, y, z getUserGroups :: forall r. - (UserGroupStorePostgresEffectConstraints r) => + ( UserGroupStorePostgresEffectConstraints r, + Member (Input (Local ())) r + ) => UserGroupPageRequest -> Sem r UserGroupPage getUserGroups req@(UserGroupPageRequest {..}) = do pool <- input + loc <- qualifyLocal () eitherResult <- liftIO $ use pool do TxSessions.transaction TxSessions.ReadCommitted TxSessions.Read do - UserGroupPage <$> getUserGroupsSession <*> getCountSession + UserGroupPage <$> getUserGroupsSession loc <*> getCountSession either throw pure eitherResult where - getUserGroupsSession :: Tx.Transaction [UserGroupMeta] - getUserGroupsSession = case (req.searchString, req.paginationState) of + getUserGroupsSession :: Local a -> Tx.Transaction [UserGroupMeta] + getUserGroupsSession loc = case (req.searchString, req.paginationState) of (Nothing, PaginationSortByName Nothing) -> do let encoder = divide id encodeId encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, pageSizeInt) stmt (Nothing, PaginationSortByCreatedAt Nothing) -> do let encoder = divide id encodeId encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, pageSizeInt) stmt (Nothing, PaginationSortByName (Just (name, gid))) -> do let encoder = divide4 id encodeId encodeGroupName encodeId encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, name, gid, pageSizeInt) stmt (Nothing, PaginationSortByCreatedAt (Just (timestamp, gid))) -> do let encoder = divide4 id encodeId encodeTime encodeId encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, timestamp, gid, pageSizeInt) stmt (Just st, PaginationSortByName Nothing) -> do let encoder = divide3 id encodeId encodeText encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, fuzzy st, pageSizeInt) stmt (Just st, PaginationSortByCreatedAt Nothing) -> do let encoder = divide3 id encodeId encodeText encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, fuzzy st, pageSizeInt) stmt (Just st, PaginationSortByName (Just (name, gid))) -> do let encoder = divide5 id encodeId encodeGroupName encodeId encodeText encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, name, gid, fuzzy st, pageSizeInt) stmt (Just st, PaginationSortByCreatedAt (Just (timestamp, gid))) -> do let encoder = divide5 id encodeId encodeTime encodeId encodeText encodeInt - stmt = refineResult (mapM parseRow) $ Statement queryBS encoder decodeRow True + stmt = refineResult (mapM $ parseRow loc) $ Statement queryBS encoder decodeRow True Tx.statement (req.team, timestamp, gid, fuzzy st, pageSizeInt) stmt getCountSession :: Tx.Transaction Int @@ -285,8 +287,8 @@ getUserGroups req@(UserGroupPageRequest {..}) = do ) ) - parseRow :: (UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID)) -> Either Text UserGroupMeta - parseRow (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw, _maybeChannels) = do + parseRow :: Local a -> (UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID)) -> Either Text UserGroupMeta + parseRow loc (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw, maybeChannels) = do managedBy <- case managedByPre of 0 -> pure ManagedByWire 1 -> pure ManagedByScim @@ -295,8 +297,7 @@ getUserGroups req@(UserGroupPageRequest {..}) = do let members = Const () membersCount = fromIntegral <$> membersCountRaw channelsCount = Just (fromIntegral channelsCountRaw) - -- TODO: process channels - channels = mempty + channels = fmap (fmap (tUntagged . qualifyAs loc . Id)) maybeChannels pure $ UserGroup_ {..} -- \| Compile a pagination state into select query to return the next page. Result is the diff --git a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs index bb82139310..9adee53c98 100644 --- a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs +++ b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs @@ -108,7 +108,7 @@ getUserGroupsImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupPageRe getUserGroupsImpl UserGroupPageRequest {..} = do let filterChannels ug = if includeChannels - then (ug :: UserGroup) {channels = mempty, channelsCount = Just $ maybe 0 length ug.channels.runIdentity} + then (ug :: UserGroup) {channels = mempty, channelsCount = Just $ maybe 0 length ug.channels} else (ug :: UserGroup) {channels = mempty} meta <- ((snd <$>) . sieve . fmap (_2 %~ userGroupToMeta . filterChannels) . Map.toList) <$> get @UserGroupInMemState pure $ UserGroupPage meta (length meta) @@ -206,7 +206,7 @@ updateUserGroupChannelsImpl gid convIds = do f (Just g) = Just ( g - { channels = Identity $ Just $ tUntagged . qualifyLocal <$> convIds, + { channels = Just $ tUntagged . qualifyLocal <$> convIds, channelsCount = Nothing } :: UserGroup @@ -219,7 +219,7 @@ listUserGroupChannelsImpl :: UserGroupId -> Sem r (Vector ConvId) listUserGroupChannelsImpl gid = - foldMap (fmap qUnqualified) . (runIdentity . (.channels) . snd <=< find ((== gid) . snd . fst) . Map.toList) + foldMap (fmap qUnqualified) . ((.channels) . snd <=< find ((== gid) . snd . fst) . Map.toList) <$> get @(Map (TeamId, UserGroupId) UserGroup) ---------------------------------------------------------------------- diff --git a/services/brig/brig.cabal b/services/brig/brig.cabal index 4d921d809b..67449f7bb0 100644 --- a/services/brig/brig.cabal +++ b/services/brig/brig.cabal @@ -218,7 +218,7 @@ library , amqp , async >=2.1 , auto-update >=0.1 - , base >=4 && <5 + , base >=4 && <5 , base-prelude , base16-bytestring >=0.1 , base64-bytestring >=1.0 @@ -314,7 +314,6 @@ library , uri-bytestring >=0.2 , utf8-string , uuid >=1.3.5 - , vector >=0.13.2.0 , wai >=3.0 , wai-extra >=3.0 , wai-middleware-gunzip >=0.0.2 diff --git a/services/brig/default.nix b/services/brig/default.nix index 5827352d74..550ed4212b 100644 --- a/services/brig/default.nix +++ b/services/brig/default.nix @@ -266,7 +266,6 @@ mkDerivation { uri-bytestring utf8-string uuid - vector wai wai-extra wai-middleware-gunzip From 2d1c43629a2b5546afe882c5abd50b964b1ea1c9 Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Wed, 8 Oct 2025 08:44:01 +0000 Subject: [PATCH 12/18] test include_channels for user-groups search endpoint --- integration/test/API/Brig.hs | 3 ++- integration/test/Test/UserGroup.hs | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/integration/test/API/Brig.hs b/integration/test/API/Brig.hs index 3b4d72d3b0..1cb9d53308 100644 --- a/integration/test/API/Brig.hs +++ b/integration/test/API/Brig.hs @@ -1111,7 +1111,8 @@ getUserGroups user GetUserGroupsArgs {..} = do ("last_seen_name",) <$> lastName, ("last_seen_created_at",) <$> lastCreatedAt, ("last_seen_id",) <$> lastId, - (if includeMemberCount then Just ("include_member_count", "true") else Nothing) + (if includeMemberCount then Just ("include_member_count", "true") else Nothing), + (if includeChannels then Just ("include_channels", "true") else Nothing) ] ) diff --git a/integration/test/Test/UserGroup.hs b/integration/test/Test/UserGroup.hs index b0f4e52f0f..c977d89cf4 100644 --- a/integration/test/Test/UserGroup.hs +++ b/integration/test/Test/UserGroup.hs @@ -446,25 +446,36 @@ testUserGroupUpdateChannelsSucceeds = do ======= >>>>>>> more test cases - -- TODO: also check the user-groups search endpoint reflects the channels updateUserGroupChannels alice gid ((.id_) <$> take 2 convs) >>= assertSuccess bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (take 2 convs) objQid + bindResponse (getUserGroups alice (def {includeChannels = True})) $ \resp -> do + resp.status `shouldMatchInt` 200 + (resp.json %. "page.0.channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (take 2 convs) objQid + updateUserGroupChannels alice gid ((.id_) <$> drop 1 convs) >>= assertSuccess bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (drop 1 convs) objQid + bindResponse (getUserGroups alice (def {includeChannels = True})) $ \resp -> do + resp.status `shouldMatchInt` 200 + (resp.json %. "page.0.channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (drop 1 convs) objQid + updateUserGroupChannels alice gid [] >>= assertSuccess bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 (resp.json %. "channels" >>= fmap length . asList) `shouldMatchInt` 0 + bindResponse (getUserGroups alice (def {includeChannels = True})) $ \resp -> do + resp.status `shouldMatchInt` 200 + (resp.json %. "page.0.channels" >>= fmap length . asList) `shouldMatchInt` 0 + testUserGroupUpdateChannelsNonAdmin :: (HasCallStack) => App () testUserGroupUpdateChannelsNonAdmin = do (alice, tid, [bob]) <- createTeam OwnDomain 2 From 00cd4033ad700cbfdcb3709d41019609b7b8cbae Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Wed, 8 Oct 2025 14:11:51 +0000 Subject: [PATCH 13/18] fix rebase error --- integration/test/Test/UserGroup.hs | 38 +++++------------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/integration/test/Test/UserGroup.hs b/integration/test/Test/UserGroup.hs index c977d89cf4..4e34207c1c 100644 --- a/integration/test/Test/UserGroup.hs +++ b/integration/test/Test/UserGroup.hs @@ -377,7 +377,6 @@ testUserGroupMembersCount = do resp.json %. "page.0.membersCount" `shouldMatchInt` 2 resp.json %. "total" `shouldMatchInt` 1 -<<<<<<< HEAD testUserGroupRemovalOnDelete :: (HasCallStack) => App () testUserGroupRemovalOnDelete = do (alice, tid, [bob, charlie]) <- createTeam OwnDomain 3 @@ -395,15 +394,8 @@ testUserGroupRemovalOnDelete = do resp.status `shouldMatchInt` 200 resp.json %. "members" `shouldMatch` [charlieId] -testUserGroupUpdateChannels :: (HasCallStack) => App () -testUserGroupUpdateChannels = do -||||||| constructed merge base -testUserGroupUpdateChannels :: (HasCallStack) => App () -testUserGroupUpdateChannels = do -======= testUserGroupUpdateChannelsSucceeds :: (HasCallStack) => App () testUserGroupUpdateChannelsSucceeds = do ->>>>>>> more test cases (alice, tid, [_bob]) <- createTeam OwnDomain 2 setTeamFeatureLockStatus alice tid "channels" "unlocked" let config = @@ -422,31 +414,13 @@ testUserGroupUpdateChannelsSucceeds = do >>= getJSON 200 gid <- ug %. "id" & asString -<<<<<<< HEAD - convId <- - postConversation alice (defMLS {team = Just tid, groupConvType = Just "channel"}) -||||||| constructed merge base - convId <- - postConversation alice (defProteus {team = Just tid}) -======= - convs <- - replicateM 5 - $ postConversation alice (defProteus {team = Just tid}) ->>>>>>> more test cases - >>= getJSON 201 - >>= objConvId -<<<<<<< HEAD + convs <- replicateM 5 $ postConversation alice (defMLS {team = Just tid, groupConvType = Just "channel"}) >>= getJSON 201 >>= objConvId + withWebSocket alice $ \wsAlice -> do - updateUserGroupChannels alice gid [convId.id_] >>= assertSuccess + updateUserGroupChannels alice gid ((.id_) <$> take 2 convs) >>= assertSuccess notif <- awaitMatch isUserGroupUpdatedNotif wsAlice notif %. "payload.0.user_group.id" `shouldMatch` gid -||||||| constructed merge base - updateUserGroupChannels alice gid [convId.id_] >>= assertSuccess -======= ->>>>>>> more test cases - - updateUserGroupChannels alice gid ((.id_) <$> take 2 convs) >>= assertSuccess bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 @@ -456,15 +430,15 @@ testUserGroupUpdateChannelsSucceeds = do resp.status `shouldMatchInt` 200 (resp.json %. "page.0.channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (take 2 convs) objQid - updateUserGroupChannels alice gid ((.id_) <$> drop 1 convs) >>= assertSuccess + updateUserGroupChannels alice gid ((.id_) <$> tail convs) >>= assertSuccess bindResponse (getUserGroupWithChannels alice gid) $ \resp -> do resp.status `shouldMatchInt` 200 - (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (drop 1 convs) objQid + (resp.json %. "channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (tail convs) objQid bindResponse (getUserGroups alice (def {includeChannels = True})) $ \resp -> do resp.status `shouldMatchInt` 200 - (resp.json %. "page.0.channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (drop 1 convs) objQid + (resp.json %. "page.0.channels" >>= asList >>= traverse objQid) `shouldMatchSet` for (tail convs) objQid updateUserGroupChannels alice gid [] >>= assertSuccess From 5aa333dddb8d1c4f0644078838f47211a1ab0abd Mon Sep 17 00:00:00 2001 From: Leif Battermann Date: Wed, 8 Oct 2025 14:22:52 +0000 Subject: [PATCH 14/18] fix and clean up tests --- .../Wire/MockInterpreters/UserGroupStore.hs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs index 9adee53c98..ecf966865a 100644 --- a/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs +++ b/libs/wire-subsystems/test/unit/Wire/MockInterpreters/UserGroupStore.hs @@ -97,20 +97,17 @@ createUserGroupImpl tid nug managedBy = do pure ug getUserGroupImpl :: (UserGroupStoreInMemEffectConstraints r) => TeamId -> UserGroupId -> Bool -> Sem r (Maybe UserGroup) -getUserGroupImpl tid gid includeChannels = fmap filterChannels . Map.lookup (tid, gid) <$> get @UserGroupInMemState - where - filterChannels ug = - if includeChannels - then ug - else (ug :: UserGroup) {channels = mempty} +getUserGroupImpl tid gid includeChannels = fmap (filterChannels includeChannels) . Map.lookup (tid, gid) <$> get @UserGroupInMemState + +filterChannels :: Bool -> UserGroup -> UserGroup +filterChannels includeChannels ug = + if includeChannels + then (ug :: UserGroup) {channelsCount = Just $ maybe 0 length ug.channels} + else (ug :: UserGroup) {channels = mempty} getUserGroupsImpl :: (UserGroupStoreInMemEffectConstraints r) => UserGroupPageRequest -> Sem r UserGroupPage getUserGroupsImpl UserGroupPageRequest {..} = do - let filterChannels ug = - if includeChannels - then (ug :: UserGroup) {channels = mempty, channelsCount = Just $ maybe 0 length ug.channels} - else (ug :: UserGroup) {channels = mempty} - meta <- ((snd <$>) . sieve . fmap (_2 %~ userGroupToMeta . filterChannels) . Map.toList) <$> get @UserGroupInMemState + meta <- ((snd <$>) . sieve . fmap (_2 %~ userGroupToMeta . (filterChannels includeChannels)) . Map.toList) <$> get @UserGroupInMemState pure $ UserGroupPage meta (length meta) where sieve, From 12ac8f5f3e09324af83366949cd8c131b4e6d649 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Wed, 8 Oct 2025 17:11:46 +0200 Subject: [PATCH 15/18] refactor: create `Wire.Qualified.Utils` --- .../src/Wire/Qualified/Utils.hs | 28 +++++++++++++++++++ .../src/Wire/UserGroupStore/Postgres.hs | 7 +---- libs/wire-subsystems/wire-subsystems.cabal | 1 + 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 libs/wire-subsystems/src/Wire/Qualified/Utils.hs diff --git a/libs/wire-subsystems/src/Wire/Qualified/Utils.hs b/libs/wire-subsystems/src/Wire/Qualified/Utils.hs new file mode 100644 index 0000000000..e5bb878689 --- /dev/null +++ b/libs/wire-subsystems/src/Wire/Qualified/Utils.hs @@ -0,0 +1,28 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.Qualified.Utils where + +import Data.Qualified +import Imports +import Polysemy +import Polysemy.Input + +qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) +qualifyLocal a = do + l <- input + pure $ qualifyAs l a diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 2694516a48..2cbb9e7e6a 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -33,6 +33,7 @@ import Wire.API.Pagination import Wire.API.User.Profile import Wire.API.UserGroup hiding (UpdateUserGroupChannels) import Wire.API.UserGroup.Pagination +import Wire.Qualified.Utils import Wire.UserGroupStore (PaginationState (..), UserGroupPageRequest (..), UserGroupStore (..), toSortBy) type UserGroupStorePostgresEffectConstraints r = @@ -75,12 +76,6 @@ updateUsers gid uids = do delete from user_group_member where user_group_id = ($1 :: uuid) |] --- TODO: move to a shared place -qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) -qualifyLocal a = do - l <- input - pure $ qualifyAs l a - getUserGroup :: forall r. (UserGroupStorePostgresEffectConstraints r, Member (Input (Local ())) r) => diff --git a/libs/wire-subsystems/wire-subsystems.cabal b/libs/wire-subsystems/wire-subsystems.cabal index ff514fc848..d035747bff 100644 --- a/libs/wire-subsystems/wire-subsystems.cabal +++ b/libs/wire-subsystems/wire-subsystems.cabal @@ -244,6 +244,7 @@ library Wire.PropertyStore.Cassandra Wire.PropertySubsystem Wire.PropertySubsystem.Interpreter + Wire.Qualified.Utils Wire.RateLimit Wire.RateLimit.Interpreter Wire.Rpc From 895236a63d8bf88a356e486f5d4c23311367eaf7 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Wed, 8 Oct 2025 19:10:16 +0200 Subject: [PATCH 16/18] refactor faor in inputQualifyLocal --- libs/types-common/src/Data/Qualified.hs | 11 +++++ libs/types-common/types-common.cabal | 1 + .../src/Wire/Qualified/Utils.hs | 28 ------------- .../src/Wire/UserGroupStore/Postgres.hs | 7 ++-- libs/wire-subsystems/wire-subsystems.cabal | 1 - services/galley/src/Galley/API/Action.hs | 4 +- services/galley/src/Galley/API/Clients.hs | 4 +- services/galley/src/Galley/API/Federation.hs | 40 +++++++++---------- services/galley/src/Galley/API/Internal.hs | 4 +- services/galley/src/Galley/API/LegalHold.hs | 8 ++-- .../src/Galley/API/LegalHold/Conflicts.hs | 3 +- .../galley/src/Galley/API/MLS/Proposal.hs | 3 +- services/galley/src/Galley/API/Query.hs | 8 ++-- services/galley/src/Galley/API/Teams.hs | 2 +- services/galley/src/Galley/API/Update.hs | 12 +++--- services/galley/src/Galley/API/Util.hs | 6 --- 16 files changed, 58 insertions(+), 84 deletions(-) delete mode 100644 libs/wire-subsystems/src/Wire/Qualified/Utils.hs diff --git a/libs/types-common/src/Data/Qualified.hs b/libs/types-common/src/Data/Qualified.hs index 18407f557f..fc59fd841c 100644 --- a/libs/types-common/src/Data/Qualified.hs +++ b/libs/types-common/src/Data/Qualified.hs @@ -47,6 +47,7 @@ module Data.Qualified deprecatedSchema, qualifiedSchema, qualifiedObjectSchema, + inputQualifyLocal, ) where @@ -61,6 +62,8 @@ import Data.OpenApi (deprecated) import Data.OpenApi qualified as S import Data.Schema import Imports hiding (local) +import Polysemy +import Polysemy.Input import Test.QuickCheck (Arbitrary (arbitrary)) ---------------------------------------------------------------------- @@ -234,3 +237,11 @@ instance S.ToSchema (Qualified Handle) where instance (Arbitrary a) => Arbitrary (Qualified a) where arbitrary = Qualified <$> arbitrary <*> arbitrary + +---------------------------------------------------------------------- +-- Polysemy + +inputQualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) +inputQualifyLocal a = do + l <- input @(Local ()) + pure $ qualifyAs l a diff --git a/libs/types-common/types-common.cabal b/libs/types-common/types-common.cabal index 393ee5b74b..096b81173b 100644 --- a/libs/types-common/types-common.cabal +++ b/libs/types-common/types-common.cabal @@ -125,6 +125,7 @@ library , openapi3 , optparse-applicative >=0.10 , pem + , polysemy , protobuf >=0.2 , QuickCheck >=2.9 , quickcheck-instances >=0.3.16 diff --git a/libs/wire-subsystems/src/Wire/Qualified/Utils.hs b/libs/wire-subsystems/src/Wire/Qualified/Utils.hs deleted file mode 100644 index e5bb878689..0000000000 --- a/libs/wire-subsystems/src/Wire/Qualified/Utils.hs +++ /dev/null @@ -1,28 +0,0 @@ --- This file is part of the Wire Server implementation. --- --- Copyright (C) 2022 Wire Swiss GmbH --- --- This program is free software: you can redistribute it and/or modify it under --- the terms of the GNU Affero General Public License as published by the Free --- Software Foundation, either version 3 of the License, or (at your option) any --- later version. --- --- This program is distributed in the hope that it will be useful, but WITHOUT --- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS --- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more --- details. --- --- You should have received a copy of the GNU Affero General Public License along --- with this program. If not, see . - -module Wire.Qualified.Utils where - -import Data.Qualified -import Imports -import Polysemy -import Polysemy.Input - -qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) -qualifyLocal a = do - l <- input - pure $ qualifyAs l a diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index 2cbb9e7e6a..079e9cc810 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -9,7 +9,7 @@ import Data.Functor.Contravariant.Divisible import Data.Id import Data.Json.Util import Data.Profunctor -import Data.Qualified (Local, QualifiedWithTag (tUntagged), qualifyAs) +import Data.Qualified (Local, QualifiedWithTag (tUntagged), inputQualifyLocal, qualifyAs) import Data.Range import Data.Text qualified as T import Data.Text.Encoding qualified as TE @@ -33,7 +33,6 @@ import Wire.API.Pagination import Wire.API.User.Profile import Wire.API.UserGroup hiding (UpdateUserGroupChannels) import Wire.API.UserGroup.Pagination -import Wire.Qualified.Utils import Wire.UserGroupStore (PaginationState (..), UserGroupPageRequest (..), UserGroupStore (..), toSortBy) type UserGroupStorePostgresEffectConstraints r = @@ -85,7 +84,7 @@ getUserGroup :: Sem r (Maybe UserGroup) getUserGroup team id_ includeChannels = do pool <- input - loc <- qualifyLocal () + loc <- inputQualifyLocal () eitherUserGroup <- liftIO $ use pool (if includeChannels then sessionWithChannels loc else session) either throw pure eitherUserGroup where @@ -168,7 +167,7 @@ getUserGroups :: Sem r UserGroupPage getUserGroups req@(UserGroupPageRequest {..}) = do pool <- input - loc <- qualifyLocal () + loc <- inputQualifyLocal () eitherResult <- liftIO $ use pool do TxSessions.transaction TxSessions.ReadCommitted TxSessions.Read do UserGroupPage <$> getUserGroupsSession loc <*> getCountSession diff --git a/libs/wire-subsystems/wire-subsystems.cabal b/libs/wire-subsystems/wire-subsystems.cabal index d035747bff..ff514fc848 100644 --- a/libs/wire-subsystems/wire-subsystems.cabal +++ b/libs/wire-subsystems/wire-subsystems.cabal @@ -244,7 +244,6 @@ library Wire.PropertyStore.Cassandra Wire.PropertySubsystem Wire.PropertySubsystem.Interpreter - Wire.Qualified.Utils Wire.RateLimit Wire.RateLimit.Interpreter Wire.Rpc diff --git a/services/galley/src/Galley/API/Action.hs b/services/galley/src/Galley/API/Action.hs index b13b5ba0c3..bfefed936c 100644 --- a/services/galley/src/Galley/API/Action.hs +++ b/services/galley/src/Galley/API/Action.hs @@ -961,7 +961,7 @@ updateLocalStateOfRemoteConv :: Maybe ConnId -> Sem r (Maybe Event) updateLocalStateOfRemoteConv rcu con = do - loc <- qualifyLocal () + loc <- inputQualifyLocal () let cu = tUnqualified rcu rconvId = fmap (.convId) rcu qconvId = tUntagged rconvId @@ -1087,7 +1087,7 @@ notifyTypingIndicator :: Sem r TypingDataUpdated notifyTypingIndicator conv qusr mcon ts = do now <- Now.get - lconv <- qualifyLocal conv.id_ + lconv <- inputQualifyLocal conv.id_ let origDomain = qDomain qusr (remoteMemsOrig, remoteMemsOther) = List.partition (\m -> origDomain == tDomain m.id_) conv.remoteMembers localMembers = fmap (.id_) (tryRemoveSelfFromLocalUsers lconv conv.localMembers) diff --git a/services/galley/src/Galley/API/Clients.hs b/services/galley/src/Galley/API/Clients.hs index e39776d2b8..99c93c71c7 100644 --- a/services/galley/src/Galley/API/Clients.hs +++ b/services/galley/src/Galley/API/Clients.hs @@ -90,7 +90,7 @@ rmClient usr cid = do clients <- E.getClients [usr] if (cid `elem` clientIds usr clients) then do - lusr <- qualifyLocal usr + lusr <- inputQualifyLocal usr let nRange1000 = toRange (Proxy @1000) :: Range 1 1000 Int32 firstConvIds <- Query.conversationIdsPageFrom lusr (GetPaginatedConversationIds Nothing nRange1000) goConvs nRange1000 firstConvIds lusr @@ -108,7 +108,7 @@ rmClient usr cid = do for_ localConvs $ \convId -> do mConv <- getConversation convId for_ mConv $ \conv -> do - lconv <- qualifyLocal conv + lconv <- inputQualifyLocal conv removeClient lconv (tUntagged lusr) cid traverse_ removeRemoteMLSClients (rangedChunks remoteConvs) when (mtpHasMore page) $ do diff --git a/services/galley/src/Galley/API/Federation.hs b/services/galley/src/Galley/API/Federation.hs index 8f162bfdee..470ba08a41 100644 --- a/services/galley/src/Galley/API/Federation.hs +++ b/services/galley/src/Galley/API/Federation.hs @@ -154,7 +154,7 @@ onClientRemoved domain req = do for_ req.convs $ \convId -> do mConv <- E.getConversation convId for mConv $ \conv -> do - lconv <- qualifyLocal conv + lconv <- inputQualifyLocal conv removeClient lconv qusr (req.client) pure EmptyResponse @@ -171,7 +171,7 @@ onConversationCreated :: Sem r EmptyResponse onConversationCreated domain rc = do let qrc = fmap (toRemoteUnsafe domain) rc - loc <- qualifyLocal () + loc <- inputQualifyLocal () let (localUserIds, _) = partitionQualified loc (map omQualifiedId (toList (nonCreatorMembers rc))) addedUserIds <- @@ -223,7 +223,7 @@ getConversations :: Sem r GetConversationsResponseV2 getConversations domain (GetConversationsRequest uid cids) = do let ruid = toRemoteUnsafe domain uid - loc <- qualifyLocal () + loc <- inputQualifyLocal () GetConversationsResponseV2 . mapMaybe (Mapping.conversationToRemote (tDomain loc) ruid) <$> E.getConversations cids @@ -282,7 +282,7 @@ leaveConversation :: Sem r LeaveConversationResponse leaveConversation requestingDomain lc = do let leaver = Qualified lc.leaver requestingDomain - lcnv <- qualifyLocal lc.convId + lcnv <- inputQualifyLocal lc.convId res <- runError @@ -371,7 +371,7 @@ onMessageSent domain rmUnqualified = do \ users not in the conversation" :: ByteString ) - loc <- qualifyLocal () + loc <- inputQualifyLocal () void $ sendLocalMessages loc @@ -406,7 +406,7 @@ sendMessage :: sendMessage originDomain msr = do let sender = Qualified msr.sender originDomain msg <- either throwErr pure (fromProto (fromBase64ByteString msr.rawMessage)) - lcnv <- qualifyLocal msr.convId + lcnv <- inputQualifyLocal msr.convId MessageSendResponse <$> postQualifiedOtrMessage User sender Nothing lcnv msg where throwErr = throw . InvalidPayload . LT.pack @@ -435,7 +435,7 @@ onUserDeleted origDomain udcn = do E.spawnMany $ fromRange convIds <&> \c -> do - lc <- qualifyLocal c + lc <- inputQualifyLocal c mconv <- E.getConversation c E.deleteMembers c (UserList [] [deletedUser]) for_ mconv $ \conv -> do @@ -498,7 +498,7 @@ updateConversation :: ConversationUpdateRequest -> Sem r ConversationUpdateResponse updateConversation origDomain updateRequest = do - loc <- qualifyLocal () + loc <- inputQualifyLocal () let rusr = toRemoteUnsafe origDomain updateRequest.user lcnv = qualifyAs loc updateRequest.convId @@ -627,7 +627,7 @@ sendMLSCommitBundle :: Sem r MLSMessageResponse sendMLSCommitBundle remoteDomain msr = handleMLSMessageErrors $ do assertMLSEnabled - loc <- qualifyLocal () + loc <- inputQualifyLocal () let sender = toRemoteUnsafe remoteDomain msr.sender bundle <- either (throw . mlsProtocolError) pure $ @@ -680,7 +680,7 @@ sendMLSMessage :: Sem r MLSMessageResponse sendMLSMessage remoteDomain msr = handleMLSMessageErrors $ do assertMLSEnabled - loc <- qualifyLocal () + loc <- inputQualifyLocal () let sender = toRemoteUnsafe remoteDomain msr.sender raw <- either (throw . mlsProtocolError) pure $ decodeMLS' (fromBase64ByteString msr.rawMessage) msg <- noteS @'MLSUnsupportedMessage $ mkIncomingMessage raw @@ -710,7 +710,7 @@ getSubConversationForRemoteUser domain GetSubConversationsRequest {..} = . mapToGalleyError @MLSGetSubConvStaticErrors $ do let qusr = Qualified gsreqUser domain - lconv <- qualifyLocal gsreqConv + lconv <- inputQualifyLocal gsreqConv getLocalSubConversation qusr lconv gsreqSubConv leaveSubConversation :: @@ -726,7 +726,7 @@ leaveSubConversation :: leaveSubConversation domain lscr = do let rusr = toRemoteUnsafe domain (lscrUser lscr) cid = mkClientIdentity (tUntagged rusr) (lscrClient lscr) - lcnv <- qualifyLocal (lscrConv lscr) + lcnv <- inputQualifyLocal (lscrConv lscr) fmap (either (LeaveSubConversationResponseProtocolError . unTagged) Imports.id) . runError @MLSProtocolError . fmap (either LeaveSubConversationResponseError Imports.id) @@ -755,7 +755,7 @@ deleteSubConversationForRemoteUser domain DeleteSubConversationFedRequest {..} = $ do let qusr = Qualified dscreqUser domain dsc = MLSReset dscreqGroupId dscreqEpoch - lconv <- qualifyLocal dscreqConv + lconv <- inputQualifyLocal dscreqConv resetLocalSubConversation qusr lconv dscreqSubConv dsc getOne2OneConversationV1 :: @@ -770,7 +770,7 @@ getOne2OneConversationV1 domain (GetOne2OneConversationRequest self other) = fmap (Imports.fromRight GetOne2OneConversationNotConnected) . runError @(Tagged 'NotConnected ()) $ do - lother <- qualifyLocal other + lother <- inputQualifyLocal other let rself = toRemoteUnsafe domain self ensureConnectedToRemotes lother [rself] foldQualified @@ -795,7 +795,7 @@ getOne2OneConversation domain (GetOne2OneConversationRequest self other) = . fmap (Imports.fromRight GetOne2OneConversationV2NotConnected) . runError @(Tagged 'NotConnected ()) $ do - lother <- qualifyLocal other + lother <- inputQualifyLocal other let rself = toRemoteUnsafe domain self let getLocal lconv = do mconv <- E.getConversation (tUnqualified lconv) @@ -859,7 +859,7 @@ onMLSMessageSent domain rmm = . runError @(Tagged 'MLSNotEnabled ()) $ do assertMLSEnabled - loc <- qualifyLocal () + loc <- inputQualifyLocal () let rcnv = toRemoteUnsafe domain rmm.conversation let users = Map.keys rmm.recipients (members, allMembers) <- @@ -913,7 +913,7 @@ mlsSendWelcome origDomain req = do . runError @(Tagged 'MLSNotEnabled ()) $ do assertMLSEnabled - loc <- qualifyLocal () + loc <- inputQualifyLocal () now <- Now.get welcome <- either (throw . InternalErrorWithDescription . LT.fromStrict) pure $ @@ -937,10 +937,10 @@ queryGroupInfo origDomain req = let sender = toRemoteUnsafe origDomain . (.sender) $ req state <- case req.conv of Conv convId -> do - lconvId <- qualifyLocal convId + lconvId <- inputQualifyLocal convId getGroupInfoFromLocalConv (tUntagged sender) lconvId SubConv convId subConvId -> do - lconvId <- qualifyLocal convId + lconvId <- inputQualifyLocal convId getSubConversationGroupInfoFromLocalConv (tUntagged sender) subConvId lconvId pure . Base64ByteString @@ -960,7 +960,7 @@ updateTypingIndicator :: Sem r TypingDataUpdateResponse updateTypingIndicator origDomain TypingDataUpdateRequest {..} = do let qusr = Qualified userId origDomain - lcnv <- qualifyLocal convId + lcnv <- inputQualifyLocal convId ret <- runError . mapToRuntimeError @'ConvNotFound ConvNotFound diff --git a/services/galley/src/Galley/API/Internal.hs b/services/galley/src/Galley/API/Internal.hs index b4096b7787..2b2fba906b 100644 --- a/services/galley/src/Galley/API/Internal.hs +++ b/services/galley/src/Galley/API/Internal.hs @@ -145,7 +145,7 @@ ejpdGetConvInfo :: UserId -> Sem r [EJPDConvInfo] ejpdGetConvInfo uid = do - luid <- qualifyLocal uid + luid <- inputQualifyLocal uid firstPage <- Query.conversationIdsPageFrom luid initialPageRequest getPages luid firstPage where @@ -478,7 +478,7 @@ deleteLoop = do liftIO $ threadDelay 1000000 doDelete usr con tid = do - lusr <- qualifyLocal usr + lusr <- inputQualifyLocal usr Teams.uncheckedDeleteTeam lusr con tid safeForever :: String -> App () -> App () diff --git a/services/galley/src/Galley/API/LegalHold.hs b/services/galley/src/Galley/API/LegalHold.hs index 97c16b8369..11545331aa 100644 --- a/services/galley/src/Galley/API/LegalHold.hs +++ b/services/galley/src/Galley/API/LegalHold.hs @@ -288,7 +288,7 @@ removeSettings' tid = spawnMany (map removeLHForUser lhMembers) removeLHForUser :: TeamMember -> Sem r () removeLHForUser member = do - luid <- qualifyLocal (member ^. userId) + luid <- inputQualifyLocal (member ^. userId) removeLegalHoldClientFromUser (tUnqualified luid) LHService.removeLegalHold tid luid changeLegalholdStatusAndHandlePolicyConflicts tid luid (member ^. legalHoldStatus) UserLegalHoldDisabled -- (support for withdrawing consent is not planned yet.) @@ -375,7 +375,7 @@ requestDevice :: Sem r RequestDeviceResult requestDevice lzusr tid uid = do let zusr = tUnqualified lzusr - luid <- qualifyLocal uid + luid <- inputQualifyLocal uid assertLegalHoldEnabledForTeam tid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) @@ -470,7 +470,7 @@ approveDevice :: Sem r () approveDevice lzusr connId tid uid (Public.ApproveLegalHoldForUserRequest mPassword) = do let zusr = tUnqualified lzusr - luid <- qualifyLocal uid + luid <- inputQualifyLocal uid assertLegalHoldEnabledForTeam tid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) @@ -544,7 +544,7 @@ disableForUser :: Public.DisableLegalHoldForUserRequest -> Sem r DisableLegalHoldForUserResponse disableForUser lzusr tid uid (Public.DisableLegalHoldForUserRequest mPassword) = do - luid <- qualifyLocal uid + luid <- inputQualifyLocal uid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) . Log.field "action" (Log.val "LegalHold.disableForUser") diff --git a/services/galley/src/Galley/API/LegalHold/Conflicts.hs b/services/galley/src/Galley/API/LegalHold/Conflicts.hs index c8a1ed9cec..e6c3123c3c 100644 --- a/services/galley/src/Galley/API/LegalHold/Conflicts.hs +++ b/services/galley/src/Galley/API/LegalHold/Conflicts.hs @@ -32,7 +32,6 @@ import Data.Map qualified as Map import Data.Misc import Data.Qualified import Data.Set qualified as Set -import Galley.API.Util import Galley.Effects import Galley.Effects.TeamStore import Galley.Options @@ -66,7 +65,7 @@ guardQualifiedLegalholdPolicyConflicts :: QualifiedUserClients -> Sem r () guardQualifiedLegalholdPolicyConflicts protectee qclients = do - localDomain <- tDomain <$> qualifyLocal () + localDomain <- tDomain <$> inputQualifyLocal () guardLegalholdPolicyConflicts protectee . UserClients . Map.findWithDefault mempty localDomain diff --git a/services/galley/src/Galley/API/MLS/Proposal.hs b/services/galley/src/Galley/API/MLS/Proposal.hs index edd9beb581..154202b271 100644 --- a/services/galley/src/Galley/API/MLS/Proposal.hs +++ b/services/galley/src/Galley/API/MLS/Proposal.hs @@ -40,7 +40,6 @@ import Data.Qualified import Data.Set qualified as Set import Galley.API.Error import Galley.API.MLS.IncomingMessage -import Galley.API.Util import Galley.Effects import Galley.Effects.ProposalStore import Galley.Env @@ -293,7 +292,7 @@ checkExternalProposalUser :: Proposal -> Sem r () checkExternalProposalUser qusr prop = do - loc <- qualifyLocal () + loc <- inputQualifyLocal () foldQualified loc ( \lusr -> case prop of diff --git a/services/galley/src/Galley/API/Query.hs b/services/galley/src/Galley/API/Query.hs index 0c699e8d6b..4a18eb5d55 100644 --- a/services/galley/src/Galley/API/Query.hs +++ b/services/galley/src/Galley/API/Query.hs @@ -127,8 +127,8 @@ getBotConversation :: ConvId -> Sem r Public.BotConvView getBotConversation zbot cnv = do - lcnv <- qualifyLocal cnv - botQuid <- tUntagged <$> qualifyLocal (botUserId zbot) + lcnv <- inputQualifyLocal cnv + botQuid <- tUntagged <$> inputQualifyLocal (botUserId zbot) c <- maskConvAccessDenied $ getConversationAsMember botQuid lcnv let domain = tDomain lcnv cmems = mapMaybe (mkMember domain) (toList c.localMembers) @@ -253,7 +253,7 @@ getLocalConversationInternal :: ConvId -> Sem r Conversation getLocalConversationInternal cid = do - lcid <- qualifyLocal cid + lcid <- inputQualifyLocal cid conv <- getConversationWithError lcid pure $ conversationView (qualifyAs lcid ()) Nothing conv @@ -666,7 +666,7 @@ internalGetMember :: UserId -> Sem r (Maybe Public.Member) internalGetMember qcnv usr = do - lusr <- qualifyLocal usr + lusr <- inputQualifyLocal usr lcnv <- ensureLocal lusr qcnv getLocalSelf lusr (tUnqualified lcnv) diff --git a/services/galley/src/Galley/API/Teams.hs b/services/galley/src/Galley/API/Teams.hs index 652c6e1b3b..2dfbb63dbb 100644 --- a/services/galley/src/Galley/API/Teams.hs +++ b/services/galley/src/Galley/API/Teams.hs @@ -1283,7 +1283,7 @@ userIsTeamOwner :: UserId -> Sem r () userIsTeamOwner tid uid = do - asking <- qualifyLocal uid + asking <- inputQualifyLocal uid mem <- getTeamMember asking tid uid unless (isTeamOwner mem) $ throwS @'AccessDenied diff --git a/services/galley/src/Galley/API/Update.hs b/services/galley/src/Galley/API/Update.hs index 196eb857fe..d7e51730f4 100644 --- a/services/galley/src/Galley/API/Update.hs +++ b/services/galley/src/Galley/API/Update.hs @@ -550,8 +550,8 @@ addCodeUnqualified :: ConvId -> Sem r AddCodeResult addCodeUnqualified mReq usr mbZHost mZcon cnv = do - lusr <- qualifyLocal usr - lcnv <- qualifyLocal cnv + lusr <- inputQualifyLocal usr + lcnv <- inputQualifyLocal cnv addCode lusr mbZHost mZcon lcnv mReq addCode :: @@ -626,7 +626,7 @@ rmCodeUnqualified :: ConvId -> Sem r Event rmCodeUnqualified lusr zcon cnv = do - lcnv <- qualifyLocal cnv + lcnv <- inputQualifyLocal cnv rmCode lusr zcon lcnv rmCode :: @@ -1475,8 +1475,8 @@ postBotMessageUnqualified :: NewOtrMessage -> Sem r (PostOtrResponse ClientMismatch) postBotMessageUnqualified sender cnv ignoreMissing reportMissing message = do - lusr <- qualifyLocal (botUserId sender) - lcnv <- qualifyLocal cnv + lusr <- inputQualifyLocal (botUserId sender) + lcnv <- inputQualifyLocal cnv unqualifyEndpoint lusr (runLocalInput lusr . postQualifiedOtrMessage Bot (tUntagged lusr) Nothing lcnv) @@ -1660,7 +1660,7 @@ memberTypingUnqualified :: TypingStatus -> Sem r () memberTypingUnqualified lusr zcon cnv ts = do - lcnv <- qualifyLocal cnv + lcnv <- inputQualifyLocal cnv memberTyping lusr zcon (tUntagged lcnv) ts addBot :: diff --git a/services/galley/src/Galley/API/Util.hs b/services/galley/src/Galley/API/Util.hs index 500c92596c..55130da99c 100644 --- a/services/galley/src/Galley/API/Util.hs +++ b/services/galley/src/Galley/API/Util.hs @@ -833,12 +833,6 @@ ensureLocal loc = foldQualified loc pure (\_ -> throw FederationNotImplemented) -------------------------------------------------------------------------------- -- Federation -qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) -qualifyLocal a = toLocalUnsafe <$> fmap getDomain input <*> pure a - where - getDomain :: Local () -> Domain - getDomain = tDomain - runLocalInput :: Local x -> Sem (Input (Local ()) ': r) a -> Sem r a runLocalInput = runInputConst . void From 8b6d57df33d007261803a95a95a752dc2f393955 Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Wed, 8 Oct 2025 19:25:34 +0200 Subject: [PATCH 17/18] fix: missing nix --- libs/types-common/default.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/types-common/default.nix b/libs/types-common/default.nix index 0cafcabbfd..63392d3321 100644 --- a/libs/types-common/default.nix +++ b/libs/types-common/default.nix @@ -35,6 +35,7 @@ , openapi3 , optparse-applicative , pem +, polysemy , protobuf , QuickCheck , quickcheck-instances @@ -93,6 +94,7 @@ mkDerivation { openapi3 optparse-applicative pem + polysemy protobuf QuickCheck quickcheck-instances From a0fdb3a5fb28a11edf648fea580bb76b09d93a5b Mon Sep 17 00:00:00 2001 From: Gautier DI FOLCO Date: Wed, 8 Oct 2025 20:42:10 +0200 Subject: [PATCH 18/18] fix: inputQualifyLocal --- services/galley/src/Galley/API/Action.hs | 4 +- services/galley/src/Galley/API/Clients.hs | 4 +- services/galley/src/Galley/API/Federation.hs | 40 +++++++++---------- services/galley/src/Galley/API/Internal.hs | 4 +- services/galley/src/Galley/API/LegalHold.hs | 8 ++-- .../src/Galley/API/LegalHold/Conflicts.hs | 3 +- .../galley/src/Galley/API/MLS/Proposal.hs | 3 +- services/galley/src/Galley/API/Query.hs | 8 ++-- services/galley/src/Galley/API/Teams.hs | 2 +- services/galley/src/Galley/API/Update.hs | 12 +++--- services/galley/src/Galley/API/Util.hs | 6 +++ 11 files changed, 51 insertions(+), 43 deletions(-) diff --git a/services/galley/src/Galley/API/Action.hs b/services/galley/src/Galley/API/Action.hs index bfefed936c..b13b5ba0c3 100644 --- a/services/galley/src/Galley/API/Action.hs +++ b/services/galley/src/Galley/API/Action.hs @@ -961,7 +961,7 @@ updateLocalStateOfRemoteConv :: Maybe ConnId -> Sem r (Maybe Event) updateLocalStateOfRemoteConv rcu con = do - loc <- inputQualifyLocal () + loc <- qualifyLocal () let cu = tUnqualified rcu rconvId = fmap (.convId) rcu qconvId = tUntagged rconvId @@ -1087,7 +1087,7 @@ notifyTypingIndicator :: Sem r TypingDataUpdated notifyTypingIndicator conv qusr mcon ts = do now <- Now.get - lconv <- inputQualifyLocal conv.id_ + lconv <- qualifyLocal conv.id_ let origDomain = qDomain qusr (remoteMemsOrig, remoteMemsOther) = List.partition (\m -> origDomain == tDomain m.id_) conv.remoteMembers localMembers = fmap (.id_) (tryRemoveSelfFromLocalUsers lconv conv.localMembers) diff --git a/services/galley/src/Galley/API/Clients.hs b/services/galley/src/Galley/API/Clients.hs index 99c93c71c7..e39776d2b8 100644 --- a/services/galley/src/Galley/API/Clients.hs +++ b/services/galley/src/Galley/API/Clients.hs @@ -90,7 +90,7 @@ rmClient usr cid = do clients <- E.getClients [usr] if (cid `elem` clientIds usr clients) then do - lusr <- inputQualifyLocal usr + lusr <- qualifyLocal usr let nRange1000 = toRange (Proxy @1000) :: Range 1 1000 Int32 firstConvIds <- Query.conversationIdsPageFrom lusr (GetPaginatedConversationIds Nothing nRange1000) goConvs nRange1000 firstConvIds lusr @@ -108,7 +108,7 @@ rmClient usr cid = do for_ localConvs $ \convId -> do mConv <- getConversation convId for_ mConv $ \conv -> do - lconv <- inputQualifyLocal conv + lconv <- qualifyLocal conv removeClient lconv (tUntagged lusr) cid traverse_ removeRemoteMLSClients (rangedChunks remoteConvs) when (mtpHasMore page) $ do diff --git a/services/galley/src/Galley/API/Federation.hs b/services/galley/src/Galley/API/Federation.hs index 470ba08a41..8f162bfdee 100644 --- a/services/galley/src/Galley/API/Federation.hs +++ b/services/galley/src/Galley/API/Federation.hs @@ -154,7 +154,7 @@ onClientRemoved domain req = do for_ req.convs $ \convId -> do mConv <- E.getConversation convId for mConv $ \conv -> do - lconv <- inputQualifyLocal conv + lconv <- qualifyLocal conv removeClient lconv qusr (req.client) pure EmptyResponse @@ -171,7 +171,7 @@ onConversationCreated :: Sem r EmptyResponse onConversationCreated domain rc = do let qrc = fmap (toRemoteUnsafe domain) rc - loc <- inputQualifyLocal () + loc <- qualifyLocal () let (localUserIds, _) = partitionQualified loc (map omQualifiedId (toList (nonCreatorMembers rc))) addedUserIds <- @@ -223,7 +223,7 @@ getConversations :: Sem r GetConversationsResponseV2 getConversations domain (GetConversationsRequest uid cids) = do let ruid = toRemoteUnsafe domain uid - loc <- inputQualifyLocal () + loc <- qualifyLocal () GetConversationsResponseV2 . mapMaybe (Mapping.conversationToRemote (tDomain loc) ruid) <$> E.getConversations cids @@ -282,7 +282,7 @@ leaveConversation :: Sem r LeaveConversationResponse leaveConversation requestingDomain lc = do let leaver = Qualified lc.leaver requestingDomain - lcnv <- inputQualifyLocal lc.convId + lcnv <- qualifyLocal lc.convId res <- runError @@ -371,7 +371,7 @@ onMessageSent domain rmUnqualified = do \ users not in the conversation" :: ByteString ) - loc <- inputQualifyLocal () + loc <- qualifyLocal () void $ sendLocalMessages loc @@ -406,7 +406,7 @@ sendMessage :: sendMessage originDomain msr = do let sender = Qualified msr.sender originDomain msg <- either throwErr pure (fromProto (fromBase64ByteString msr.rawMessage)) - lcnv <- inputQualifyLocal msr.convId + lcnv <- qualifyLocal msr.convId MessageSendResponse <$> postQualifiedOtrMessage User sender Nothing lcnv msg where throwErr = throw . InvalidPayload . LT.pack @@ -435,7 +435,7 @@ onUserDeleted origDomain udcn = do E.spawnMany $ fromRange convIds <&> \c -> do - lc <- inputQualifyLocal c + lc <- qualifyLocal c mconv <- E.getConversation c E.deleteMembers c (UserList [] [deletedUser]) for_ mconv $ \conv -> do @@ -498,7 +498,7 @@ updateConversation :: ConversationUpdateRequest -> Sem r ConversationUpdateResponse updateConversation origDomain updateRequest = do - loc <- inputQualifyLocal () + loc <- qualifyLocal () let rusr = toRemoteUnsafe origDomain updateRequest.user lcnv = qualifyAs loc updateRequest.convId @@ -627,7 +627,7 @@ sendMLSCommitBundle :: Sem r MLSMessageResponse sendMLSCommitBundle remoteDomain msr = handleMLSMessageErrors $ do assertMLSEnabled - loc <- inputQualifyLocal () + loc <- qualifyLocal () let sender = toRemoteUnsafe remoteDomain msr.sender bundle <- either (throw . mlsProtocolError) pure $ @@ -680,7 +680,7 @@ sendMLSMessage :: Sem r MLSMessageResponse sendMLSMessage remoteDomain msr = handleMLSMessageErrors $ do assertMLSEnabled - loc <- inputQualifyLocal () + loc <- qualifyLocal () let sender = toRemoteUnsafe remoteDomain msr.sender raw <- either (throw . mlsProtocolError) pure $ decodeMLS' (fromBase64ByteString msr.rawMessage) msg <- noteS @'MLSUnsupportedMessage $ mkIncomingMessage raw @@ -710,7 +710,7 @@ getSubConversationForRemoteUser domain GetSubConversationsRequest {..} = . mapToGalleyError @MLSGetSubConvStaticErrors $ do let qusr = Qualified gsreqUser domain - lconv <- inputQualifyLocal gsreqConv + lconv <- qualifyLocal gsreqConv getLocalSubConversation qusr lconv gsreqSubConv leaveSubConversation :: @@ -726,7 +726,7 @@ leaveSubConversation :: leaveSubConversation domain lscr = do let rusr = toRemoteUnsafe domain (lscrUser lscr) cid = mkClientIdentity (tUntagged rusr) (lscrClient lscr) - lcnv <- inputQualifyLocal (lscrConv lscr) + lcnv <- qualifyLocal (lscrConv lscr) fmap (either (LeaveSubConversationResponseProtocolError . unTagged) Imports.id) . runError @MLSProtocolError . fmap (either LeaveSubConversationResponseError Imports.id) @@ -755,7 +755,7 @@ deleteSubConversationForRemoteUser domain DeleteSubConversationFedRequest {..} = $ do let qusr = Qualified dscreqUser domain dsc = MLSReset dscreqGroupId dscreqEpoch - lconv <- inputQualifyLocal dscreqConv + lconv <- qualifyLocal dscreqConv resetLocalSubConversation qusr lconv dscreqSubConv dsc getOne2OneConversationV1 :: @@ -770,7 +770,7 @@ getOne2OneConversationV1 domain (GetOne2OneConversationRequest self other) = fmap (Imports.fromRight GetOne2OneConversationNotConnected) . runError @(Tagged 'NotConnected ()) $ do - lother <- inputQualifyLocal other + lother <- qualifyLocal other let rself = toRemoteUnsafe domain self ensureConnectedToRemotes lother [rself] foldQualified @@ -795,7 +795,7 @@ getOne2OneConversation domain (GetOne2OneConversationRequest self other) = . fmap (Imports.fromRight GetOne2OneConversationV2NotConnected) . runError @(Tagged 'NotConnected ()) $ do - lother <- inputQualifyLocal other + lother <- qualifyLocal other let rself = toRemoteUnsafe domain self let getLocal lconv = do mconv <- E.getConversation (tUnqualified lconv) @@ -859,7 +859,7 @@ onMLSMessageSent domain rmm = . runError @(Tagged 'MLSNotEnabled ()) $ do assertMLSEnabled - loc <- inputQualifyLocal () + loc <- qualifyLocal () let rcnv = toRemoteUnsafe domain rmm.conversation let users = Map.keys rmm.recipients (members, allMembers) <- @@ -913,7 +913,7 @@ mlsSendWelcome origDomain req = do . runError @(Tagged 'MLSNotEnabled ()) $ do assertMLSEnabled - loc <- inputQualifyLocal () + loc <- qualifyLocal () now <- Now.get welcome <- either (throw . InternalErrorWithDescription . LT.fromStrict) pure $ @@ -937,10 +937,10 @@ queryGroupInfo origDomain req = let sender = toRemoteUnsafe origDomain . (.sender) $ req state <- case req.conv of Conv convId -> do - lconvId <- inputQualifyLocal convId + lconvId <- qualifyLocal convId getGroupInfoFromLocalConv (tUntagged sender) lconvId SubConv convId subConvId -> do - lconvId <- inputQualifyLocal convId + lconvId <- qualifyLocal convId getSubConversationGroupInfoFromLocalConv (tUntagged sender) subConvId lconvId pure . Base64ByteString @@ -960,7 +960,7 @@ updateTypingIndicator :: Sem r TypingDataUpdateResponse updateTypingIndicator origDomain TypingDataUpdateRequest {..} = do let qusr = Qualified userId origDomain - lcnv <- inputQualifyLocal convId + lcnv <- qualifyLocal convId ret <- runError . mapToRuntimeError @'ConvNotFound ConvNotFound diff --git a/services/galley/src/Galley/API/Internal.hs b/services/galley/src/Galley/API/Internal.hs index 2b2fba906b..b4096b7787 100644 --- a/services/galley/src/Galley/API/Internal.hs +++ b/services/galley/src/Galley/API/Internal.hs @@ -145,7 +145,7 @@ ejpdGetConvInfo :: UserId -> Sem r [EJPDConvInfo] ejpdGetConvInfo uid = do - luid <- inputQualifyLocal uid + luid <- qualifyLocal uid firstPage <- Query.conversationIdsPageFrom luid initialPageRequest getPages luid firstPage where @@ -478,7 +478,7 @@ deleteLoop = do liftIO $ threadDelay 1000000 doDelete usr con tid = do - lusr <- inputQualifyLocal usr + lusr <- qualifyLocal usr Teams.uncheckedDeleteTeam lusr con tid safeForever :: String -> App () -> App () diff --git a/services/galley/src/Galley/API/LegalHold.hs b/services/galley/src/Galley/API/LegalHold.hs index 11545331aa..97c16b8369 100644 --- a/services/galley/src/Galley/API/LegalHold.hs +++ b/services/galley/src/Galley/API/LegalHold.hs @@ -288,7 +288,7 @@ removeSettings' tid = spawnMany (map removeLHForUser lhMembers) removeLHForUser :: TeamMember -> Sem r () removeLHForUser member = do - luid <- inputQualifyLocal (member ^. userId) + luid <- qualifyLocal (member ^. userId) removeLegalHoldClientFromUser (tUnqualified luid) LHService.removeLegalHold tid luid changeLegalholdStatusAndHandlePolicyConflicts tid luid (member ^. legalHoldStatus) UserLegalHoldDisabled -- (support for withdrawing consent is not planned yet.) @@ -375,7 +375,7 @@ requestDevice :: Sem r RequestDeviceResult requestDevice lzusr tid uid = do let zusr = tUnqualified lzusr - luid <- inputQualifyLocal uid + luid <- qualifyLocal uid assertLegalHoldEnabledForTeam tid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) @@ -470,7 +470,7 @@ approveDevice :: Sem r () approveDevice lzusr connId tid uid (Public.ApproveLegalHoldForUserRequest mPassword) = do let zusr = tUnqualified lzusr - luid <- inputQualifyLocal uid + luid <- qualifyLocal uid assertLegalHoldEnabledForTeam tid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) @@ -544,7 +544,7 @@ disableForUser :: Public.DisableLegalHoldForUserRequest -> Sem r DisableLegalHoldForUserResponse disableForUser lzusr tid uid (Public.DisableLegalHoldForUserRequest mPassword) = do - luid <- inputQualifyLocal uid + luid <- qualifyLocal uid P.debug $ Log.field "targets" (toByteString (tUnqualified luid)) . Log.field "action" (Log.val "LegalHold.disableForUser") diff --git a/services/galley/src/Galley/API/LegalHold/Conflicts.hs b/services/galley/src/Galley/API/LegalHold/Conflicts.hs index e6c3123c3c..c8a1ed9cec 100644 --- a/services/galley/src/Galley/API/LegalHold/Conflicts.hs +++ b/services/galley/src/Galley/API/LegalHold/Conflicts.hs @@ -32,6 +32,7 @@ import Data.Map qualified as Map import Data.Misc import Data.Qualified import Data.Set qualified as Set +import Galley.API.Util import Galley.Effects import Galley.Effects.TeamStore import Galley.Options @@ -65,7 +66,7 @@ guardQualifiedLegalholdPolicyConflicts :: QualifiedUserClients -> Sem r () guardQualifiedLegalholdPolicyConflicts protectee qclients = do - localDomain <- tDomain <$> inputQualifyLocal () + localDomain <- tDomain <$> qualifyLocal () guardLegalholdPolicyConflicts protectee . UserClients . Map.findWithDefault mempty localDomain diff --git a/services/galley/src/Galley/API/MLS/Proposal.hs b/services/galley/src/Galley/API/MLS/Proposal.hs index 154202b271..edd9beb581 100644 --- a/services/galley/src/Galley/API/MLS/Proposal.hs +++ b/services/galley/src/Galley/API/MLS/Proposal.hs @@ -40,6 +40,7 @@ import Data.Qualified import Data.Set qualified as Set import Galley.API.Error import Galley.API.MLS.IncomingMessage +import Galley.API.Util import Galley.Effects import Galley.Effects.ProposalStore import Galley.Env @@ -292,7 +293,7 @@ checkExternalProposalUser :: Proposal -> Sem r () checkExternalProposalUser qusr prop = do - loc <- inputQualifyLocal () + loc <- qualifyLocal () foldQualified loc ( \lusr -> case prop of diff --git a/services/galley/src/Galley/API/Query.hs b/services/galley/src/Galley/API/Query.hs index 4a18eb5d55..0c699e8d6b 100644 --- a/services/galley/src/Galley/API/Query.hs +++ b/services/galley/src/Galley/API/Query.hs @@ -127,8 +127,8 @@ getBotConversation :: ConvId -> Sem r Public.BotConvView getBotConversation zbot cnv = do - lcnv <- inputQualifyLocal cnv - botQuid <- tUntagged <$> inputQualifyLocal (botUserId zbot) + lcnv <- qualifyLocal cnv + botQuid <- tUntagged <$> qualifyLocal (botUserId zbot) c <- maskConvAccessDenied $ getConversationAsMember botQuid lcnv let domain = tDomain lcnv cmems = mapMaybe (mkMember domain) (toList c.localMembers) @@ -253,7 +253,7 @@ getLocalConversationInternal :: ConvId -> Sem r Conversation getLocalConversationInternal cid = do - lcid <- inputQualifyLocal cid + lcid <- qualifyLocal cid conv <- getConversationWithError lcid pure $ conversationView (qualifyAs lcid ()) Nothing conv @@ -666,7 +666,7 @@ internalGetMember :: UserId -> Sem r (Maybe Public.Member) internalGetMember qcnv usr = do - lusr <- inputQualifyLocal usr + lusr <- qualifyLocal usr lcnv <- ensureLocal lusr qcnv getLocalSelf lusr (tUnqualified lcnv) diff --git a/services/galley/src/Galley/API/Teams.hs b/services/galley/src/Galley/API/Teams.hs index 2dfbb63dbb..652c6e1b3b 100644 --- a/services/galley/src/Galley/API/Teams.hs +++ b/services/galley/src/Galley/API/Teams.hs @@ -1283,7 +1283,7 @@ userIsTeamOwner :: UserId -> Sem r () userIsTeamOwner tid uid = do - asking <- inputQualifyLocal uid + asking <- qualifyLocal uid mem <- getTeamMember asking tid uid unless (isTeamOwner mem) $ throwS @'AccessDenied diff --git a/services/galley/src/Galley/API/Update.hs b/services/galley/src/Galley/API/Update.hs index d7e51730f4..196eb857fe 100644 --- a/services/galley/src/Galley/API/Update.hs +++ b/services/galley/src/Galley/API/Update.hs @@ -550,8 +550,8 @@ addCodeUnqualified :: ConvId -> Sem r AddCodeResult addCodeUnqualified mReq usr mbZHost mZcon cnv = do - lusr <- inputQualifyLocal usr - lcnv <- inputQualifyLocal cnv + lusr <- qualifyLocal usr + lcnv <- qualifyLocal cnv addCode lusr mbZHost mZcon lcnv mReq addCode :: @@ -626,7 +626,7 @@ rmCodeUnqualified :: ConvId -> Sem r Event rmCodeUnqualified lusr zcon cnv = do - lcnv <- inputQualifyLocal cnv + lcnv <- qualifyLocal cnv rmCode lusr zcon lcnv rmCode :: @@ -1475,8 +1475,8 @@ postBotMessageUnqualified :: NewOtrMessage -> Sem r (PostOtrResponse ClientMismatch) postBotMessageUnqualified sender cnv ignoreMissing reportMissing message = do - lusr <- inputQualifyLocal (botUserId sender) - lcnv <- inputQualifyLocal cnv + lusr <- qualifyLocal (botUserId sender) + lcnv <- qualifyLocal cnv unqualifyEndpoint lusr (runLocalInput lusr . postQualifiedOtrMessage Bot (tUntagged lusr) Nothing lcnv) @@ -1660,7 +1660,7 @@ memberTypingUnqualified :: TypingStatus -> Sem r () memberTypingUnqualified lusr zcon cnv ts = do - lcnv <- inputQualifyLocal cnv + lcnv <- qualifyLocal cnv memberTyping lusr zcon (tUntagged lcnv) ts addBot :: diff --git a/services/galley/src/Galley/API/Util.hs b/services/galley/src/Galley/API/Util.hs index 55130da99c..500c92596c 100644 --- a/services/galley/src/Galley/API/Util.hs +++ b/services/galley/src/Galley/API/Util.hs @@ -833,6 +833,12 @@ ensureLocal loc = foldQualified loc pure (\_ -> throw FederationNotImplemented) -------------------------------------------------------------------------------- -- Federation +qualifyLocal :: (Member (Input (Local ())) r) => a -> Sem r (Local a) +qualifyLocal a = toLocalUnsafe <$> fmap getDomain input <*> pure a + where + getDomain :: Local () -> Domain + getDomain = tDomain + runLocalInput :: Local x -> Sem (Input (Local ()) ': r) a -> Sem r a runLocalInput = runInputConst . void