@@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma):
221221 self .len_spatial_sigma = 3
222222 else :
223223 raise ValueError (
224- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self . len_spatial_sigma } ."
224+ f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims (1, 2 or 3) ."
225225 )
226226
227227 # Register sigmas as trainable parameters.
@@ -394,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma):
394394 self .len_spatial_sigma = 3
395395 else :
396396 raise ValueError (
397- f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims { self . len_spatial_sigma } ."
397+ f"len(spatial_sigma) { spatial_sigma } must match number of spatial dims (1, 2, or 3) ."
398398 )
399399
400400 # Register sigmas as trainable parameters.
@@ -404,9 +404,13 @@ def __init__(self, spatial_sigma, color_sigma):
404404 self .sigma_color = torch .nn .Parameter (torch .tensor (color_sigma ))
405405
406406 def forward (self , input_tensor , guidance_tensor ):
407+ if len (input_tensor .shape ) < 3 :
408+ raise ValueError (
409+ f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got { len (input_tensor .shape )} "
410+ )
407411 if input_tensor .shape [1 ] != 1 :
408412 raise ValueError (
409- f"Currently channel dimensions >1 ({ input_tensor .shape [1 ]} ) are not supported. "
413+ f"Currently channel dimensions > 1 ({ input_tensor .shape [1 ]} ) are not supported. "
410414 "Please use multiple parallel filter layers if you want "
411415 "to filter multiple channels."
412416 )
@@ -417,26 +421,27 @@ def forward(self, input_tensor, guidance_tensor):
417421 )
418422
419423 len_input = len (input_tensor .shape )
424+ spatial_dims = len_input - 2
420425
421426 # C++ extension so far only supports 5-dim inputs.
422- if len_input == 3 :
427+ if spatial_dims == 1 :
423428 input_tensor = input_tensor .unsqueeze (3 ).unsqueeze (4 )
424429 guidance_tensor = guidance_tensor .unsqueeze (3 ).unsqueeze (4 )
425- elif len_input == 4 :
430+ elif spatial_dims == 2 :
426431 input_tensor = input_tensor .unsqueeze (4 )
427432 guidance_tensor = guidance_tensor .unsqueeze (4 )
428433
429- if self .len_spatial_sigma != len_input :
430- raise ValueError (f"Spatial dimension ({ len_input } ) must match initialized len(spatial_sigma)." )
434+ if self .len_spatial_sigma != spatial_dims :
435+ raise ValueError (f"Spatial dimension ({ spatial_dims } ) must match initialized len(spatial_sigma)." )
431436
432437 prediction = TrainableJointBilateralFilterFunction .apply (
433438 input_tensor , guidance_tensor , self .sigma_x , self .sigma_y , self .sigma_z , self .sigma_color
434439 )
435440
436441 # Make sure to return tensor of the same shape as the input.
437- if len_input == 3 :
442+ if spatial_dims == 1 :
438443 prediction = prediction .squeeze (4 ).squeeze (3 )
439- elif len_input == 4 :
444+ elif spatial_dims == 2 :
440445 prediction = prediction .squeeze (4 )
441446
442447 return prediction
0 commit comments