|
| 1 | +import logging |
1 | 2 | from typing import Optional, Union
|
2 | 3 |
|
3 |
| -import numpy as np |
4 | 4 | import tensorrt as trt
|
5 | 5 | import torch
|
6 | 6 | import torch_tensorrt.dynamo.conversion.impl as impl
|
|
16 | 16 | cast_trt_tensor,
|
17 | 17 | get_trt_tensor,
|
18 | 18 | has_dynamic_shape,
|
| 19 | + set_layer_name, |
19 | 20 | )
|
20 | 21 | from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
|
21 | 22 | convert_binary_elementwise,
|
22 | 23 | )
|
23 | 24 | from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign
|
24 | 25 | from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
|
25 | 26 |
|
| 27 | +_LOGGER = logging.getLogger(__name__) |
| 28 | + |
26 | 29 |
|
27 | 30 | def trunc_div(
|
28 | 31 | ctx: ConversionContext,
|
@@ -250,12 +253,26 @@ def atan2(
|
250 | 253 | A TensorRT tensor representing the result of the atan2 operation.
|
251 | 254 | """
|
252 | 255 | pi_value = 3.141592653589793
|
253 |
| - pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi") |
254 | 256 |
|
255 |
| - if isinstance(input, TRTTensor): |
256 |
| - input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input") |
257 |
| - if isinstance(other, TRTTensor): |
258 |
| - other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other") |
| 257 | + promoted_type = _enums.dtype._from( |
| 258 | + torch.promote_types( |
| 259 | + _enums.dtype._from(input.dtype).to(torch.dtype), |
| 260 | + _enums.dtype._from(other.dtype).to(torch.dtype), |
| 261 | + ) |
| 262 | + ) |
| 263 | + # atan2's output is always float, so we promote any integer types to float32 |
| 264 | + # This mirrors PyTorch's behavior where atan2(int, int) -> float. |
| 265 | + if not promoted_type.to(torch.dtype).is_floating_point: |
| 266 | + promoted_type = _enums.dtype.float32 |
| 267 | + |
| 268 | + trt_promoted_type = promoted_type.to(trt.DataType) |
| 269 | + |
| 270 | + pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi", dtype=trt_promoted_type) |
| 271 | + |
| 272 | + if input.dtype != trt_promoted_type: |
| 273 | + input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted") |
| 274 | + if other.dtype != trt_promoted_type: |
| 275 | + other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted") |
259 | 276 |
|
260 | 277 | input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other")
|
261 | 278 |
|
@@ -333,56 +350,43 @@ def atan2(
|
333 | 350 | y_positive,
|
334 | 351 | )
|
335 | 352 |
|
| 353 | + # Create constant tensors for boundary conditions (x=0 or y=0) |
| 354 | + # Use impl.full which handles both dynamic and static shapes efficiently. |
336 | 355 | if has_dynamic_shape(input.shape):
|
337 |
| - pi_over_2_tensor = convert_binary_elementwise( |
338 |
| - ctx, |
339 |
| - target, |
340 |
| - source_ir, |
341 |
| - f"{name}_pi_over_2_tensor", |
342 |
| - trt.ElementWiseOperation.PROD, |
343 |
| - (pi_value / 2), |
344 |
| - input, |
345 |
| - ) |
346 |
| - |
347 |
| - minus_pi_over_2_tensor = convert_binary_elementwise( |
348 |
| - ctx, |
349 |
| - target, |
350 |
| - source_ir, |
351 |
| - f"{name}_minus_pi_over_2_tensor", |
352 |
| - trt.ElementWiseOperation.PROD, |
353 |
| - (-pi_value / 2), |
354 |
| - input, |
355 |
| - ) |
356 |
| - zero_tensor = convert_binary_elementwise( |
357 |
| - ctx, |
358 |
| - target, |
359 |
| - source_ir, |
360 |
| - f"{name}_zero_tensor", |
361 |
| - trt.ElementWiseOperation.PROD, |
362 |
| - 0, |
363 |
| - input, |
364 |
| - ) |
| 356 | + shape_layer = ctx.net.add_shape(input) |
| 357 | + set_layer_name(shape_layer, target, f"{name}_shape", source_ir) |
| 358 | + shape = shape_layer.get_output(0) |
365 | 359 | else:
|
366 |
| - # on x or y-axis |
367 |
| - pi_over_2_tensor = get_trt_tensor( |
368 |
| - ctx, |
369 |
| - (pi_value / 2) * np.ones(input.shape, dtype=np.float32), |
370 |
| - f"{name}_pi_over_2_tensor", |
371 |
| - dtype=trt.float32, |
372 |
| - ) |
| 360 | + shape = list(input.shape) |
373 | 361 |
|
374 |
| - minus_pi_over_2_tensor = get_trt_tensor( |
375 |
| - ctx, |
376 |
| - (-pi_value / 2) * np.ones(input.shape, dtype=np.float32), |
377 |
| - f"{name}_minus_pi_over_2_tensor", |
378 |
| - dtype=trt.float32, |
379 |
| - ) |
380 |
| - zero_tensor = get_trt_tensor( |
381 |
| - ctx, |
382 |
| - np.zeros(input.shape, dtype=np.float32), |
383 |
| - f"{name}_zero_tensor", |
384 |
| - dtype=trt.float32, |
385 |
| - ) |
| 362 | + pi_over_2_tensor = impl.full.full( |
| 363 | + ctx, |
| 364 | + target, |
| 365 | + source_ir, |
| 366 | + f"{name}_pi_over_2_tensor", |
| 367 | + shape, |
| 368 | + pi_value / 2, |
| 369 | + dtype=trt_promoted_type, |
| 370 | + ) |
| 371 | + |
| 372 | + minus_pi_over_2_tensor = impl.full.full( |
| 373 | + ctx, |
| 374 | + target, |
| 375 | + source_ir, |
| 376 | + f"{name}_minus_pi_over_2_tensor", |
| 377 | + shape, |
| 378 | + -pi_value / 2, |
| 379 | + dtype=trt_promoted_type, |
| 380 | + ) |
| 381 | + zero_tensor = impl.full.full( |
| 382 | + ctx, |
| 383 | + target, |
| 384 | + source_ir, |
| 385 | + f"{name}_zero_tensor", |
| 386 | + shape, |
| 387 | + 0.0, |
| 388 | + dtype=trt_promoted_type, |
| 389 | + ) |
386 | 390 |
|
387 | 391 | # π/2 if x>0 and y=0,
|
388 | 392 | pi_over_2_output = impl.condition.select(
|
|
0 commit comments