Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/ci-static-checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: "CI: Static Checks"

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
lint-and-fmt:
name: lint and fmt
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Run pre-commit
uses: pre-commit/action@v3.0.1
24 changes: 24 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-case-conflict
- id: check-yaml
- id: check-json
- id: check-merge-conflict
- id: destroyed-symlinks
- id: mixed-line-ending

- repo: https://github.com/tombi-toml/tombi-pre-commit
rev: v0.7.0
hooks:
- id: tombi-format
exclude: ^Cargo\.lock$
- id: tombi-lint

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.6
hooks:
- id: ruff-check
args: [--fix]
- id: ruff-format
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10
29 changes: 19 additions & 10 deletions engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
from __future__ import annotations

import os

os.environ["FLAGS_use_system_allocator"] = "1"
os.environ["FLAGS_USE_SYSTEM_ALLOCATOR"] = "1"
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"

import argparse
from datetime import datetime

import paddle
import torch

from tester import (APIConfig, APITestAccuracy, APITestAccuracyStable,
APITestCINNVSDygraph, APITestPaddleGPUPerformance,
APITestPaddleOnly, APITestPaddleTorchGPUPerformance,
APITestTorchGPUPerformance,APITestCustomDeviceVSCPU,set_cfg)
from tester.api_config.log_writer import (close_process_files, read_log,
write_to_log)
from tester import (
APIConfig,
APITestAccuracy,
APITestAccuracyStable,
APITestCINNVSDygraph,
APITestCustomDeviceVSCPU,
APITestPaddleGPUPerformance,
APITestPaddleOnly,
APITestPaddleTorchGPUPerformance,
APITestTorchGPUPerformance,
set_cfg,
)
from tester.api_config.log_writer import close_process_files, read_log, write_to_log


def parse_bool(value):
Expand Down Expand Up @@ -163,8 +171,8 @@ def main():
paddle.device.cuda.empty_cache()
elif options.api_config_file != "":
finish_configs = read_log("checkpoint")
with open(options.api_config_file, "r") as f:
api_configs = set(line.strip() for line in f if line.strip())
with open(options.api_config_file) as f:
api_configs = {line.strip() for line in f if line.strip()}
api_configs = api_configs - finish_configs
api_configs = sorted(api_configs)
for api_config_str in api_configs:
Expand Down Expand Up @@ -223,5 +231,6 @@ def main():

close_process_files()


if __name__ == "__main__":
main()
Loading