|
7 | 7 | from pathlib import Path |
8 | 8 | from typing import Callable |
9 | 9 |
|
| 10 | +from packaging import version |
10 | 11 | from sqlmodel import Session, text |
11 | 12 |
|
12 | 13 | from module.conf import DATA_PATH |
13 | | -from module.database.engine import engine |
| 14 | +from .engine import engine |
14 | 15 | from module.models import DatabaseVersion |
15 | 16 |
|
16 | 17 | logger = logging.getLogger(__name__) |
@@ -130,35 +131,34 @@ def _get_migration_path(self, from_version: str, to_version: str) -> list[str]: |
130 | 131 | migrations_needed.append(version) |
131 | 132 |
|
132 | 133 | # 按版本号排序 |
133 | | - migrations_needed.sort(key=lambda v: self._version_to_tuple(v)) |
| 134 | + migrations_needed.sort(key=lambda v: version.parse(v)) |
134 | 135 |
|
135 | 136 | return migrations_needed |
136 | 137 |
|
137 | 138 | def _version_compare(self, version1: str, version2: str) -> int: |
138 | 139 | """比较两个版本号 |
139 | 140 | 返回值: 1 表示 version1 > version2, 0 表示相等, -1 表示 version1 < version2 |
140 | 141 | """ |
141 | | - v1_tuple = self._version_to_tuple(version1) |
142 | | - v2_tuple = self._version_to_tuple(version2) |
143 | | - |
144 | | - if v1_tuple > v2_tuple: |
145 | | - return 1 |
146 | | - elif v1_tuple < v2_tuple: |
147 | | - return -1 |
148 | | - else: |
149 | | - return 0 |
150 | | - |
151 | | - def _version_to_tuple(self, version: str) -> tuple: |
152 | | - """将版本字符串转换为可比较的元组""" |
153 | 142 | try: |
154 | | - # 处理类似 "3.2.0" 或 "3.2.0-beta" 的版本号 |
155 | | - # 先移除后缀 (如 -beta, -alpha, -dev 等) |
156 | | - base_version = version.split("-")[0] |
157 | | - parts = base_version.split(".") |
158 | | - return tuple(int(part) for part in parts) |
159 | | - except (ValueError, AttributeError): |
160 | | - # 如果解析失败,返回一个默认值 |
161 | | - return (0, 0, 0) |
| 143 | + v1 = version.parse(version1) |
| 144 | + v2 = version.parse(version2) |
| 145 | + |
| 146 | + if v1 > v2: |
| 147 | + return 1 |
| 148 | + elif v1 < v2: |
| 149 | + return -1 |
| 150 | + else: |
| 151 | + return 0 |
| 152 | + except Exception as e: |
| 153 | + logger.warning(f"版本比较失败 {version1} vs {version2}: {e}") |
| 154 | + # 降级到字符串比较 |
| 155 | + if version1 > version2: |
| 156 | + return 1 |
| 157 | + elif version1 < version2: |
| 158 | + return -1 |
| 159 | + else: |
| 160 | + return 0 |
| 161 | + |
162 | 162 |
|
163 | 163 | def _migrate_to_3_2_0(self, session: Session): |
164 | 164 | """迁移到版本 3.2.0""" |
|
0 commit comments