|
| 1 | +import logging |
1 | 2 | from typing import Optional, Union |
2 | 3 |
|
3 | | -import numpy as np |
4 | 4 | import tensorrt as trt |
5 | 5 | import torch |
6 | 6 | import torch_tensorrt.dynamo.conversion.impl as impl |
|
16 | 16 | cast_trt_tensor, |
17 | 17 | get_trt_tensor, |
18 | 18 | has_dynamic_shape, |
| 19 | + set_layer_name, |
19 | 20 | ) |
20 | 21 | from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( |
21 | 22 | convert_binary_elementwise, |
22 | 23 | ) |
23 | 24 | from torch_tensorrt.dynamo.conversion.impl.unary import atan, sign |
24 | 25 | from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary |
25 | 26 |
|
| 27 | +_LOGGER = logging.getLogger(__name__) |
| 28 | + |
26 | 29 |
|
27 | 30 | def trunc_div( |
28 | 31 | ctx: ConversionContext, |
@@ -250,12 +253,26 @@ def atan2( |
250 | 253 | A TensorRT tensor representing the result of the atan2 operation. |
251 | 254 | """ |
252 | 255 | pi_value = 3.141592653589793 |
253 | | - pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi") |
254 | 256 |
|
255 | | - if isinstance(input, TRTTensor): |
256 | | - input = cast_trt_tensor(ctx, input, trt.float32, f"{name}_input") |
257 | | - if isinstance(other, TRTTensor): |
258 | | - other = cast_trt_tensor(ctx, other, trt.float32, f"{name}_other") |
| 257 | + promoted_type = _enums.dtype._from( |
| 258 | + torch.promote_types( |
| 259 | + _enums.dtype._from(input.dtype).to(torch.dtype), |
| 260 | + _enums.dtype._from(other.dtype).to(torch.dtype), |
| 261 | + ) |
| 262 | + ) |
| 263 | + # atan2's output is always float, so we promote any integer types to float32 |
| 264 | + # This mirrors PyTorch's behavior where atan2(int, int) -> float. |
| 265 | + if not promoted_type.to(torch.dtype).is_floating_point: |
| 266 | + promoted_type = _enums.dtype.float32 |
| 267 | + |
| 268 | + trt_promoted_type = promoted_type.to(trt.DataType) |
| 269 | + |
| 270 | + pi_tensor = get_trt_tensor(ctx, pi_value, f"{name}_pi", dtype=trt_promoted_type) |
| 271 | + |
| 272 | + if input.dtype != trt_promoted_type: |
| 273 | + input = cast_trt_tensor(ctx, input, trt_promoted_type, f"{name}_input_casted") |
| 274 | + if other.dtype != trt_promoted_type: |
| 275 | + other = cast_trt_tensor(ctx, other, trt_promoted_type, f"{name}_other_casted") |
259 | 276 |
|
260 | 277 | input, other = broadcast(ctx, input, other, f"{name}_input", f"{name}_other") |
261 | 278 |
|
@@ -333,56 +350,43 @@ def atan2( |
333 | 350 | y_positive, |
334 | 351 | ) |
335 | 352 |
|
| 353 | + # Create constant tensors for boundary conditions (x=0 or y=0) |
| 354 | + # Use impl.full which handles both dynamic and static shapes efficiently. |
336 | 355 | if has_dynamic_shape(input.shape): |
337 | | - pi_over_2_tensor = convert_binary_elementwise( |
338 | | - ctx, |
339 | | - target, |
340 | | - source_ir, |
341 | | - f"{name}_pi_over_2_tensor", |
342 | | - trt.ElementWiseOperation.PROD, |
343 | | - (pi_value / 2), |
344 | | - input, |
345 | | - ) |
346 | | - |
347 | | - minus_pi_over_2_tensor = convert_binary_elementwise( |
348 | | - ctx, |
349 | | - target, |
350 | | - source_ir, |
351 | | - f"{name}_minus_pi_over_2_tensor", |
352 | | - trt.ElementWiseOperation.PROD, |
353 | | - (-pi_value / 2), |
354 | | - input, |
355 | | - ) |
356 | | - zero_tensor = convert_binary_elementwise( |
357 | | - ctx, |
358 | | - target, |
359 | | - source_ir, |
360 | | - f"{name}_zero_tensor", |
361 | | - trt.ElementWiseOperation.PROD, |
362 | | - 0, |
363 | | - input, |
364 | | - ) |
| 356 | + shape_layer = ctx.net.add_shape(input) |
| 357 | + set_layer_name(shape_layer, target, f"{name}_shape", source_ir) |
| 358 | + shape = shape_layer.get_output(0) |
365 | 359 | else: |
366 | | - # on x or y-axis |
367 | | - pi_over_2_tensor = get_trt_tensor( |
368 | | - ctx, |
369 | | - (pi_value / 2) * np.ones(input.shape, dtype=np.float32), |
370 | | - f"{name}_pi_over_2_tensor", |
371 | | - dtype=trt.float32, |
372 | | - ) |
| 360 | + shape = list(input.shape) |
373 | 361 |
|
374 | | - minus_pi_over_2_tensor = get_trt_tensor( |
375 | | - ctx, |
376 | | - (-pi_value / 2) * np.ones(input.shape, dtype=np.float32), |
377 | | - f"{name}_minus_pi_over_2_tensor", |
378 | | - dtype=trt.float32, |
379 | | - ) |
380 | | - zero_tensor = get_trt_tensor( |
381 | | - ctx, |
382 | | - np.zeros(input.shape, dtype=np.float32), |
383 | | - f"{name}_zero_tensor", |
384 | | - dtype=trt.float32, |
385 | | - ) |
| 362 | + pi_over_2_tensor = impl.full.full( |
| 363 | + ctx, |
| 364 | + target, |
| 365 | + source_ir, |
| 366 | + f"{name}_pi_over_2_tensor", |
| 367 | + shape, |
| 368 | + pi_value / 2, |
| 369 | + dtype=trt_promoted_type, |
| 370 | + ) |
| 371 | + |
| 372 | + minus_pi_over_2_tensor = impl.full.full( |
| 373 | + ctx, |
| 374 | + target, |
| 375 | + source_ir, |
| 376 | + f"{name}_minus_pi_over_2_tensor", |
| 377 | + shape, |
| 378 | + -pi_value / 2, |
| 379 | + dtype=trt_promoted_type, |
| 380 | + ) |
| 381 | + zero_tensor = impl.full.full( |
| 382 | + ctx, |
| 383 | + target, |
| 384 | + source_ir, |
| 385 | + f"{name}_zero_tensor", |
| 386 | + shape, |
| 387 | + 0.0, |
| 388 | + dtype=trt_promoted_type, |
| 389 | + ) |
386 | 390 |
|
387 | 391 | # π/2 if x>0 and y=0, |
388 | 392 | pi_over_2_output = impl.condition.select( |
|
0 commit comments