Skip to content

What kind of layers are optimized by torchao on a RTX 4090? #1805

Open
@naiveen

Description

@naiveen

I am trying to quantize a model and I am running this on a 4090. Since many of the available quantization benchmarks are done on higher gpus, I am trying to establish a baseline perfromance gain I can expect from quantization.

I tried the tutorial at torchao_demo on a gpu and it worked great. My model has similar kind of transformer layers with q, k, v projections but I am not able to see the same kind of performance with a large chunk of aten::_copy() operations in profile log.

To debug, I wanted to benchmark on a single linear layer as the majority of modified layers seem to be of this type. But I am not able to see any performance gain in this experiment of mine. I would appreciate if I can get more context into the specific layers that gets optimized by torchao.

'''
    https://github.com/ethanshenley/PyTorch-Conference-Recipes/blob/main/torchao_demo.ipynb
'''
import gc
import psutil
import torch
import torch.nn as nn
import time

from torchao.quantization import quantize_, int8_weight_only,float8_weight_only


device = "cuda:0"
def get_memory_usage():
    return psutil.Process().memory_info().rss / 1024 / 1024  # in MB

def run_inference(model, inputs, num_runs=10):
    start_time = time.time()
    for i in range(num_runs):
        with torch.no_grad():
            outputs = model(inputs[i].squeeze())
    torch.cuda.synchronize(device)
    end_time = time.time()
    return (end_time - start_time) / num_runs

# Load model and tokenizer
bsz = 16
n_runs = 100
for sz in range(1024, 20480, 1024):
    print('====================================================')
    print(f"Running with linear layer of size {sz}...")
    model = nn.Linear(sz, sz).to(device)
    inputs = torch.randn(n_runs, bsz, sz).to(device)

    print("\nRunning baseline model...")
    baseline_memory = get_memory_usage()
    baseline_time = run_inference(model, inputs, n_runs)
    print(f"Baseline - Time: {baseline_time:.4f}s, Memory: {baseline_memory:.2f}MB")


    print("\nRunning int8 weight-only quantized model...")
    model_int8 = nn.Linear(sz, sz).to(device)
    quantize_(model_int8, int8_weight_only())
    int8_memory = get_memory_usage()
    int8_time = run_inference(model_int8, inputs, n_runs)
    print(f"Int8 Weight-Only - Time: {int8_time:.4f}s, Memory: {int8_memory:.2f}MB")

    print("\nRunning fp8 weight-only quantized model...")
    model_fp8 = nn.Linear(sz, sz).to(device)
    quantize_(model_fp8, float8_weight_only())  
    fp8_memory = get_memory_usage()
    fp8_time = run_inference(model, inputs, n_runs)
    print(f"fp8 Weight-Only  - Time: {fp8_time:.4f}s, Memory: {fp8_memory:.2f}MB")


    print("\nPerformance Improvements:")
    print(f"Int8 weight-only speedup: {baseline_time / int8_time:.2f}x")
    print(f"Int8 weight-only memory reduction: {baseline_memory / int8_memory:.2f}x")
    print(f"fp8 weight-only speedup: {baseline_time / fp8_time:.2f}x")
    print(f"fp8 weight-only memory reduction: {baseline_memory / fp8_memory:.2f}x")

    del model, model_int8, model_fp8, inputs
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize(device)

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions