Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion inference/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
}


def main(hf_ckpt_path, save_path, n_experts, mp):
def main(hf_ckpt_path: str, save_path: str, n_experts: str, mp: int) -> None:
"""
Converts and saves model checkpoint files into a specified format.

Expand Down
4 changes: 2 additions & 2 deletions inference/fp8_cast_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from kernel import weight_dequant

def main(fp8_path, bf16_path):
def main(fp8_path: str, bf16_path: str) -> None:
"""
Converts FP8 weights to BF16 and saves the converted weights.

Expand Down Expand Up @@ -41,7 +41,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names = []

# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
def get_tensor(tensor_name: str) -> torch.Tensor:
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.

Expand Down
2 changes: 1 addition & 1 deletion inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from model import Transformer, ModelArgs


def sample(logits, temperature: float = 1.0):
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
"""
Samples a token from the logits using temperature scaling.

Expand Down
12 changes: 6 additions & 6 deletions inference/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def act_quant_kernel(x_ptr: triton.Pointer, y_ptr: triton.Pointer, s_ptr: triton.Pointer, BLOCK_SIZE: tl.constexpr) -> None:
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.

Expand Down Expand Up @@ -53,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor


@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
def weight_dequant_kernel(x_ptr: tl.pointer, s_ptr: tl.pointer, y_ptr: tl.pointer, M: int, N: int, BLOCK_SIZE: tl.constexpr) -> None:
"""
Dequantizes weights using the provided scaling factors and stores the result.

Expand Down Expand Up @@ -112,12 +112,12 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t

@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
def fp8_gemm_kernel(a_ptr: tl.tensor, b_ptr: tl.tensor, c_ptr:tl.tensor,
a_s_ptr: tl.tensor, b_s_ptr: tl.tensor,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
BLOCK_SIZE_K: tl.constexpr) -> None:
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.

Expand Down Expand Up @@ -167,7 +167,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
tl.store(c_ptrs, c, mask=mask)


def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
"""
Perform a matrix multiplication using FP8 precision.

Expand Down
26 changes: 13 additions & 13 deletions inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ParallelEmbedding(nn.Module):
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
def __init__(self, vocab_size: int, dim: int) -> None:
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
Expand Down Expand Up @@ -173,7 +173,7 @@ class Linear(nn.Module):
"""
dtype = torch.bfloat16

def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -212,7 +212,7 @@ class ColumnParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
Expand Down Expand Up @@ -241,7 +241,7 @@ class RowParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
Expand Down Expand Up @@ -406,7 +406,7 @@ class MLA(nn.Module):
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
Expand Down Expand Up @@ -440,7 +440,7 @@ def __init__(self, args: ModelArgs):
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.

Expand Down Expand Up @@ -503,7 +503,7 @@ class MLP(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
def __init__(self, dim: int, inter_dim: int) -> None:
"""
Initializes the MLP layer.

Expand Down Expand Up @@ -543,7 +543,7 @@ class Gate(nn.Module):
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs) -> None:
"""
Initializes the Gate module.

Expand Down Expand Up @@ -604,7 +604,7 @@ class Expert(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
def __init__(self, dim: int, inter_dim: int) -> None:
"""
Initializes the Expert layer.

Expand Down Expand Up @@ -643,7 +643,7 @@ class MoE(nn.Module):
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs) -> None:
"""
Initializes the MoE module.

Expand Down Expand Up @@ -700,7 +700,7 @@ class Block(nn.Module):
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs):
def __init__(self, layer_id: int, args: ModelArgs) -> None:
"""
Initializes the Transformer block.

Expand Down Expand Up @@ -744,7 +744,7 @@ class Transformer(nn.Module):
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs) -> None:
"""
Initializes the Transformer model.

Expand All @@ -766,7 +766,7 @@ def __init__(self, args: ModelArgs):
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""
Forward pass for the Transformer model.

Expand Down