Skip to content

Commit 532be09

Browse files
authored
feature: train & infer scripts (#13)
* feature: train & infer * feature: train & infer scripts * feature: update README
1 parent faf1b51 commit 532be09

4 files changed

Lines changed: 80 additions & 5 deletions

File tree

README.md

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,53 @@
6767
Our method takes as input a front-view image, a natural-language navigation command with a system prompt, and the ego-vehicle states, and outputs an 8-waypoint future trajectory spanning 4 seconds through parallel denoising. The model is first trained via supervised fine-tuning to learn accurate trajectory prediction. We then apply simulatorguided GRPO to further optimize closed-loop behavior. The GRPO reward function integrates safety constraints (collision avoidance, drivable-area compliance) with performance objectives (ego-progress, time-to-collision, comfort).
6868

6969

70-
## Preparation
7170

72-
### Environment
71+
## Quick Start
72+
73+
### Installation
74+
75+
Clone the repo:
76+
77+
```sh
78+
git clone https://github.com/fudan-generative-vision/WAM-Flow.git
79+
cd WAM-Flow
80+
```
81+
82+
Install dependencies:
83+
7384
```sh
7485
conda create --name wam-flow python=3.10
7586
conda activate wam-flow
7687
pip install -r requirements.txt
7788
```
7889

79-
## Training
90+
91+
### Model Download
92+
93+
Download models using huggingface-cli:
94+
8095
```sh
81-
sh script/sft_debug.sh
96+
pip install "huggingface_hub[cli]"
97+
huggingface-cli download fudan-generative-ai/WAM-Flow --local-dir ./pretrained_model/wam-flow
98+
huggingface-cli download LucasJinWang/FUDOKI --local-dir ./pretrained_model/fudoki
8299
```
83100

84-
## Inference
101+
102+
103+
### Inference
104+
85105
```sh
86106
sh script/infer.sh
87107
```
88108

89109

110+
### Training
111+
112+
```bash
113+
sh script/sft_debug.sh
114+
```
115+
116+
90117

91118
## 📝 Citation
92119

script/infer.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash
2+
3+
CKPT_PATH="pretrained_model/wam-flow/navsim"
4+
FUDOKI_PATH="pretrained_model/fudoki"
5+
IMAGE_PATH="data/navsim_data/sensor_blobs/test/2021.09.09.17.18.51_veh-48_00889_01147/CAM_F0/9a6f0331d98258a0.jpg"
6+
7+
torchrun --nproc_per_node 1 infer.py \
8+
--checkpoint_path $CKPT_PATH \
9+
--image_path $IMAGE_PATH \
10+
--processor_path $FUDOKI_PATH \
11+
--text_embedding_path $FUDOKI_PATH/text_embedding.pt \
12+
--image_embedding_path $FUDOKI_PATH/image_embedding.pt \
13+
--discrete_fm_steps 2 \
14+
--seed 123

script/sft_debug.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
3+
NUM_NODES=1
4+
NUM_GPUS=1
5+
6+
config=config/debug.yaml
7+
output_dir=output/train/debug
8+
9+
accelerate launch \
10+
--config_file ./config/accelerate_config_ds2.yaml \
11+
--machine_rank 0 \
12+
--main_process_port 12345 \
13+
--num_machines $NUM_NODES \
14+
--num_processes $NUM_GPUS \
15+
train.py \
16+
--config $config \
17+
--output_dir $output_dir

script/sft_navsim.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
3+
NUM_NODES=1
4+
NUM_GPUS=1
5+
6+
config=config/sft_navsim.yaml
7+
output_dir=output/train/debug
8+
9+
accelerate launch \
10+
--config_file ./config/accelerate_config_ds2.yaml \
11+
--machine_rank 0 \
12+
--main_process_port 12345 \
13+
--num_machines $NUM_NODES \
14+
--num_processes $NUM_GPUS \
15+
train.py \
16+
--config $config \
17+
--output_dir $output_dir

0 commit comments

Comments
 (0)