Skip to content

[Minor] Added NVILA prefilling rables #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions awq/kernels/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("w8a8_gemm_forward_cuda", &w8a8_gemm_forward_cuda, "our w8a8 gemm kernel");
m.def("w8a8_gemm_fuse_bias_forward_cuda", &w8a8_gemm_fuse_bias_forward_cuda, "our w8a8 gemm fused bias kernel");
m.def("invoke_quant", &invoke_quant, "fp16->int8 quantization");
m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
m.def("layer_norm_general", &layer_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("bias"),py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = true,
"Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
"Apply Layer Normalization to the input tensor (TRTLLM kernel) and quantize the tensor into 8 bits.");
m.def("silu_and_mul", &silu_and_mul, "Activation function.");
m.def("gelu_and_quant",&gelu_and_quant, "Apply gelu act and quant output");
}
}
2 changes: 1 addition & 1 deletion awq/kernels/csrc/w8a8/layernorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,

} // namespace vllm

void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
void layer_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &bias, // [hidden_size]
Expand Down
2 changes: 1 addition & 1 deletion awq/kernels/csrc/w8a8/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <torch/extension.h>
#include <cuda_fp16.h>
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
void layer_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &bias, // [hidden_size]
Expand Down
35 changes: 32 additions & 3 deletions tinychat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,21 @@ Time-To-First-Token (TTFT) of Llama-2-7B (Unit: Seconds):
| ----------- |:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|
| FP16 | 0.029 | 0.058 | 0.100 | 0.211 | 0.329 | 0.441 |
| TinyChat | 0.018 | 0.031 | 0.060 | 0.124 | 0.193 | 0.265 |
| Speedup | 1.57x | 1.83x | 1.66x | 1.70x | 1.70x | 1.66x |
| Speedup | 1.57x | 1.83x | 1.66x | 1.70x | 1.70x | 1.66x |

Time-To-First-Token (TTFT) NVILA models processing 8-image inputs (Unit: seconds):

| Model | Precison | VisonTower | LLM | Total |
|:---------------:|:----------:|:------------:|:------------:|:------------:|
| NVILA-lite-2B | FP16 | 0.074 | 0.024 | 0.097 |
| | TinyChat | 0.045 | 0.016 | 0.060 |
| | Speedup | 1.65x | 1.52x | 1.62x |
| NVILA-lite-8B | FP16 | 0.073 | 0.098 | 0.172 |
| | TinyChat | 0.045 | 0.059 | 0.104 |
| | Speedup | 1.63x | 1.67x | 1.65x |
| NVILA-8B | FP16 | 0.075 | 0.205 | 0.280 |
| | TinyChat | 0.046 | 0.122 | 0.168 |
| | Speedup | 1.61x | 1.69x | 1.66x |


#### Jetson Orin Results
Expand All @@ -197,6 +211,21 @@ Time-To-First-Token (TTFT) of Llama-3-8B (Unit: Seconds):
| TinyChat | 0.166 | 0.315 | 0.623 | 1.248 | 1.907 | 2.573 |
| Speedup | 1.24x | 1.26x | 0.91x | 1.22x | 1.21x | 1.21x |

Time-To-First-Token (TTFT) NVILA models processing 8-image inputs (Unit: seconds):


| Model | Precison | VisonTower | LLM | Total |
|:---------------:|:----------:|:------------:|:------------:|:------------:|
| NVILA-lite-2B | FP16 | 0.449 | 0.155 | 0.605 |
| | TinyChat | 0.419 | 0.145 | 0.564 |
| | Speedup | 1.07x | 1.07x | 1.07x |
| NVILA-lite-8B | FP16 | 0.449 | 0.733 | 1.183 |
| | TinyChat | 0.419 | 0.620 | 1.040 |
| | Speedup | 1.07x | 1.18x | 1.14x |
| NVILA-8B | FP16 | 0.449 | 1.798 | 2.247 |
| | TinyChat | 0.419 | 1.200 | 1.620 |
| | Speedup | 1.07x | 1.50x | 1.39x |


#### Comparison with Other Systems

