-
Notifications
You must be signed in to change notification settings - Fork 266
Expand file tree
/
Copy pathhelper.py
More file actions
100 lines (87 loc) · 3.45 KB
/
helper.py
File metadata and controls
100 lines (87 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoProcessor, AutoTokenizer
from typing import Any
import torch
import collections.abc
MIN_PIXELS = 163840
MAX_PIXELS = 196608
BASE_PROCESSOR_NAME = "Qwen/Qwen3-VL-2B-Instruct"
def create_message(frames: torch.Tensor):
"""Construct the message using images and cot."""
assert frames.ndim == 4, f"{frames.ndim=}, expected (N, C, H, W)"
# NOTE: we expand the padding tokens to match training, so we can directly apply native processor from VLM.
num_traj_token = 48
hist_traj_placeholder = (
f"<|traj_history_start|>{'<|traj_history|>' * num_traj_token}<|traj_history_end|>"
)
return [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are a driving assistant that generates safe and accurate actions.",
}
],
},
{
"role": "user",
"content": [{"type": "image", "image": frame} for frame in frames]
+ [
{
"type": "text",
"text": f"{hist_traj_placeholder}output the chain-of-thought reasoning of the driving process, then output the future trajectory.",
}
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "<|cot_start|>",
}
],
},
]
def get_processor(tokenizer: AutoTokenizer) -> AutoProcessor:
"""Get the processor for the Qwen3-VL-2B-Instruct model."""
processor_kwargs = {
"min_pixels": MIN_PIXELS,
"max_pixels": MAX_PIXELS,
}
processor = AutoProcessor.from_pretrained(BASE_PROCESSOR_NAME, **processor_kwargs)
processor.tokenizer = tokenizer
return processor
def to_device(
data: Any,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Any:
"""Recursively cast data into the specified device, dtype."""
if isinstance(data, torch.Tensor):
# Only apply dtype conversion to floating-point tensors.
# Integer tensors (e.g., input_ids, attention_mask) must preserve their dtype
# for compatibility with Hugging Face models during mixed-precision inference.
if dtype is not None and data.is_floating_point():
return data.to(device=device, dtype=dtype)
return data.to(device=device)
elif isinstance(data, collections.abc.Mapping):
return {key: to_device(data[key], device=device, dtype=dtype) for key in data}
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
return [to_device(elem, device=device, dtype=dtype) for elem in data]
else:
return data