Skip to content

Commit d2b1422

Browse files
committed
test
1 parent 1d172ce commit d2b1422

File tree

5 files changed

+36
-23
lines changed

5 files changed

+36
-23
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ repos:
2121
- id: clang-format
2222
types_or: [c++, c, cuda]
2323
- repo: https://github.com/keith/pre-commit-buildifier
24-
rev: 6.4.0
24+
rev: 8.0.3
2525
hooks:
2626
- id: buildifier
2727
args:
2828
- --warnings=all
2929
- id: buildifier-lint
3030
- repo: https://github.com/abravalheri/validate-pyproject
31-
rev: v0.23
31+
rev: v0.24.1
3232
hooks:
3333
- id: validate-pyproject
3434
- repo: https://github.com/pycqa/isort
@@ -37,17 +37,17 @@ repos:
3737
- id: isort
3838
name: isort (python)
3939
- repo: https://github.com/pre-commit/mirrors-mypy
40-
rev: "v1.9.0"
40+
rev: "v1.15.0"
4141
hooks:
4242
- id: mypy
4343
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
4545
# Ruff version.
46-
rev: v0.3.3
46+
rev: v0.11.7
4747
hooks:
4848
- id: ruff
4949
- repo: https://github.com/psf/black
50-
rev: 24.3.0
50+
rev: 25.1.0
5151
hooks:
5252
- id: black
5353
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
@@ -57,7 +57,7 @@ repos:
5757
- id: typos
5858
- repo: https://github.com/astral-sh/uv-pre-commit
5959
# uv version.
60-
rev: 0.5.5
60+
rev: 0.7.1
6161
hooks:
6262
# Update the uv lockfile
6363
- id: uv-lock

py/torch_tensorrt/_enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ class dtype(Enum):
7676

7777
f8 = auto()
7878
"""8 bit floating-point number, equivalent to ``dtype.fp8`` and ``dtype.float8``
79-
79+
8080
:meta hide-value:
8181
"""
82-
82+
8383
f4 = auto()
8484
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``
8585

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929
REQUIRE_FULL_COMPILATION = False
3030
DRYRUN = False
3131
HARDWARE_COMPATIBLE = False
32-
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
32+
SUPPORTED_KERNEL_PRECISIONS = {
33+
dtype.f32,
34+
dtype.f16,
35+
dtype.bf16,
36+
dtype.i8,
37+
dtype.f8,
38+
dtype.f4,
39+
}
3340
TIMING_CACHE_PATH = os.path.join(
3441
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
3542
)

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def quantize(
6868

6969
return dq_output
7070

71+
7172
def dynamic_block_quantize(
7273
ctx: ConversionContext,
7374
target: Target,
@@ -99,23 +100,29 @@ def dynamic_block_quantize(
99100
raise ValueError(
100101
f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
101102
)
102-
print(f"input_tensor.shape: {input_tensor.shape} {block_size=} {amax=} {num_bits=} {exponent_bits=} {scale_num_bits=} {scale_exponent_bits=}")
103103
max_bound = 6
104104
amax = to_torch(amax, None)
105105
scale = torch.divide(amax, max_bound)
106106
scale = get_trt_tensor(ctx, scale, name + "_scale")
107107

108-
output_type=trt.DataType.FP4
109108
# Add Q node
110-
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(input_tensor, axis=-1, block_size=16, output_type=output_type)
111-
quantize_layer.set_output_type(0, output_type)
109+
dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
110+
input_tensor,
111+
axis=-1,
112+
block_size=16,
113+
output_type=trt.DataType.FP4,
114+
scale_type=trt.DataType.FP8,
115+
)
116+
dynamic_quantize_layer.set_output_type(0, trt.DataType.FP4)
112117

113-
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
114-
q_output = quantize_layer.get_output(0)
118+
set_layer_name(
119+
dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir
120+
)
121+
q_output = dynamic_quantize_layer.get_output(0)
115122
# Add DQ node
116123
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
117124
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
118-
dequantize_layer.precision = output_type
125+
dequantize_layer.precision = trt.DataType.FP4
119126
dq_output = dequantize_layer.get_output(0)
120127

121128
return dq_output

tests/py/dynamo/models/test_models_export.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,9 @@ def test_resnet18_half(ir):
199199
torch._dynamo.reset()
200200

201201

202-
203202
@unittest.skipIf(
204-
torch.cuda.get_device_capability() < (8, 9),
205-
"FP4 quantization requires compute capability 8.9 or later",
203+
torch.cuda.get_device_capability() < (10, 0),
204+
"FP4 quantization requires compute capability 10.0 or later",
206205
)
207206
@unittest.skipIf(
208207
not importlib.util.find_spec("modelopt"),
@@ -216,8 +215,8 @@ def test_base_fp4(ir):
216215
class SimpleNetwork(torch.nn.Module):
217216
def __init__(self):
218217
super(SimpleNetwork, self).__init__()
219-
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
220-
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
218+
self.linear1 = torch.nn.Linear(in_features=32, out_features=16)
219+
self.linear2 = torch.nn.Linear(in_features=16, out_features=1)
221220

222221
def forward(self, x):
223222
x = self.linear1(x)
@@ -229,12 +228,12 @@ def calibrate_loop(model):
229228
"""Simple calibration function for testing."""
230229
model(input_tensor)
231230

232-
input_tensor = torch.randn(1, 10).cuda()
231+
input_tensor = torch.randn(1, 32).cuda()
233232
model = SimpleNetwork().eval().cuda()
234233

235234
quant_cfg = mtq.NVFP4_DEFAULT_CFG
236235
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
237-
# model has FP8 qdq nodes at this point
236+
# model has FP4 qdq nodes at this point
238237
output_pyt = model(input_tensor)
239238

240239
with torch.no_grad():

0 commit comments

Comments
 (0)