From 978ee7e3d19c69852c34f84000cd5d7845d1ef04 Mon Sep 17 00:00:00 2001 From: Yutong Li Date: Tue, 29 Oct 2024 07:12:52 +0000 Subject: [PATCH] Ensure all tensors in GaussianRasterizer's raster_settings are on the same device - Added _sanitizeSettings method in GaussianRasterizer class - Check if all tensors in raster_settings are on the same device - Raise an exception if tensors are on different devices - Raise an exception if any tensor is on the CPU, requiring the use of a CUDA device --- diff_gaussian_rasterization/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..a13a6f18 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -172,6 +172,15 @@ class GaussianRasterizer(nn.Module): def __init__(self, raster_settings): super().__init__() self.raster_settings = raster_settings + self._sanitizeSettings() + + def _sanitizeSettings(self): + # Check if all tensors in raster_settings is on same device + compute_devices = [_x.device for _x in self.raster_settings if isinstance(_x, torch.Tensor)] + if len(set(compute_devices)) > 1: + raise Exception('All tensors in raster_settings should be on the same device.') + if any([d == torch.device('cpu') for d in compute_devices]): + raise Exception('CUDA device is required for the rasterizer.') def markVisible(self, positions): # Mark visible points (based on frustum culling for camera) with a boolean