Skip to content

Commit 33b8bbc

Browse files
authored
fix: atan2 strong type support & bug fix for integer dynamic shape (#3751)
1 parent b156b6e commit 33b8bbc

File tree

2 files changed

+59
-54
lines changed

2 files changed

+59
-54
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import logging
12
from typing import Optional, Union
23

3-
import numpy as np
44
import tensorrt as trt
55
import torch
66
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -16,13 +16,16 @@
1616
cast_trt_tensor,
1717
get_trt_tensor,
1818
has_dynamic_shape,
19+
set_layer_name,
1920
)
2021
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
2122
convert_binary_elementwise,
2223
)
2324
from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
2425
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
2526

27+
_LOGGER = logging.getLogger(__name__)
28+
2629

2730
def trunc_div(
2831
ctx: ConversionContext,
@@ -250,12 +253,26 @@ def atan2(
250253
A TensorRT tensor representing the result of the atan2 operation.
251254
"""
252255
pi_value = 3.141592653589793
253-
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi")
254256

255-
if isinstance(input, TRTTensor):
256-
input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input")
257-
if isinstance(other, TRTTensor):
258-
other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other")
257+
promoted_type = _enums.dtype._from(
258+
torch.promote_types(
259+
_enums.dtype._from(input.dtype).to(torch.dtype),
260+
_enums.dtype._from(other.dtype).to(torch.dtype),
261+
)
262+
)
263+
# atan2's output is always float, so we promote any integer types to float32
264+
# This mirrors PyTorch's behavior where atan2(int, int) -> float.
265+
if not promoted_type.to(torch.dtype).is_floating_point:
266+
promoted_type = _enums.dtype.float32
267+
268+
trt_promoted_type = promoted_type.to(trt.DataType)
269+
270+
pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi", dtype=trt_promoted_type)
271+
272+
if input.dtype != trt_promoted_type:
273+
input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted")
274+
if other.dtype != trt_promoted_type:
275+
other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted")
259276

260277
input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other")
261278

@@ -333,56 +350,43 @@ def atan2(
333350
y_positive,
334351
)
335352

353+
# Create constant tensors for boundary conditions (x=0 or y=0)
354+
# Use impl.full which handles both dynamic and static shapes efficiently.
336355
if has_dynamic_shape(input.shape):
337-
pi_over_2_tensor = convert_binary_elementwise(
338-
ctx,
339-
target,
340-
source_ir,
341-
f"{name}_pi_over_2_tensor",
342-
trt.ElementWiseOperation.PROD,
343-
(pi_value / 2),
344-
input,
345-
)
346-
347-
minus_pi_over_2_tensor = convert_binary_elementwise(
348-
ctx,
349-
target,
350-
source_ir,
351-
f"{name}_minus_pi_over_2_tensor",
352-
trt.ElementWiseOperation.PROD,
353-
(-pi_value / 2),
354-
input,
355-
)
356-
zero_tensor = convert_binary_elementwise(
357-
ctx,
358-
target,
359-
source_ir,
360-
f"{name}_zero_tensor",
361-
trt.ElementWiseOperation.PROD,
362-
0,
363-
input,
364-
)
356+
shape_layer = ctx.net.add_shape(input)
357+
set_layer_name(shape_layer, target, f"{name}_shape", source_ir)
358+
shape = shape_layer.get_output(0)
365359
else:
366-
# on x or y-axis
367-
pi_over_2_tensor = get_trt_tensor(
368-
ctx,
369-
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
370-
f"{name}_pi_over_2_tensor",
371-
dtype=trt.float32,
372-
)
360+
shape = list(input.shape)
373361

374-
minus_pi_over_2_tensor = get_trt_tensor(
375-
ctx,
376-
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
377-
f"{name}_minus_pi_over_2_tensor",
378-
dtype=trt.float32,
379-
)
380-
zero_tensor = get_trt_tensor(
381-
ctx,
382-
np.zeros(input.shape, dtype=np.float32),
383-
f"{name}_zero_tensor",
384-
dtype=trt.float32,
385-
)
362+
pi_over_2_tensor = impl.full.full(
363+
ctx,
364+
target,
365+
source_ir,
366+
f"{name}_pi_over_2_tensor",
367+
shape,
368+
pi_value / 2,
369+
dtype=trt_promoted_type,
370+
)
371+
372+
minus_pi_over_2_tensor = impl.full.full(
373+
ctx,
374+
target,
375+
source_ir,
376+
f"{name}_minus_pi_over_2_tensor",
377+
shape,
378+
-pi_value / 2,
379+
dtype=trt_promoted_type,
380+
)
381+
zero_tensor = impl.full.full(
382+
ctx,
383+
target,
384+
source_ir,
385+
f"{name}_zero_tensor",
386+
shape,
387+
0.0,
388+
dtype=trt_promoted_type,
389+
)
386390

387391
# π/2 if x>0 and y=0,
388392
pi_over_2_output = impl.condition.select(

py/torch_tensorrt/dynamo/conversion/impl/full.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def full(
3434
# in static shape scenario, shape is a list of int
3535
if all(isinstance(dim, int) for dim in shape):
3636
output_np_dtype = output_dtype.try_to(np.dtype, use_default=True)
37-
return np.full(shape, fill_value, dtype=output_np_dtype)
37+
np_array = np.full(shape, fill_value, dtype=output_np_dtype)
38+
return get_trt_tensor(ctx, np_array, name, dtype=output_dtype)
3839
else:
3940
shape = impl.cat.cat(
4041
ctx, target, source_ir, name + "_concat_shape", shape, 0

0 commit comments

Comments
 (0)