11# Add this import
22from keras import backend
3- from keras import ops
43from keras import layers
4+ from keras import ops
55from keras import random
66
7+
78class RandomElasticDeformation3D (layers .Layer ):
89 """
910 A high-performance 3D elastic deformation layer optimized for TPUs.
1011 """
1112
12- def __init__ (self ,
13- grid_size = (4 , 4 , 4 ),
14- alpha = 35.0 ,
15- sigma = 2.5 ,
16- data_format = "channels_last" ,
17- ** kwargs ):
13+ def __init__ (
14+ self ,
15+ grid_size = (4 , 4 , 4 ),
16+ alpha = 35.0 ,
17+ sigma = 2.5 ,
18+ data_format = "channels_last" ,
19+ seed = None ,
20+ ** kwargs ,
21+ ):
1822 super ().__init__ (** kwargs )
1923 self .grid_size = grid_size
24+ self .seed = seed
2025 self .alpha = alpha
2126 self .sigma = sigma
2227 self .data_format = data_format
28+ self ._rng = random .SeedGenerator (seed ) if seed is not None else None
2329 if data_format not in ["channels_last" , "channels_first" ]:
2430 message = (
2531 "`data_format` must be one of 'channels_last' or "
@@ -28,21 +34,36 @@ def __init__(self,
2834 raise ValueError (message )
2935
3036 def build (self , input_shape ):
31- self ._alpha_tensor = ops .convert_to_tensor (self .alpha , dtype = self .compute_dtype )
32- self ._sigma_tensor = ops .convert_to_tensor (self .sigma , dtype = self .compute_dtype )
33- kernel_size = ops .cast (2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32" )
34- ax = ops .arange (- ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 , ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 )
37+ self ._alpha_tensor = ops .convert_to_tensor (
38+ self .alpha , dtype = self .compute_dtype
39+ )
40+ self ._sigma_tensor = ops .convert_to_tensor (
41+ self .sigma , dtype = self .compute_dtype
42+ )
43+ kernel_size = ops .cast (
44+ 2 * ops .round (3 * self ._sigma_tensor ) + 1 , dtype = "int32"
45+ )
46+ ax = ops .arange (
47+ - ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
48+ ops .cast (kernel_size // 2 , self .compute_dtype ) + 1.0 ,
49+ )
3550 kernel_1d = ops .exp (- (ax ** 2 ) / (2.0 * self ._sigma_tensor ** 2 ))
3651 self .kernel_1d = kernel_1d / ops .sum (kernel_1d )
3752 self .built = True
3853
3954 def _separable_gaussian_filter_3d (self , tensor ):
4055 depth_kernel = ops .reshape (self .kernel_1d , (- 1 , 1 , 1 , 1 , 1 ))
41- tensor = ops .conv (tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = 'same' )
56+ tensor = ops .conv (
57+ tensor , ops .cast (depth_kernel , dtype = tensor .dtype ), padding = "same"
58+ )
4259 height_kernel = ops .reshape (self .kernel_1d , (1 , - 1 , 1 , 1 , 1 ))
43- tensor = ops .conv (tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = 'same' )
60+ tensor = ops .conv (
61+ tensor , ops .cast (height_kernel , dtype = tensor .dtype ), padding = "same"
62+ )
4463 width_kernel = ops .reshape (self .kernel_1d , (1 , 1 , - 1 , 1 , 1 ))
45- tensor = ops .conv (tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = 'same' )
64+ tensor = ops .conv (
65+ tensor , ops .cast (width_kernel , dtype = tensor .dtype ), padding = "same"
66+ )
4667 return tensor
4768
4869 def call (self , inputs ):
@@ -61,33 +82,90 @@ def call(self, inputs):
6182 label_volume = ops .cast (label_volume , dtype = compute_dtype )
6283
6384 input_shape = ops .shape (image_volume )
64- B , D , H , W , C = input_shape [0 ], input_shape [1 ], input_shape [2 ], input_shape [3 ], input_shape [4 ]
65-
66- coarse_flow = random .uniform (shape = (B , self .grid_size [0 ], self .grid_size [1 ], self .grid_size [2 ], 3 ), minval = - 1 , maxval = 1 , dtype = compute_dtype )
67-
85+ B , D , H , W , C = (
86+ input_shape [0 ],
87+ input_shape [1 ],
88+ input_shape [2 ],
89+ input_shape [3 ],
90+ input_shape [4 ],
91+ )
92+
93+ if self ._rng is not None :
94+ coarse_flow = random .uniform (
95+ shape = (
96+ B ,
97+ self .grid_size [0 ],
98+ self .grid_size [1 ],
99+ self .grid_size [2 ],
100+ 3 ,
101+ ),
102+ minval = - 1 ,
103+ maxval = 1 ,
104+ dtype = compute_dtype ,
105+ seed = self ._rng ,
106+ )
107+ else :
108+ coarse_flow = random .uniform (
109+ shape = (
110+ B ,
111+ self .grid_size [0 ],
112+ self .grid_size [1 ],
113+ self .grid_size [2 ],
114+ 3 ,
115+ ),
116+ minval = - 1 ,
117+ maxval = 1 ,
118+ dtype = compute_dtype ,
119+ )
120+
68121 flow = coarse_flow
69122 flow_shape = ops .shape (flow )
70- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ))
123+ flow = ops .reshape (
124+ flow ,
125+ (flow_shape [0 ] * flow_shape [1 ], flow_shape [2 ], flow_shape [3 ], 3 ),
126+ )
71127 flow = ops .image .resize (flow , (H , W ), interpolation = "bicubic" )
72128 flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], H , W , 3 ))
73129 flow = ops .transpose (flow , (0 , 2 , 3 , 1 , 4 ))
74130 flow_shape = ops .shape (flow )
75- flow = ops .reshape (flow , (flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ], flow_shape [3 ], 1 , 3 ))
131+ flow = ops .reshape (
132+ flow ,
133+ (
134+ flow_shape [0 ] * flow_shape [1 ] * flow_shape [2 ],
135+ flow_shape [3 ],
136+ 1 ,
137+ 3 ,
138+ ),
139+ )
76140 flow = ops .image .resize (flow , (D , 1 ), interpolation = "bicubic" )
77- flow = ops .reshape (flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 ))
141+ flow = ops .reshape (
142+ flow , (flow_shape [0 ], flow_shape [1 ], flow_shape [2 ], D , 3 )
143+ )
78144 flow = ops .transpose (flow , (0 , 3 , 1 , 2 , 4 ))
79-
145+
80146 flow_components = ops .unstack (flow , axis = - 1 )
81147 smoothed_components = []
82148 for component in flow_components :
83- smoothed_components .append (ops .squeeze (self ._separable_gaussian_filter_3d (ops .expand_dims (component , axis = - 1 )), axis = - 1 ))
149+ smoothed_components .append (
150+ ops .squeeze (
151+ self ._separable_gaussian_filter_3d (
152+ ops .expand_dims (component , axis = - 1 )
153+ ),
154+ axis = - 1 ,
155+ )
156+ )
84157 smoothed_flow = ops .stack (smoothed_components , axis = - 1 )
85-
158+
86159 flow = smoothed_flow * self ._alpha_tensor
87- grid_d , grid_h , grid_w = ops .meshgrid (ops .arange (D , dtype = compute_dtype ), ops .arange (H , dtype = compute_dtype ), ops .arange (W , dtype = compute_dtype ), indexing = 'ij' )
160+ grid_d , grid_h , grid_w = ops .meshgrid (
161+ ops .arange (D , dtype = compute_dtype ),
162+ ops .arange (H , dtype = compute_dtype ),
163+ ops .arange (W , dtype = compute_dtype ),
164+ indexing = "ij" ,
165+ )
88166 grid = ops .stack ([grid_d , grid_h , grid_w ], axis = - 1 )
89167 warp_grid = ops .expand_dims (grid , 0 ) + flow
90-
168+
91169 batched_coords = ops .transpose (warp_grid , (0 , 4 , 1 , 2 , 3 ))
92170
93171 def perform_map (elems ):
@@ -96,25 +174,45 @@ def perform_map(elems):
96174 image_slice_transposed = ops .transpose (image_slice , (3 , 0 , 1 , 2 ))
97175 # The channel dimension C is a static value when the graph is built
98176 for c in range (C ):
99- deformed_channels .append (ops .image .map_coordinates (image_slice_transposed [c ], coords , order = 1 ))
177+ deformed_channels .append (
178+ ops .image .map_coordinates (
179+ image_slice_transposed [c ], coords , order = 1
180+ )
181+ )
100182 deformed_image_slice = ops .stack (deformed_channels , axis = 0 )
101- deformed_image_slice = ops .transpose (deformed_image_slice , (1 , 2 , 3 , 0 ))
183+ deformed_image_slice = ops .transpose (
184+ deformed_image_slice , (1 , 2 , 3 , 0 )
185+ )
102186 label_channel = ops .squeeze (label_slice , axis = - 1 )
103- deformed_label_channel = ops .image .map_coordinates (label_channel , coords , order = 0 )
104- deformed_label_slice = ops .expand_dims (deformed_label_channel , axis = - 1 )
187+ deformed_label_channel = ops .image .map_coordinates (
188+ label_channel , coords , order = 0
189+ )
190+ deformed_label_slice = ops .expand_dims (
191+ deformed_label_channel , axis = - 1
192+ )
105193 return deformed_image_slice , deformed_label_slice
106194
107195 if backend .backend () == "tensorflow" :
108196 import tensorflow as tf
109- deformed_image , deformed_label = tf .map_fn (perform_map , elems = (image_volume , label_volume , batched_coords ), dtype = (compute_dtype , compute_dtype ))
197+
198+ deformed_image , deformed_label = tf .map_fn (
199+ perform_map ,
200+ elems = (image_volume , label_volume , batched_coords ),
201+ dtype = (compute_dtype , compute_dtype ),
202+ )
110203 elif backend .backend () == "jax" :
111204 import jax
112- deformed_image , deformed_label = jax .lax .map (perform_map , xs = (image_volume , label_volume , batched_coords ))
205+
206+ deformed_image , deformed_label = jax .lax .map (
207+ perform_map , xs = (image_volume , label_volume , batched_coords )
208+ )
113209 else :
114210 deformed_images_list = []
115211 deformed_labels_list = []
116212 for i in range (B ):
117- img_slice , lbl_slice = perform_map ((image_volume [i ], label_volume [i ], batched_coords [i ]))
213+ img_slice , lbl_slice = perform_map (
214+ (image_volume [i ], label_volume [i ], batched_coords [i ])
215+ )
118216 deformed_images_list .append (img_slice )
119217 deformed_labels_list .append (lbl_slice )
120218 deformed_image = ops .stack (deformed_images_list , axis = 0 )
@@ -135,5 +233,13 @@ def compute_output_shape(self, input_shape):
135233
136234 def get_config (self ):
137235 config = super ().get_config ()
138- config .update ({"grid_size" : self .grid_size , "alpha" : self .alpha , "sigma" : self .sigma , "data_format" : self .data_format })
139- return config
236+ config .update (
237+ {
238+ "grid_size" : self .grid_size ,
239+ "alpha" : self .alpha ,
240+ "sigma" : self .sigma ,
241+ "data_format" : self .data_format ,
242+ "seed" : self .seed ,
243+ }
244+ )
245+ return config
0 commit comments