@@ -39,9 +39,14 @@ import (
3939//
4040// If returned walker is nil, then there are no more rooms left to traverse. This method does not modify the provided walker, so it
4141// can be cached.
42- func (querier * Queryer ) QueryNextRoomHierarchyPage (ctx context.Context , walker roomserver.RoomHierarchyWalker , limit int ) ([]fclient.RoomHierarchyRoom , * roomserver.RoomHierarchyWalker , error ) {
43- if authorised , _ := authorised (ctx , querier , walker .Caller , walker .RootRoomID , nil ); ! authorised {
44- return nil , nil , roomserver.ErrRoomUnknownOrNotAllowed {Err : fmt .Errorf ("room is unknown/forbidden" )}
42+ func (querier * Queryer ) QueryNextRoomHierarchyPage (ctx context.Context , walker roomserver.RoomHierarchyWalker , limit int ) (
43+ []fclient.RoomHierarchyRoom ,
44+ []string ,
45+ * roomserver.RoomHierarchyWalker ,
46+ error ,
47+ ) {
48+ if authorised , _ , _ := authorised (ctx , querier , walker .Caller , walker .RootRoomID , nil ); ! authorised {
49+ return nil , []string {walker .RootRoomID .String ()}, nil , roomserver.ErrRoomUnknownOrNotAllowed {Err : fmt .Errorf ("room is unknown/forbidden" )}
4550 }
4651
4752 discoveredRooms := []fclient.RoomHierarchyRoom {}
@@ -50,6 +55,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
5055 unvisited := make ([]roomserver.RoomHierarchyWalkerQueuedRoom , len (walker .Unvisited ))
5156 copy (unvisited , walker .Unvisited )
5257 processed := walker .Processed .Copy ()
58+ inaccessible := []string {}
5359
5460 // Depth first -> stack data structure
5561 for len (unvisited ) > 0 {
@@ -108,7 +114,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
108114 // as these children may be rooms we do know about.
109115 roomType = spec .MSpace
110116 }
111- } else if authorised , isJoinedOrInvited := authorised (ctx , querier , walker .Caller , queuedRoom .RoomID , queuedRoom .ParentRoomID ); authorised {
117+ } else if authorised , isJoinedOrInvited , allowedRoomIDs := authorised (ctx , querier , walker .Caller , queuedRoom .RoomID , queuedRoom .ParentRoomID ); authorised {
112118 // Get all `m.space.child` state events for this room
113119 events , err := childReferences (ctx , querier , walker .SuggestedOnly , queuedRoom .RoomID )
114120 if err != nil {
@@ -125,14 +131,18 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
125131 }
126132
127133 discoveredRooms = append (discoveredRooms , fclient.RoomHierarchyRoom {
128- PublicRoom : * pubRoom ,
129- RoomType : roomType ,
130- ChildrenState : events ,
134+ PublicRoom : * pubRoom ,
135+ RoomType : roomType ,
136+ ChildrenState : events ,
137+ AllowedRoomIDs : allowedRoomIDs ,
131138 })
132139 // don't walk children if the user is not joined/invited to the space
133140 if ! isJoinedOrInvited {
134141 continue
135142 }
143+ } else if ! authorised {
144+ inaccessible = append (inaccessible , queuedRoom .RoomID .String ())
145+ continue
136146 } else {
137147 // room exists but user is not authorised
138148 continue
@@ -149,6 +159,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
149159 // We need to invert the order here because the child events are lo->hi on the timestamp,
150160 // so we need to ensure we pop in the same lo->hi order, which won't be the case if we
151161 // insert the highest timestamp last in a stack.
162+ extendQueueLoop:
152163 for i := len (discoveredChildEvents ) - 1 ; i >= 0 ; i -- {
153164 spaceContent := struct {
154165 Via []string `json:"via"`
@@ -161,6 +172,12 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
161172 if err != nil {
162173 util .GetLogger (ctx ).WithError (err ).WithField ("invalid_room_id" , ev .StateKey ).WithField ("parent_room_id" , queuedRoom .RoomID ).Warn ("Invalid room ID in m.space.child state event" )
163174 } else {
175+ // Make sure not to queue inaccessible rooms
176+ for _ , inaccessibleRoomID := range inaccessible {
177+ if inaccessibleRoomID == childRoomID .String () {
178+ continue extendQueueLoop
179+ }
180+ }
164181 unvisited = append (unvisited , roomserver.RoomHierarchyWalkerQueuedRoom {
165182 RoomID : * childRoomID ,
166183 ParentRoomID : & queuedRoom .RoomID ,
@@ -173,7 +190,7 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
173190
174191 if len (unvisited ) == 0 {
175192 // If no more rooms to walk, then don't return a walker for future pages
176- return discoveredRooms , nil , nil
193+ return discoveredRooms , inaccessible , nil , nil
177194 } else {
178195 // If there are more rooms to walk, then return a new walker to resume walking from (for querying more pages)
179196 newWalker := roomserver.RoomHierarchyWalker {
@@ -185,22 +202,25 @@ func (querier *Queryer) QueryNextRoomHierarchyPage(ctx context.Context, walker r
185202 Processed : processed ,
186203 }
187204
188- return discoveredRooms , & newWalker , nil
205+ return discoveredRooms , inaccessible , & newWalker , nil
189206 }
190207
191208}
192209
193210// authorised returns true iff the user is joined this room or the room is world_readable
194- func authorised (ctx context.Context , querier * Queryer , caller types.DeviceOrServerName , roomID spec.RoomID , parentRoomID * spec.RoomID ) (authed , isJoinedOrInvited bool ) {
211+ func authorised (ctx context.Context , querier * Queryer , caller types.DeviceOrServerName , roomID spec.RoomID , parentRoomID * spec.RoomID ) (authed , isJoinedOrInvited bool , resultAllowedRoomIDs [] string ) {
195212 if clientCaller := caller .Device (); clientCaller != nil {
196213 return authorisedUser (ctx , querier , clientCaller , roomID , parentRoomID )
197- } else {
198- return authorisedServer (ctx , querier , roomID , * caller .ServerName ()), false
199214 }
215+ if serverCaller := caller .ServerName (); serverCaller != nil {
216+ authed , resultAllowedRoomIDs = authorisedServer (ctx , querier , roomID , * serverCaller )
217+ return authed , false , resultAllowedRoomIDs
218+ }
219+ return false , false , resultAllowedRoomIDs
200220}
201221
202222// authorisedServer returns true iff the server is joined this room or the room is world_readable, public, or knockable
203- func authorisedServer (ctx context.Context , querier * Queryer , roomID spec.RoomID , callerServerName spec.ServerName ) bool {
223+ func authorisedServer (ctx context.Context , querier * Queryer , roomID spec.RoomID , callerServerName spec.ServerName ) ( bool , [] string ) {
204224 // Check history visibility / join rules first
205225 hisVisTuple := gomatrixserverlib.StateKeyTuple {
206226 EventType : spec .MRoomHistoryVisibility ,
@@ -219,13 +239,13 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
219239 }, & queryRoomRes )
220240 if err != nil {
221241 util .GetLogger (ctx ).WithError (err ).Error ("failed to QueryCurrentState" )
222- return false
242+ return false , [] string {}
223243 }
224244 hisVisEv := queryRoomRes .StateEvents [hisVisTuple ]
225245 if hisVisEv != nil {
226246 hisVis , _ := hisVisEv .HistoryVisibility ()
227247 if hisVis == "world_readable" {
228- return true
248+ return true , [] string {}
229249 }
230250 }
231251
@@ -238,19 +258,23 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
238258 rule , ruleErr := joinRuleEv .JoinRule ()
239259 if ruleErr != nil {
240260 util .GetLogger (ctx ).WithError (ruleErr ).WithField ("parent_room_id" , roomID ).Warn ("failed to get join rule" )
241- return false
261+ return false , [] string {}
242262 }
243263
244264 if rule == spec .Public || rule == spec .Knock {
245- return true
265+ return true , [] string {}
246266 }
247267
248- if rule == spec .Restricted {
268+ if rule == spec .Restricted || rule == spec . KnockRestricted {
249269 allowJoinedToRoomIDs = append (allowJoinedToRoomIDs , restrictedJoinRuleAllowedRooms (ctx , joinRuleEv )... )
250270 }
251271 }
252272
253273 // check if server is joined to any allowed room
274+ resultAllowedRoomIDs := make ([]string , 0 , len (allowJoinedToRoomIDs ))
275+ for _ , allowedRoomID := range allowJoinedToRoomIDs {
276+ resultAllowedRoomIDs = append (resultAllowedRoomIDs , allowedRoomID .String ())
277+ }
254278 for _ , allowedRoomID := range allowJoinedToRoomIDs {
255279 var queryRes fs.QueryJoinedHostServerNamesInRoomResponse
256280 err = querier .FSAPI .QueryJoinedHostServerNamesInRoom (ctx , & fs.QueryJoinedHostServerNamesInRoomRequest {
@@ -262,18 +286,18 @@ func authorisedServer(ctx context.Context, querier *Queryer, roomID spec.RoomID,
262286 }
263287 for _ , srv := range queryRes .ServerNames {
264288 if srv == callerServerName {
265- return true
289+ return true , resultAllowedRoomIDs [ 1 :]
266290 }
267291 }
268292 }
269293
270- return false
294+ return false , resultAllowedRoomIDs [ 1 :]
271295}
272296
273297// authorisedUser returns true iff the user is invited/joined this room or the room is world_readable
274298// or if the room has a public or knock join rule.
275299// Failing that, if the room has a restricted join rule and belongs to the space parent listed, it will return true.
276- func authorisedUser (ctx context.Context , querier * Queryer , clientCaller * userapi.Device , roomID spec.RoomID , parentRoomID * spec.RoomID ) (authed bool , isJoinedOrInvited bool ) {
300+ func authorisedUser (ctx context.Context , querier * Queryer , clientCaller * userapi.Device , roomID spec.RoomID , parentRoomID * spec.RoomID ) (authed bool , isJoinedOrInvited bool , resultAllowedRoomIDs [] string ) {
277301 hisVisTuple := gomatrixserverlib.StateKeyTuple {
278302 EventType : spec .MRoomHistoryVisibility ,
279303 StateKey : "" ,
@@ -295,20 +319,20 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
295319 }, & queryRes )
296320 if err != nil {
297321 util .GetLogger (ctx ).WithError (err ).Error ("failed to QueryCurrentState" )
298- return false , false
322+ return false , false , resultAllowedRoomIDs
299323 }
300324 memberEv := queryRes .StateEvents [roomMemberTuple ]
301325 if memberEv != nil {
302326 membership , _ := memberEv .Membership ()
303327 if membership == spec .Join || membership == spec .Invite {
304- return true , true
328+ return true , true , resultAllowedRoomIDs
305329 }
306330 }
307331 hisVisEv := queryRes .StateEvents [hisVisTuple ]
308332 if hisVisEv != nil {
309333 hisVis , _ := hisVisEv .HistoryVisibility ()
310334 if hisVis == "world_readable" {
311- return true , false
335+ return true , false , resultAllowedRoomIDs
312336 }
313337 }
314338 joinRuleEv := queryRes .StateEvents [joinRuleTuple ]
@@ -323,6 +347,7 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
323347 allowedRoomIDs := restrictedJoinRuleAllowedRooms (ctx , joinRuleEv )
324348 // check parent is in the allowed set
325349 for _ , a := range allowedRoomIDs {
350+ resultAllowedRoomIDs = append (resultAllowedRoomIDs , a .String ())
326351 if * parentRoomID == a {
327352 allowed = true
328353 break
@@ -345,13 +370,13 @@ func authorisedUser(ctx context.Context, querier *Queryer, clientCaller *userapi
345370 if memberEv != nil {
346371 membership , _ := memberEv .Membership ()
347372 if membership == spec .Join {
348- return true , false
373+ return true , false , resultAllowedRoomIDs
349374 }
350375 }
351376 }
352377 }
353378 }
354- return false , false
379+ return false , false , resultAllowedRoomIDs
355380}
356381
357382// helper function to fetch a state event
0 commit comments