Expand Down Expand Up @@ -500,11 +529,11 @@ python -m awq.entry --model_path PATH/TO/NVILA/llm \
```
Next, try chatting with it using the command below to experience shorter Time To First Token (TTFT) and higher decoding throughput.
```bash
python nvila_demo.py --model-path EPATH/TO/NVILA \
python nvila_demo.py --model-path PATH/TO/NVILA \
--quant_path PATH/TO/NVILA-w4-g128-v2.pt \
--media PATH/TO/MEDIA \
--act_scale_path PATH/TO/NVILA-smooth-scale.pt \
--quant_llm --chunk --model_type nvila
--all --chunk --model_type nvila
```


Expand Down
34 changes: 17 additions & 17 deletions tinychat/modules/fused_siglipdecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Optional, Tuple, Union
from flash_attn import flash_attn_func
import time

import argparse
CLIP_RANGE = 5


Expand All @@ -24,7 +24,11 @@
class QuantSiglipEncoder(nn.Module):
def __init__(self, module: SiglipEncoder, bsz=64, seqlen=1024):
super().__init__()
self.config = module.config
self.config=module.config
if "output_hidden_states" not in self.config:
self.config["output_hidden_states"]=False
self.config["use_return_dict"]=False

self.layers = [QuantSiglipEncoderLayer(layer) for layer in module.layers]
self.buffer = ActivationBuffer(module)
self.bsz = bsz
Expand All @@ -40,14 +44,12 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
# TODO Find why this code is necessary
# torch.sum(inputs_embeds!=inputs_embeds)
inputs_embeds=inputs_embeds.contiguous()
bsz, seqlen, _ = inputs_embeds.shape
if self.bsz != bsz or self.seqlen != seqlen:
self.buffer.allocate_activation_buffer(bsz * seqlen)
self.bsz = bsz
self.seqlen = seqlen

output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
Expand All @@ -68,7 +70,6 @@ def forward(
hidden_states = encoder_layer(
hidden_states, self.buffer, attention_mask, bsz, seqlen
)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states.reshape(bsz, seqlen, -1),)
if not return_dict:
Expand All @@ -84,7 +85,7 @@ class QuantSiglipMLP(nn.Module):
def __init__(self, siglipmlp, init_only=False):
super().__init__()
self.config = siglipmlp.config
self.activation_fn = siglipmlp.activation_fn
self.activation_fn = getattr(siglipmlp, "activation_fn", None)
self.fc1 = W8A8OF16LinearDynamicInputScale.from_linear(
siglipmlp.fc1, init_only=init_only, fc1=False
)
Expand Down Expand Up @@ -182,14 +183,14 @@ def __init__(self, module: SiglipEncoderLayer):
super().__init__()
self.embed_dim = module.embed_dim
self.self_attn = QuantSiglipFlashAttention2(module.self_attn)
self.layer_norm1 = RMSNormGeneral(
self.layer_norm1 = LayerNormGeneral(
module.layer_norm1.weight.data,
module.layer_norm1.bias.data,
module.layer_norm1.eps,
True,
).cuda()
self.mlp = QuantSiglipMLP(module.mlp)
self.layer_norm2 = RMSNormGeneral(
self.layer_norm2 = LayerNormGeneral(
module.layer_norm2.weight.data,
module.layer_norm2.bias.data,
module.layer_norm2.eps,
Expand Down Expand Up @@ -220,7 +221,6 @@ def forward(
buffer.quantized_hidden_states_buffer,
buffer.quantized_scale_buffer,
)

# INT8 -> FP16
self.self_attn(buffer, bsz, seqlen)
hidden_states = (
Expand All @@ -234,7 +234,6 @@ def forward(
buffer.quantized_hidden_states_buffer,
buffer.quantized_scale_buffer,
)

# INT8 -> FP16
self.mlp(buffer)
hidden_states = (
Expand All @@ -243,11 +242,11 @@ def forward(
return hidden_states


class RMSNormGeneral(nn.Module):
"""Root mean square normalization (w/ per-token or per-tensor quant).
class LayerNormGeneral(nn.Module):
"""Layer normalization (w/ per-token or per-tensor quant).

Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
Computes x -> w * (x-E(x)) / sqrt(E[x^2] + eps) + b where w is the learned weight.
Refer to https://arxiv.org/abs/1607.06450
"""

def __init__(
Expand All @@ -271,12 +270,13 @@ def forward(
quantized_sum_buffer: torch.Tensor = None,
) -> torch.Tensor:
# quantized_sum_buffer is not used, only to keep the consistency of the interface
awq_inference_engine.rms_norm_general(

awq_inference_engine.layer_norm_general(
quantized_hidden_states_buffer,
x,
self.weight.data,
self.bias.data,
quantized_scale_buffer,
self.variance_epsilon,
self.use_per_token_quant,
)
)
2 changes: 1 addition & 1 deletion tinychat/scripts/nvila_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ python -m awq.entry --model_path $MODEL_PATH/llm \
# Run the TinyChat demo:
python nvila_demo.py --model-path $MODEL_PATH \
--quant_path quant_cache/$MODEL_NAME-w4-g128-awq.pt \
--media ../figures/nvila-logo.jpg \
--media ../figures/vila-logo.jpg \
--act_scale_path awq_cache/$MODEL_NAME-smooth-scale.pt \
--all --chunk --model_type nvila --vis_image