Skip to content

Commit 3bf9c85

Browse files
garrett4wademeizhiyu.mzyzhaochenyang20zhaochen20Jayon02
authored
[Fix] Merge previous contributions from fw/refactor to lite (areal-project#163)
* initial proposal * add arealite * . * change api * . * remove LOG_ROOT * remove MODEL_SAVE_PATH * remove PARAM_REALLOC_PATH, DATASET_CACHE * prepare for testing * prepare for testing * ready for run * local run * tests mainly pass * format * . * amend cluster.py * . * . * client test pass * pass rollout test * remove unused imports * add arealite readme * change api * . * . * . * . * . * . * . * . * format * . * implement iteraptable generation (areal-project#112) Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com> * . * fix * . * . * . * pass controller generate batch test * . * refactor rollout controller into worker and controller * . * . * . * change to async rollout * pass rollout controller test * pass test * . * update readme * . * sft debug * . * add lisence * remove unused files * remove unsed args in ppo * add hf engine wrapper (areal-project#116) * add hf engine * fix issues * fix ppo bugs and add test * add hf client interface and modify cli args * fix bugs * fix issues * Merge fw/refactor * Finish hf wrapper test * add test --------- Co-authored-by: Wei Fu <36355462+garrett4wade@users.noreply.github.com> * format * format * . * refine hf engine * . * fix * add fsdp engine and sft tests * . * . * . * pass ppo unittest * pass ppo and rollout controller tests * clear unused imports * rename ppo to grpo * change reward function organization * reorganize code * add dataset api * . * . * . * format * chmod fix * . * rename workflow to collector * refactor llm_client location * . * . * fix llm server api * refactor config structure * . * fix tests * . * . * . * Fix unresolved issue in SFTTrainer PR (areal-project#139) * . * . * efficient loading * format * . * . * . * . * . * . * Add CI for testing AReaLite (areal-project#150) * ci: add test-arealite * ci: add checkout before running test-arealite * ci: add USERNAME * ci: add test script * ci: add GitHub mirror * ci: fix typo * ci: clone one commit * ci: fix condition * ci: set command timeout to 60m * ci: enable pip cache * ci: optimize container lifecycle * ci: split into many stages * ci(test-arealite): fix typo * ci: fix wrong env * ci: fix pytest * ci: uninstall transformer-engine * ci: uninstall transformer-engine * ci: fix model paths * ci: show stdout/stderr * ci: fix not clean up * ci: backup sglang * ci: remove tmp repo dir when run * ci: fix docker run exit 1 condition * ci(test-arealite): limit the concurrency and extend command timeout * . * merge fw/refactor * revert some changes * fix --------- Co-authored-by: meizhiyu.mzy <meizhiyu.mzy@antgroup.com> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: zhaochenyang <zhaochenyang20@gmail.com> Co-authored-by: Jayon02 <qiujiangc@outlook.com> Co-authored-by: root <meizhiyu.mzy> Co-authored-by: Zijian Zhang <futrime@outlook.com>
1 parent d48bf00 commit 3bf9c85

5 files changed

Lines changed: 9 additions & 4 deletions

File tree

arealite/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,4 +737,4 @@ dataloader = StatefulDataLoader(
737737
)
738738
for data in dataloader:
739739
assert isinstance(data, list)
740-
```
740+
```

arealite/api/cli_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from pathlib import Path
55
from typing import Dict, List, Optional, Tuple
66

7+
import uvloop
8+
9+
uvloop.install()
710
from hydra import compose as hydra_compose
811
from hydra import initialize as hydra_init
912
from omegaconf import MISSING, OmegaConf

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ dependencies = [
5353
"hydra-core==1.4.0.dev1",
5454
"packaging",
5555
"tabulate",
56+
"gymnasium>=1.1.1",
5657
"torchdata",
5758
"autoflake",
58-
"gymnasium",
5959
"tensordict",
6060

6161
# Monitoring and logging

realhf/api/core/data_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import time
1010
from contextlib import contextmanager
11+
from functools import lru_cache
1112

1213
# NOTE: We don't sue wildcard importing here because the type
1314
# `Sequence` has a very similar name to `SequenceSample`.
@@ -47,6 +48,7 @@
4748
RL_TASKS = ["math", "code", "rlhf", "stem"]
4849

4950

51+
@lru_cache(maxsize=8)
5052
def load_hf_tokenizer(
5153
model_name_or_path: str,
5254
fast_tokenizer=True,

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ word2number
6969
Pebble
7070
timeout-decorator
7171
prettytable
72+
gymnasium>=1.1.1
7273
swanlab[dashboard]
7374
torchdata
7475
autoflake
75-
gymnasium
76-
tensordict
76+
tensordict

0 commit comments

Comments
 (0)