11# Add this import
22from keras import backend
3- from keras import ops
43from keras import layers
4+ from keras import ops
55from keras import random
66
7+
78class RandomElasticDeformation3D (layers .Layer ):
89 """
910 A high-performance 3D elastic deformation layer optimized for TPUs.
1011 """
1112
12- def __init__ (self ,
13- grid_size = (4 , 4 , 4 ),
14- alpha = 35.0 ,
15- sigma = 2.5 ,
16- data_format = "channels_last" ,
17- ** kwargs ):
13+ def __init__ (
14+ self ,
15+ grid_size = (4 , 4 , 4 ),
16+ alpha = 35.0 ,
17+ sigma = 2.5 ,
18+ data_format = "channels_last" ,
19+ ** kwargs ,
20+ ):
1821 super ().__init__ (** kwargs )
1922 self .grid_size = grid_size
2023 self .alpha = alpha
@@ -28,21 +31,36 @@ def __init__(self,
2831 raise ValueError (message )
2932
3033 def build (self , input_shape ):
31- self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
32- self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
33- kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
34- ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 , ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
34+ self ._alpha_tensor = ops .convert_to_tensor (
35+ self .alpha , dtype = self .compute_dtype
36+ )
37+ self ._sigma_tensor = ops .convert_to_tensor (
38+ self .sigma , dtype = self .compute_dtype
39+ )
40+ kernel_size = ops .cast (
41+ 2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32"
42+ )
43+ ax = ops .arange (
44+ - ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
45+ ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
46+ )
3547 kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
3648 self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
3749 self .built = True
3850
3951 def _separable_gaussian_filter_3d (self , tensor ):
4052 depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
41- tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
53+ tensor = ops .conv (
54+ tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = "same"
55+ )
4256 height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
43- tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
57+ tensor = ops .conv (
58+ tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = "same"
59+ )
4460 width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
45- tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
61+ tensor = ops .conv (
62+ tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = "same"
63+ )
4664 return tensor
4765
4866 def call (self , inputs ):
@@ -61,33 +79,75 @@ def call(self, inputs):
6179 label_volume = ops .cast (label_volume , dtype = compute_dtype )
6280
6381 input_shape = ops .shape (image_volume )
64- B , D , H , W , C = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ], input_shape [4 ]
65-
66- coarse_flow = random .uniform (shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ), minval = - 1 , maxval = 1 , dtype = compute_dtype )
67-
82+ B , D , H , W , C = (
83+ input_shape [0 ],
84+ input_shape [1 ],
85+ input_shape [2 ],
86+ input_shape [3 ],
87+ input_shape [4 ],
88+ )
89+
90+ coarse_flow = random .uniform (
91+ shape = (
92+ B ,
93+ self .grid_size [0 ],
94+ self .grid_size [1 ],
95+ self .grid_size [2 ],
96+ 3 ,
97+ ),
98+ minval = - 1 ,
99+ maxval = 1 ,
100+ dtype = compute_dtype ,
101+ )
102+
68103 flow = coarse_flow
69104 flow_shape = ops .shape (flow )
70- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
105+ flow = ops .reshape (
106+ flow ,
107+ (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ),
108+ )
71109 flow = ops .image .resize (flow , (H , W ), interpolation = "bicubic" )
72110 flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], H , W , 3 ))
73111 flow = ops .transpose (flow , (0 , 2 , 3 , 1 , 4 ))
74112 flow_shape = ops .shape (flow )
75- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ], flow_shape [3 ], 1 , 3 ))
113+ flow = ops .reshape (
114+ flow ,
115+ (
116+ flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ],
117+ flow_shape [3 ],
118+ 1 ,
119+ 3 ,
120+ ),
121+ )
76122 flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
77- flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
123+ flow = ops .reshape (
124+ flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 )
125+ )
78126 flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
79-
127+
80128 flow_components = ops .unstack (flow , axis = - 1 )
81129 smoothed_components = []
82130 for component in flow_components :
83- smoothed_components .append (ops .squeeze (self ._separable_gaussian_filter_3d (ops .expand_dims (component , axis = - 1 )), axis = - 1 ))
131+ smoothed_components .append (
132+ ops .squeeze (
133+ self ._separable_gaussian_filter_3d (
134+ ops .expand_dims (component , axis = - 1 )
135+ ),
136+ axis = - 1 ,
137+ )
138+ )
84139 smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
85-
140+
86141 flow = smoothed_flow * self ._alpha_tensor
87- grid_d , grid_h , grid_w = ops .meshgrid (ops .arange (D , dtype = compute_dtype ), ops .arange (H , dtype = compute_dtype ), ops .arange (W , dtype = compute_dtype ), indexing = 'ij' )
142+ grid_d , grid_h , grid_w = ops .meshgrid (
143+ ops .arange (D , dtype = compute_dtype ),
144+ ops .arange (H , dtype = compute_dtype ),
145+ ops .arange (W , dtype = compute_dtype ),
146+ indexing = "ij" ,
147+ )
88148 grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
89149 warp_grid = ops .expand_dims (grid , 0 ) + flow
90-
150+
91151 batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
92152
93153 def perform_map (elems ):
@@ -96,25 +156,45 @@ def perform_map(elems):
96156 image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
97157 # The channel dimension C is a static value when the graph is built
98158 for c in range (C ):
99- deformed_channels .append (ops .image .map_coordinates (image_slice_transposed [c ], coords , order = 1 ))
159+ deformed_channels .append (
160+ ops .image .map_coordinates (
161+ image_slice_transposed [c ], coords , order = 1
162+ )
163+ )
100164 deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
101- deformed_image_slice = ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 ))
165+ deformed_image_slice = ops .transpose (
166+ deformed_image_slice , (1 , 2 , 3 , 0 )
167+ )
102168 label_channel = ops .squeeze (label_slice , axis = - 1 )
103- deformed_label_channel = ops .image .map_coordinates (label_channel , coords , order = 0 )
104- deformed_label_slice = ops .expand_dims (deformed_label_channel , axis = - 1 )
169+ deformed_label_channel = ops .image .map_coordinates (
170+ label_channel , coords , order = 0
171+ )
172+ deformed_label_slice = ops .expand_dims (
173+ deformed_label_channel , axis = - 1
174+ )
105175 return deformed_image_slice , deformed_label_slice
106176
107177 if backend .backend () == "tensorflow" :
108178 import tensorflow as tf
109- deformed_image , deformed_label = tf .map_fn (perform_map , elems = (image_volume , label_volume , batched_coords ), dtype = (compute_dtype , compute_dtype ))
179+
180+ deformed_image , deformed_label = tf .map_fn (
181+ perform_map ,
182+ elems = (image_volume , label_volume , batched_coords ),
183+ dtype = (compute_dtype , compute_dtype ),
184+ )
110185 elif backend .backend () == "jax" :
111186 import jax
112- deformed_image , deformed_label = jax .lax .map (perform_map , xs = (image_volume , label_volume , batched_coords ))
187+
188+ deformed_image , deformed_label = jax .lax .map (
189+ perform_map , xs = (image_volume , label_volume , batched_coords )
190+ )
113191 else :
114192 deformed_images_list = []
115193 deformed_labels_list = []
116194 for i in range (B ):
117- img_slice , lbl_slice = perform_map ((image_volume [i ], label_volume [i ], batched_coords [i ]))
195+ img_slice , lbl_slice = perform_map (
196+ (image_volume [i ], label_volume [i ], batched_coords [i ])
197+ )
118198 deformed_images_list .append (img_slice )
119199 deformed_labels_list .append (lbl_slice )
120200 deformed_image = ops .stack (deformed_images_list , axis = 0 )
@@ -135,5 +215,12 @@ def compute_output_shape(self, input_shape):
135215
136216 def get_config (self ):
137217 config = super ().get_config ()
138- config .update ({"grid_size" : self .grid_size , "alpha" : self .alpha , "sigma" : self .sigma , "data_format" : self .data_format })
139- return config
218+ config .update (
219+ {
220+ "grid_size" : self .grid_size ,
221+ "alpha" : self .alpha ,
222+ "sigma" : self .sigma ,
223+ "data_format" : self .data_format ,
224+ }
225+ )
226+ return config
0 commit comments