@@ -30,7 +30,10 @@ class TestQTensor:
3030 )
3131 @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
3232 @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ])
33- def test_qtensor (self , num_bits , block_sizes , device , input_dtype ):
33+ @pytest .mark .parametrize (
34+ ("input_shape" , "check_memory" ), [((256 , 64 ), True ), ((256 , 32 ), False )]
35+ ) # test
36+ def test_qtensor (self , num_bits , block_sizes , device , input_dtype , input_shape , check_memory ):
3437 nf4_attr_cfg = QuantizerAttributeConfig (
3538 num_bits = num_bits ,
3639 block_sizes = block_sizes ,
@@ -40,7 +43,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
4043
4144 # Original tensor
4245 base_mem = torch .cuda .memory_allocated ("cuda" )
43- x = torch .rand (256 , 64 ).to (device ).to (dtype = input_dtype )
46+ x = torch .rand (input_shape ).to (device ).to (dtype = input_dtype )
4447 x_allocated = torch .cuda .memory_allocated ("cuda" )
4548 bf16_mem_usage = x_allocated - base_mem
4649
@@ -51,7 +54,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
5154 nf4_mem_usage = nf4_x_allocated - base_mem
5255
5356 # Check the memory saving
54- if bf16_mem_usage > 0 :
57+ if bf16_mem_usage > 0 and check_memory :
5558 assert (nf4_mem_usage ) / bf16_mem_usage < 0.3
5659
5760 # De-quantize to origin dtype
0 commit comments