@@ -326,7 +326,7 @@ def test_simple_sphere_batched(self):
326326 )
327327 self .assertClose (rgb , image_ref )
328328
329- def test_compositor_background_color (self ):
329+ def test_compositor_background_color_rgba (self ):
330330
331331 N , H , W , K , C , P = 1 , 15 , 15 , 20 , 4 , 225
332332 ptclds = torch .randn ((C , P ))
@@ -357,7 +357,7 @@ def test_compositor_background_color(self):
357357 torch .masked_select (images , is_foreground [:, None ]),
358358 )
359359
360- is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , 4 )
360+ is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , C )
361361
362362 # permute masked_images to correctly get rgb values
363363 masked_images = masked_images .permute (0 , 2 , 3 , 1 )
@@ -367,12 +367,58 @@ def test_compositor_background_color(self):
367367 # check if background colors are properly changed
368368 self .assertTrue (
369369 masked_images [is_background ]
370- .view (- 1 , 4 )[..., i ]
370+ .view (- 1 , C )[..., i ]
371371 .eq (channel_color )
372372 .all ()
373373 )
374374
375375 # check background color alpha values
376376 self .assertTrue (
377- masked_images [is_background ].view (- 1 , 4 )[..., 3 ].eq (1 ).all ()
377+ masked_images [is_background ].view (- 1 , C )[..., 3 ].eq (1 ).all ()
378378 )
379+
380+ def test_compositor_background_color_rgb (self ):
381+
382+ N , H , W , K , C , P = 1 , 15 , 15 , 20 , 3 , 225
383+ ptclds = torch .randn ((C , P ))
384+ alphas = torch .rand ((N , K , H , W ))
385+ pix_idxs = torch .randint (- 1 , 20 , (N , K , H , W )) # 20 < P, large amount of -1
386+ background_color = [0.5 , 0 , 1 ]
387+
388+ compositor_funcs = [
389+ (NormWeightedCompositor , norm_weighted_sum ),
390+ (AlphaCompositor , alpha_composite ),
391+ ]
392+
393+ for (compositor_class , composite_func ) in compositor_funcs :
394+
395+ compositor = compositor_class (background_color )
396+
397+ # run the forward method to generate masked images
398+ masked_images = compositor .forward (pix_idxs , alphas , ptclds )
399+
400+ # generate unmasked images for testing purposes
401+ images = composite_func (pix_idxs , alphas , ptclds )
402+
403+ is_foreground = pix_idxs [:, 0 ] >= 0
404+
405+ # make sure foreground values are unchanged
406+ self .assertClose (
407+ torch .masked_select (masked_images , is_foreground [:, None ]),
408+ torch .masked_select (images , is_foreground [:, None ]),
409+ )
410+
411+ is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , C )
412+
413+ # permute masked_images to correctly get rgb values
414+ masked_images = masked_images .permute (0 , 2 , 3 , 1 )
415+ for i in range (3 ):
416+ channel_color = background_color [i ]
417+
418+ # check if background colors are properly changed
419+ self .assertTrue (
420+ masked_images [is_background ]
421+ .view (- 1 , C )[..., i ]
422+ .eq (channel_color )
423+ .all ()
424+ )
0 commit comments