Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_logprob_test_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ jobs:
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
set +e
rm -rf ./baseline_output
cp -r baseline_dev_0311/ERNIE-4.5-0.3B-Paddle ./baseline_output
cp -r baseline_0419/ERNIE-4.5-0.3B-Paddle ./baseline_output
LOGPROB_EXIT_CODE=0
python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$?
echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ def test_non_stream_with_logprobs(api_url):

base_path = os.getenv("MODEL_PATH")
if base_path:
base_file = os.path.join(base_path, "21b_tp1_dp4_logprobs_non_stream_static_baseline.txt")
base_file = os.path.join(base_path, "21b_tp1_dp4_logprobs_non_stream_static_baseline_0419.txt")
else:
base_file = "21b_tp1_dp4_logprobs_non_stream_static_baseline.txt"
base_file = "21b_tp1_dp4_logprobs_non_stream_static_baseline_0419.txt"

with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
Expand Down Expand Up @@ -647,9 +647,9 @@ def test_stream_with_logprobs(api_url):

base_path = os.getenv("MODEL_PATH")
if base_path:
base_file = os.path.join(base_path, "21b_tp1_dp4_logprobs_stream_static_baseline.txt")
base_file = os.path.join(base_path, "21b_tp1_dp4_logprobs_stream_static_baseline_0419.txt")
else:
base_file = "21b_tp1_dp4_logprobs_stream_static_baseline.txt"
base_file = "21b_tp1_dp4_logprobs_stream_static_baseline_0419.txt"

with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ def test_non_stream_with_logprobs(api_url):
base_path = os.getenv("MODEL_PATH")

if base_path:
base_file = os.path.join(base_path, "21b_tp1_dp4_mtp_logprobs_non_stream_static_baseline.txt")
base_file = os.path.join(base_path, "21b_tp1_dp4_mtp_logprobs_non_stream_static_baseline_0419.txt")
else:
base_file = "21b_tp1_dp4_mtp_logprobs_non_stream_static_baseline.txt"
base_file = "21b_tp1_dp4_mtp_logprobs_non_stream_static_baseline_0419.txt"

with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
Expand Down Expand Up @@ -555,9 +555,9 @@ def test_stream_with_logprobs(api_url):
base_path = os.getenv("MODEL_PATH")

if base_path:
base_file = os.path.join(base_path, "21b_tp1_dp4_mtp_logprobs_stream_static_baseline.txt")
base_file = os.path.join(base_path, "21b_tp1_dp4_mtp_logprobs_stream_static_baseline_0419.txt")
else:
base_file = "21b_tp1_dp4_mtp_logprobs_stream_static_baseline.txt"
base_file = "21b_tp1_dp4_mtp_logprobs_stream_static_baseline_0419.txt"

with open(base_file, "r", encoding="utf-8") as f:
baseline = json.load(f)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/utils/baseline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, base_dir="/ModelData"):
self.base_dir = base_dir

def _get_path(self, name: str):
branch = os.getenv("TEST_BRANCH", "default")
branch = os.getenv("TEST_BRANCH", "0419")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 默认值从 "default" 改为硬编码日期 "0419",意味着每次 Paddle 更新导致 baseline 变化时,都需要同步修改此处代码。

建议考虑以下方案之一:

  1. 通过统一的配置文件(如 baseline_version.txt)管理版本号,代码中读取该文件作为默认值
  2. 保留 "default" 作为默认值,在 CI workflow 中统一通过 TEST_BRANCH 环境变量传入实际版本号

这样可以将 baseline 版本与代码解耦,减少未来更新时的改动点。

if os.getenv("STATIC_C8") == "1":
c8_mode = "_static_c8"
elif os.getenv("DYNAMIC_C8") == "1":
Expand Down
Loading