Skip to content

docs: a claude demo of using torch-tensorrt to compile qwen3-reranker#4085

Open
narendasan wants to merge 1 commit intomainfrom
narendasan/push-vyowusqpovxt
Open

docs: a claude demo of using torch-tensorrt to compile qwen3-reranker#4085
narendasan wants to merge 1 commit intomainfrom
narendasan/push-vyowusqpovxt

Conversation

@narendasan
Copy link
Collaborator

Description

A LLM generated demo for this new model class

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@meta-cla meta-cla bot added the cla signed label Feb 18, 2026
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/dynamo/torch_export_qwen3_reranker.py	2026-02-18 17:30:20.512233+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/dynamo/torch_export_qwen3_reranker.py	2026-02-18 17:30:54.330404+00:00
@@ -125,11 +125,13 @@

    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

-    def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self, input_ids: torch.Tensor, position_ids: torch.Tensor
+    ) -> torch.Tensor:
        out = self.model(input_ids=input_ids, position_ids=position_ids)
        return out.logits


# ---------------------------------------------------------------------------
@@ -202,11 +204,13 @@
# Main
# ---------------------------------------------------------------------------


def parse_args():
-    p = argparse.ArgumentParser(description="Compile Qwen3-Reranker with Torch-TensorRT")
+    p = argparse.ArgumentParser(
+        description="Compile Qwen3-Reranker with Torch-TensorRT"
+    )
    p.add_argument("--model", default="Qwen/Qwen3-Reranker-0.6B", help="HF model name")
    p.add_argument("--precision", default="FP16", choices=["FP16", "BF16", "FP32"])
    p.add_argument(
        "--max_length",
        type=int,
@@ -231,11 +235,13 @@
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    token_true_id = tokenizer.convert_tokens_to_ids("yes")
    token_false_id = tokenizer.convert_tokens_to_ids("no")
-    print(f"  token_true_id (yes) = {token_true_id}, token_false_id (no) = {token_false_id}")
+    print(
+        f"  token_true_id (yes) = {token_true_id}, token_false_id (no) = {token_false_id}"
+    )

    base_model = (
        AutoModelForCausalLM.from_pretrained(
            args.model,
            use_cache=False,
@@ -252,22 +258,26 @@
    base_model = base_model.to(dtype_map[args.precision])

    # ------------------------------------------------------------------
    # 2. Build test inputs
    # ------------------------------------------------------------------
-    instruction = "Given a web search query, retrieve relevant passages that answer the query"
+    instruction = (
+        "Given a web search query, retrieve relevant passages that answer the query"
+    )
    queries = [
        "What is the capital of China?",
        "How does photosynthesis work?",
    ]
    documents = [
        "The capital of China is Beijing.",
        "Photosynthesis is the process by which plants convert sunlight into glucose.",
    ]

    print("Tokenizing inputs ...")
-    inputs = build_inputs(tokenizer, queries, documents, instruction, max_length=args.max_length)
+    inputs = build_inputs(
+        tokenizer, queries, documents, instruction, max_length=args.max_length
+    )
    input_ids = inputs["input_ids"]
    print(f"  input_ids shape: {input_ids.shape}")

    # ------------------------------------------------------------------
    # 3. PyTorch baseline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments