Skip to content

Commit a846b5c

Browse files
committed
fix: more rebase fixes for v2
1 parent a886cef commit a846b5c

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tidy3d/components/medium.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,18 @@ def _interp_axis(
13781378
elif component == "real":
13791379
values = values.real
13801380

1381-
return values.sum(axis=-1).reshape(eps_shape)
1381+
vjp_array = values.sum(axis=-1).reshape(eps_shape)
1382+
1383+
# match derivative dtype to the underlying dataset
1384+
target_array = getattr(spatial_data, "values", None)
1385+
if target_array is None and hasattr(spatial_data, "data"):
1386+
target_array = spatial_data.data
1387+
if target_array is not None:
1388+
target_dtype = np.asarray(target_array).dtype
1389+
if not np.issubdtype(target_dtype, np.complexfloating):
1390+
vjp_array = np.real(vjp_array).astype(target_dtype, copy=False)
1391+
1392+
return vjp_array
13821393

13831394

13841395
""" Dispersionless Medium """
@@ -3591,7 +3602,9 @@ def is_spatially_uniform(self) -> bool:
35913602
return True
35923603

35933604
@staticmethod
3594-
def _sorted_spatial_data(data: CustomSpatialDataTypeAnnotated):
3605+
def _sorted_spatial_data(
3606+
data: CustomSpatialDataTypeAnnotated,
3607+
) -> CustomSpatialDataTypeAnnotated:
35953608
"""Return spatial data sorted along its coordinates if applicable."""
35963609
if isinstance(data, SpatialDataArray):
35973610
return data._spatially_sorted

0 commit comments

Comments
 (0)