diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..a13a6f18 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -172,6 +172,15 @@ class GaussianRasterizer(nn.Module): def __init__(self, raster_settings): super().__init__() self.raster_settings = raster_settings + self._sanitizeSettings() + + def _sanitizeSettings(self): + # Check if all tensors in raster_settings is on same device + compute_devices = [_x.device for _x in self.raster_settings if isinstance(_x, torch.Tensor)] + if len(set(compute_devices)) > 1: + raise Exception('All tensors in raster_settings should be on the same device.') + if any([d == torch.device('cpu') for d in compute_devices]): + raise Exception('CUDA device is required for the rasterizer.') def markVisible(self, positions): # Mark visible points (based on frustum culling for camera) with a boolean