diff --git a/olm_lut.py b/olm_lut.py index 70b5a13..45cb45a 100644 --- a/olm_lut.py +++ b/olm_lut.py @@ -2,7 +2,7 @@ import os import torch import numpy as np -from scipy.interpolate import RegularGridInterpolator +import torch.nn.functional as F LUT_FOLDER = os.path.join(os.path.dirname(__file__), "luts") @@ -11,6 +11,10 @@ def __init__(self): pass class OlmLUT: + def __init__(self): + # Cache loaded LUTs to avoid reloading + self.lut_cache = {} + @classmethod def INPUT_TYPES(cls): return { @@ -34,12 +38,17 @@ def INPUT_TYPES(cls): } RETURN_TYPES = ("IMAGE",) - FUNCTION = "apply_lut" - CATEGORY = "LUT" def load_cube_lut(self, filename, debug_logging): + """Load LUT and cache it as a GPU tensor""" + # Check cache first + if filename in self.lut_cache: + if debug_logging: + print(f"Using cached LUT: {filename}") + return self.lut_cache[filename] + size = None lut_data = [] @@ -70,17 +79,24 @@ def load_cube_lut(self, filename, debug_logging): if len(lut_data) != size**3: raise ValueError(f"Expected {size**3} LUT entries, found {len(lut_data)} in file: {path}") + # Convert to torch tensor and reshape lut_array = np.array(lut_data, dtype=np.float32) lut_array = lut_array.reshape(size, size, size, 3) lut_array = np.transpose(lut_array, (2, 1, 0, 3)) + + # Convert to torch and move to GPU - store as (1, 3, D, H, W) for grid_sample + lut_tensor = torch.from_numpy(lut_array).permute(3, 2, 1, 0).unsqueeze(0) # [1, 3, D, H, W] if debug_logging: - print(f"LUT loaded: shape {lut_array.shape}, range [{lut_array.min():.3f}, {lut_array.max():.3f}]") + print(f"LUT loaded: shape {lut_tensor.shape}, range [{lut_tensor.min():.3f}, {lut_tensor.max():.3f}]") print(f"First few LUT entries:") for i in range(min(3, size)): - print(f" LUT[{i},0,0] = [{lut_array[i,0,0,0]:.3f}, {lut_array[i,0,0,1]:.3f}, {lut_array[i,0,0,2]:.3f}]") + r, g, b = lut_array[i, 0, 0] + print(f" LUT[{i},0,0] = [{r:.3f}, {g:.3f}, {b:.3f}]") - return lut_array + # Cache the tensor + self.lut_cache[filename] = lut_tensor + return lut_tensor @classmethod def get_lut_files(cls): @@ -93,190 +109,191 @@ def get_lut_files(cls): files = [f for f in os.listdir(folder) if f.lower().endswith(".cube")] return files if files else ["No LUTs found"] - def srgb_to_linear(self, img): - return np.where(img <= 0.04045, - img / 12.92, - np.power(np.clip((img + 0.055) / 1.055, 0, 1), 2.4)) - - def linear_to_srgb(self, img): - return np.where(img <= 0.0031308, - img * 12.92, - 1.055 * np.power(np.clip(img, 0, 1), 1.0/2.4) - 0.055) + def srgb_to_linear_gpu(self, img): + """GPU-accelerated sRGB to linear conversion""" + return torch.where( + img <= 0.04045, + img / 12.92, + torch.pow(torch.clamp((img + 0.055) / 1.055, 0, 1), 2.4) + ) + + def linear_to_srgb_gpu(self, img): + """GPU-accelerated linear to sRGB conversion""" + return torch.where( + img <= 0.0031308, + img * 12.92, + 1.055 * torch.pow(torch.clamp(img, 0, 1), 1.0/2.4) - 0.055 + ) + + def apply_lut_gpu(self, image, lut_tensor, device): + """ + GPU-accelerated 3D LUT interpolation using grid_sample + + Args: + image: [B, H, W, 3] tensor + lut_tensor: [1, 3, D, H, W] LUT tensor + device: torch device + """ + B, H, W, C = image.shape + lut_size = lut_tensor.shape[2] + + # Move LUT to device if needed + if lut_tensor.device != device: + lut_tensor = lut_tensor.to(device) + + # Reshape image for grid_sample: [B, H*W, 1, 1, 3] + img_flat = image.reshape(B, H * W, 1, 1, 3) + + # Scale RGB values from [0,1] to grid_sample coords [-1, 1] + # grid_sample expects (x, y, z) coordinates in range [-1, 1] + grid = img_flat * 2.0 - 1.0 # [B, H*W, 1, 1, 3] + + # grid_sample expects grid as [B, D_out, H_out, W_out, 3] + # We want to sample at each pixel location + grid = grid.permute(0, 1, 2, 3, 4) # Already correct order: [B, H*W, 1, 1, 3] + + # Expand LUT for batch + lut_batch = lut_tensor.expand(B, -1, -1, -1, -1) + + # Apply 3D interpolation + # grid_sample needs grid in (B, D, H, W, 3) and input in (B, C, D, H, W) + sampled = F.grid_sample( + lut_batch, # [B, 3, D, H, W] + grid, # [B, H*W, 1, 1, 3] + mode='bilinear', # Trilinear for 5D + padding_mode='border', + align_corners=True + ) + + # Reshape back to image: [B, 3, H*W, 1, 1] -> [B, H, W, 3] + result = sampled.squeeze(-1).squeeze(-1).permute(0, 2, 1).reshape(B, H, W, 3) + + return result + + def generate_test_pattern_gpu(self, operation_mode, shape, device, debug_logging=False): + """Generate test patterns directly on GPU""" + B, H, W, C = shape + + if operation_mode == "test_colors_horizontal": + if debug_logging: + print("TEST COLORS MODE: Generating test pattern on GPU") + test_image = torch.zeros(B, H, W, C, device=device) + test_image[:, :H//3, :, 0] = 1.0 + test_image[:, H//3:2*H//3, :, 1] = 1.0 + test_image[:, 2*H//3:, :, 2] = 1.0 + + elif operation_mode == "test_colors_vertical": + if debug_logging: + print("TEST COLORS VERTICAL MODE: Generating test pattern on GPU") + test_image = torch.zeros(B, H, W, C, device=device) + test_image[:, :, :W//3, 0] = 1.0 + test_image[:, :, W//3:2*W//3, 1] = 1.0 + test_image[:, :, 2*W//3:, 2] = 1.0 + + elif operation_mode == "test_gradient": + if debug_logging: + print("TEST GRADIENT MODE: Generating test pattern on GPU") + test_image = torch.zeros(B, H, W, C, device=device) + gradient = torch.linspace(0, 1, W, device=device).view(1, 1, W, 1).expand(B, H, W, 3) + test_image = gradient + + elif operation_mode == "test_hsv": + if debug_logging: + print("TEST HSV MODE: Generating HSV hue sweep gradient on GPU") + # Generate on CPU for colorsys, then move to GPU + hues = np.linspace(0.0, 1.0, W, dtype=np.float32) + rgb_row = np.array([colorsys.hsv_to_rgb(h, 1.0, 1.0) for h in hues], dtype=np.float32) + rgb_image = np.tile(rgb_row[np.newaxis, :, :], (H, 1, 1)) + test_image = torch.from_numpy(rgb_image).unsqueeze(0).expand(B, -1, -1, -1).to(device) + + elif operation_mode == "mid_gray_box": + if debug_logging: + print("TEST GRAY BOX: Generating gray test box on GPU") + test_image = torch.zeros(B, H, W, C, device=device) + test_image[:, H//4:3*H//4, W//4:3*W//4, :] = 0.5 + + else: + test_image = None + + return test_image def apply_lut(self, image, lut_file, strength, gamma_correction=False, operation_mode="normal", debug_logging=False): - + B, H, W, C = image.shape + device = image.device if C != 3: raise ValueError(f"Expected RGB image with 3 channels, got {C}") if debug_logging: - print(f"\n=== LUT APPLICATION DEBUG ===") + print(f"\n=== LUT APPLICATION DEBUG (GPU-OPTIMIZED) ===") print(f"Input image shape: {image.shape}") + print(f"Device: {device}") print(f"Input range: [{image.min():.6f}, {image.max():.6f}]") print(f"Sample input pixel (0,0): R={image[0, 0, 0, 0]:.6f}, G={image[0, 0, 0, 1]:.6f}, B={image[0, 0, 0, 2]:.6f}") print(f"Gamma correction enabled: {gamma_correction}") print(f"Operation mode: {operation_mode}") - test_image = [any] - + # Passthrough mode if operation_mode == "passthrough": if debug_logging: print("PASSTHROUGH MODE: Returning original image") return (image,) - elif operation_mode == "test_colors_horizontal": - if debug_logging: - print("TEST COLORS MODE: Generating test pattern") - - test_image = torch.zeros_like(image) - H, W = image.shape[1], image.shape[2] - test_image[0, :H//3, :, 0] = 1.0 - test_image[0, H//3:2*H//3, :, 1] = 1.0 - test_image[0, 2*H//3:, :, 2] = 1.0 - - elif operation_mode == "test_colors_vertical": - if debug_logging: - print("TEST COLORS VERTICAL MODE: Generating test pattern") - - test_image = torch.zeros_like(image) - H, W = image.shape[1], image.shape[2] - test_image[0, :, :W//3, 0] = 1.0 - test_image[0, :, W//3:2*W//3, 1] = 1.0 - test_image[0, :, 2*W//3:, 2] = 1.0 + # Generate test patterns on GPU + if operation_mode in ["test_colors_horizontal", "test_colors_vertical", "test_gradient", "test_hsv", "mid_gray_box"]: + img_to_process = self.generate_test_pattern_gpu(operation_mode, image.shape, device, debug_logging) + else: + img_to_process = image - elif operation_mode == "test_gradient": + # Apply gamma correction on GPU if needed + if gamma_correction: + img_for_lut = self.linear_to_srgb_gpu(img_to_process) if debug_logging: - print("TEST GRADIENT MODE: Generating test pattern") - - test_image = torch.zeros_like(image) - H, W = image.shape[1], image.shape[2] - for i in range(W): - luma = i / (W - 1) - test_image[0, :, i, 0] = luma - test_image[0, :, i, 1] = luma - test_image[0, :, i, 2] = luma - - elif operation_mode == "test_hsv": - if debug_logging: - print("TEST HSV MODE: Generating HSV hue sweep gradient") - - H, W = image.shape[1], image.shape[2] - hues = np.linspace(0.0, 1.0, W, dtype=np.float32) - rgb_row = np.array([colorsys.hsv_to_rgb(h, 1.0, 1.0) for h in hues], dtype=np.float32) # shape: [W, 3] - rgb_image = np.tile(rgb_row[np.newaxis, :, :], (H, 1, 1)) # shape: [H, W, 3] - test_image = torch.from_numpy(rgb_image).unsqueeze(0) - - elif operation_mode == "mid_gray_box": + print("Converting linear -> sRGB for LUT lookup (GPU)") + print(f"Linear input sample: R={img_to_process[0,0,0,0]:.6f}") + print(f"sRGB for LUT sample: R={img_for_lut[0,0,0,0]:.6f}") + else: if debug_logging: - print("TEST GRAY BOX: Generating gray test box") + print("No gamma correction - assuming input is already in LUT color space") + img_for_lut = img_to_process - lut = self.load_cube_lut(lut_file, debug_logging) - lut_size = lut.shape[0] + # Load LUT (cached as tensor) + lut_tensor = self.load_cube_lut(lut_file, debug_logging) + lut_size = lut_tensor.shape[2] - test_image = np.zeros((H, W, 3), dtype=np.float32) - test_image[H//4:3*H//4, W//4:3*W//4, 0] = 0.5 - test_image[H//4:3*H//4, W//4:3*W//4, 1] = 0.5 - test_image[H//4:3*H//4, W//4:3*W//4, 2] = 0.5 - - if debug_logging: - mid_idx = lut_size // 2 - print(f"Mid-gray input [0.5, 0.5, 0.5] should map to LUT[{mid_idx},{mid_idx},{mid_idx}]:") - print(f" LUT output: [{lut[mid_idx,mid_idx,mid_idx,0]:.3f}, {lut[mid_idx,mid_idx,mid_idx,1]:.3f}, {lut[mid_idx,mid_idx,mid_idx,2]:.3f}]") - - test_image = torch.from_numpy(test_image).unsqueeze(0) - - output_images = [] - - for b in range(B): - - if operation_mode in ["test_colors_horizontal", "test_colors_vertical", "test_gradient", "test_hsv", "mid_gray_box"]: - img_np = test_image[b].cpu().numpy().astype(np.float32) - else: - img_np = image[b].cpu().numpy().astype(np.float32) + if debug_logging: + print(f"\n--- LUT ANALYSIS ---") + print(f"LUT tensor shape: {lut_tensor.shape}") + print(f"LUT range: [{lut_tensor.min():.6f}, {lut_tensor.max():.6f}]") - if gamma_correction: - img_for_lut = self.linear_to_srgb(img_np) - if debug_logging: - print("Converting linear -> sRGB for LUT lookup") - print(f"Linear input sample: R={img_np[0,0,0]:.6f}, G={img_np[0,0,1]:.6f}, B={img_np[0,0,2]:.6f}") - print(f"sRGB for LUT sample: R={img_for_lut[0,0,0]:.6f}, G={img_for_lut[0,0,1]:.6f}, B={img_for_lut[0,0,2]:.6f}") - else: - if debug_logging: - print("No gamma correction - assuming input is already in LUT color space") - img_for_lut = img_np.copy() + # Clamp input to [0, 1] + img_for_lut = torch.clamp(img_for_lut, 0, 1) - lut = self.load_cube_lut(lut_file, debug_logging) - lut_size = lut.shape[0] + # Apply LUT interpolation on GPU + mapped_img = self.apply_lut_gpu(img_for_lut, lut_tensor, device) + # Convert back from sRGB if needed + if gamma_correction: + mapped_img = self.srgb_to_linear_gpu(mapped_img) if debug_logging: - print(f"\n--- LUT ANALYSIS ---") - print(f"LUT shape: {lut.shape}") - print(f"LUT range: [{lut.min():.6f}, {lut.max():.6f}]") - - print(f"LUT corners:") - print(f" Black [0,0,0] -> [{lut[0,0,0,0]:.3f}, {lut[0,0,0,1]:.3f}, {lut[0,0,0,2]:.3f}]") - print(f" White [{lut_size-1},{lut_size-1},{lut_size-1}] -> [{lut[-1,-1,-1,0]:.3f}, {lut[-1,-1,-1,1]:.3f}, {lut[-1,-1,-1,2]:.3f}]") - mid = lut_size // 2 - print(f" Mid-gray [{mid},{mid},{mid}] -> [{lut[mid,mid,mid,0]:.3f}, {lut[mid,mid,mid,1]:.3f}, {lut[mid,mid,mid,2]:.3f}]") - - coords = [np.linspace(0, 1, lut_size) for _ in range(3)] - - interpolators = [] - for channel in range(3): - interpolators.append( - RegularGridInterpolator( - coords, - lut[:, :, :, channel], - method='linear', - bounds_error=False, - fill_value=0.0 - ) - ) - - img_flat = img_for_lut.reshape(-1, 3) - img_flat = np.clip(img_flat, 0, 1) - - mapped_flat = np.zeros_like(img_flat, dtype=np.float32) - - chunk_size = 10000 - for i in range(0, len(img_flat), chunk_size): - end_idx = min(i + chunk_size, len(img_flat)) - chunk = img_flat[i:end_idx] - - for channel in range(3): - mapped_flat[i:end_idx, channel] = interpolators[channel](chunk) - - mapped_img = mapped_flat.reshape(H, W, 3) - - if gamma_correction: - mapped_img_linear = self.srgb_to_linear(mapped_img) - if debug_logging: - print("Converting sRGB LUT result -> linear") - print(f"sRGB LUT result sample: R={mapped_img[0,0,0]:.6f}, G={mapped_img[0,0,1]:.6f}, B={mapped_img[0,0,2]:.6f}") - print(f"Linear result sample: R={mapped_img_linear[0,0,0]:.6f}, G={mapped_img_linear[0,0,1]:.6f}, B={mapped_img_linear[0,0,2]:.6f}") - mapped_img = mapped_img_linear - - blended = (1 - strength) * img_np + strength * mapped_img - blended = np.clip(blended, 0, 1).astype(np.float32) - - if debug_logging: - print(f"Original vs LUT comparison for sample pixel (center of the image):") - print(f" Original: R={img_np[H//2,W//2,0]:.6f}, G={img_np[H//2,W//2,1]:.6f}, B={img_np[H//2,W//2,2]:.6f}") - print(f" LUT result: R={mapped_img[H//2,W//2,0]:.6f}, G={mapped_img[H//2,W//2,1]:.6f}, B={mapped_img[H//2,W//2,2]:.6f}") - print(f" Final blend: R={blended[H//2,W//2,0]:.6f}, G={blended[H//2,W//2,1]:.6f}, B={blended[H//2,W//2,2]:.6f}") - - output_images.append(blended) + print("Converting sRGB LUT result -> linear (GPU)") - output_tensor = torch.from_numpy(np.stack(output_images, axis=0)) + # Blend with original on GPU + blended = (1 - strength) * img_to_process + strength * mapped_img + blended = torch.clamp(blended, 0, 1) if debug_logging: + print(f"Original vs LUT comparison for sample pixel (center):") + print(f" Original: R={img_to_process[0,H//2,W//2,0]:.6f}") + print(f" LUT result: R={mapped_img[0,H//2,W//2,0]:.6f}") + print(f" Final blend: R={blended[0,H//2,W//2,0]:.6f}") print(f"\n=== FINAL OUTPUT ===") - print(f"Output shape: {output_tensor.shape}") - print(f"Output range: [{output_tensor.min():.6f}, {output_tensor.max():.6f}]") + print(f"Output shape: {blended.shape}") + print(f"Output range: [{blended.min():.6f}, {blended.max():.6f}]") print("=== END DEBUG ===\n") - return (output_tensor,) + return (blended,) NODE_CLASS_MAPPINGS = { @@ -284,5 +301,5 @@ def apply_lut(self, image, lut_file, strength, gamma_correction=False, operation } NODE_DISPLAY_NAME_MAPPINGS = { - "OlmLUT": "Olm LUT", + "OlmLUT": "Olm LUT (GPU-Optimized)", } \ No newline at end of file