Skip to content

Commit 0476320

Browse files
authored
Refactoring to use actual Pydantic V2 API (#190)
1 parent 1e8d502 commit 0476320

17 files changed

+115
-111
lines changed

Diff for: .github/workflows/docs.yml

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ jobs:
3838
run: |
3939
python -m pip install --upgrade pip
4040
pip install ".[docs]"
41-
pip install "autodoc_pydantic==1.9.1"
4241
- name: Set environment
4342
run: |
4443
REPO_OWNER="${GITHUB_REPOSITORY%%/*}"

Diff for: pyproject.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ dependencies = [
3333
"psycopg2-binary >=2.9.6",
3434
"psygnal >=0.9.0",
3535
"pyarrow >=16.1.0,<20",
36-
"pydantic >= 2",
36+
"pydantic >=2",
37+
"pydantic-settings",
3738
"pydot >=2.0.0",
3839
"qtawesome >=1.3.1",
3940
"qtpy",
@@ -53,8 +54,7 @@ dependencies = [
5354
[project.optional-dependencies]
5455
docs = [
5556
"sphinx == 6.2.1",
56-
"autodoc_pydantic", # NOT WORKING, needs to install 'autodoc_pydantic==1.9.1' to build docs
57-
"pydantic-settings",
57+
"autodoc_pydantic",
5858
"furo",
5959
"myst-parser >= 2.0.0",
6060
"nbsphinx >= 0.9.3",
@@ -105,6 +105,7 @@ pillow = ">=10.0.0"
105105
psycopg2-binary = ">=2.9.6"
106106
psygnal = ">=0.9.0"
107107
pydantic = ">=2"
108+
pydantic-settings = "*"
108109
pydot = ">=2.0.0"
109110
qtawesome = ">=1.3.1"
110111
qtpy = "*"

Diff for: ultrack/api/_test/test_api.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def experiment_instance(
7777

7878
experiment = Experiment(
7979
name="PyTest",
80-
config=config_instance,
80+
config=config_instance.model_dump(by_alias=True),
8181
data_url=ome_zarr_dataset_path,
8282
image_channel_or_path="image",
8383
edges_channel_or_path="edges",
@@ -103,7 +103,7 @@ def test_config():
103103
default_config = MainConfig()
104104
default_config.data_config = None
105105

106-
assert response.json() == default_config.dict()
106+
assert response.json() == default_config.model_dump(by_alias=True)
107107

108108

109109
def test_manual_segment(experiment_instance: Experiment):
@@ -130,7 +130,7 @@ def test_manual_segment(experiment_instance: Experiment):
130130
with client.websocket_connect("/segment/manual") as websocket:
131131
json_exp = json.loads(
132132
json.dumps(
133-
experiment.dict(),
133+
experiment.model_dump(),
134134
default=lambda o: o.isoformat()
135135
if isinstance(o, (datetime.date, datetime.datetime))
136136
else None,
@@ -139,7 +139,7 @@ def test_manual_segment(experiment_instance: Experiment):
139139
websocket.send_json({"experiment": json_exp})
140140
while experiment.status != ExperimentStatus.SUCCESS:
141141
response = websocket.receive_json()
142-
experiment = Experiment.parse_obj(response)
142+
experiment = Experiment.model_validate(response)
143143
assert experiment.status != ExperimentStatus.ERROR
144144

145145
tracks_df_api, graph_api = to_tracks_layer(experiment.get_config())
@@ -179,7 +179,7 @@ def test_auto_detect(
179179
with client.websocket_connect("/segment/auto_detect") as websocket:
180180
json_exp = json.loads(
181181
json.dumps(
182-
experiment.dict(),
182+
experiment.model_dump(),
183183
default=lambda o: o.isoformat()
184184
if isinstance(o, (datetime.date, datetime.datetime))
185185
else None,
@@ -201,7 +201,7 @@ def test_auto_detect(
201201
websocket.send_json(json_request)
202202
while experiment.status != ExperimentStatus.SUCCESS:
203203
response = websocket.receive_json()
204-
experiment = Experiment.parse_obj(response)
204+
experiment = Experiment.model_validate(response)
205205
assert experiment.status != ExperimentStatus.ERROR
206206

207207
detection = np.zeros_like(image_data, dtype=float)
@@ -250,7 +250,7 @@ def test_from_labels(experiment_instance: Experiment, label_to_edges_kwargs: dic
250250
with client.websocket_connect("/segment/labels") as websocket:
251251
json_exp = json.loads(
252252
json.dumps(
253-
experiment.dict(),
253+
experiment.model_dump(),
254254
default=lambda o: o.isoformat()
255255
if isinstance(o, (datetime.date, datetime.datetime))
256256
else o,
@@ -265,7 +265,7 @@ def test_from_labels(experiment_instance: Experiment, label_to_edges_kwargs: dic
265265
label_to_edges_kwargs = {}
266266
while experiment.status != ExperimentStatus.SUCCESS:
267267
response = websocket.receive_json()
268-
experiment = Experiment.parse_obj(response)
268+
experiment = Experiment.model_validate(response)
269269
assert experiment.status != ExperimentStatus.ERROR
270270

271271
# compare the results with the ones obtained from the ultrack module
@@ -293,7 +293,7 @@ def test_output_experiment(experiment_instance: Experiment):
293293
with client.websocket_connect("/segment/auto_detect") as websocket:
294294
json_exp = json.loads(
295295
json.dumps(
296-
experiment_instance.dict(),
296+
experiment_instance.model_dump(),
297297
default=lambda o: o.isoformat()
298298
if isinstance(o, (datetime.date, datetime.datetime))
299299
else None,
@@ -306,7 +306,7 @@ def test_output_experiment(experiment_instance: Experiment):
306306

307307
while experiment_instance.status != ExperimentStatus.SUCCESS:
308308
response = websocket.receive_json()
309-
experiment_instance = Experiment.parse_obj(response)
309+
experiment_instance = Experiment.model_validate(response)
310310
assert experiment_instance.status != ExperimentStatus.ERROR
311311

312312
assert "robust_invert" in experiment_instance.err_log
@@ -353,7 +353,7 @@ def test_available_configs(experiment_instance: Experiment):
353353

354354
while experiment.status != ExperimentStatus.SUCCESS:
355355
response = websocket.receive_json()
356-
experiment = Experiment.parse_obj(response)
356+
experiment = Experiment.model_validate(response)
357357
assert experiment.status != ExperimentStatus.ERROR
358358

359359

@@ -377,7 +377,7 @@ def test_available_configs(experiment_instance: Experiment):
377377
#
378378
# while experiment_instance.status != ExperimentStatus.ERROR:
379379
# response = websocket.receive_json()
380-
# experiment_instance = Experiment.parse_obj(response)
380+
# experiment_instance = Experiment.model_validate(response)
381381
#
382382
# assert "Exception" in experiment_instance.err_log
383383
# assert "Traceback" in experiment_instance.err_log
@@ -399,9 +399,9 @@ def test_available_configs(experiment_instance: Experiment):
399399
# )
400400
# websocket.send_json({"experiment": json_exp})
401401
# response = websocket.receive_json()
402-
# experiment_instance = Experiment.parse_obj(response)
402+
# experiment_instance = Experiment.model_validate(response)
403403
# response = client.post(f"/stop/{experiment_instance.id}")
404404
# assert response.status_code == 200
405405
# response = websocket.receive_json()
406-
# experiment_instance = Experiment.parse_obj(response)
406+
# experiment_instance = Experiment.model_validate(response)
407407
# assert experiment_instance.status == ExperimentStatus.ERROR

Diff for: ultrack/api/app.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ async def finish_experiment(ws: WebSocket, exp: Experiment) -> Experiment:
151151
The experiment instance to be finished.
152152
153153
"""
154-
exp.end_time = datetime.now()
154+
exp.end_time = datetime.now().isoformat()
155155
exp.status = ExperimentStatus.SUCCESS
156156
update_experiment(exp)
157157
app.state.queue.task_done()
@@ -184,7 +184,7 @@ async def get_default_config() -> Dict:
184184
"""
185185
config = MainConfig()
186186
config.data_config = None
187-
return config.dict()
187+
return config.model_dump(by_alias=True)
188188

189189

190190
@app.get("/config/available")
@@ -201,7 +201,7 @@ async def get_available_configs() -> Dict:
201201

202202
experiment = {
203203
"name": "Unnamed Experiment",
204-
"config": default_config.dict(),
204+
"config": default_config.model_dump(by_alias=True),
205205
}
206206

207207
auto_detect_config = {
@@ -316,7 +316,7 @@ async def auto_detect(websocket: WebSocket) -> None:
316316
)
317317
return
318318

319-
experiment = Experiment.parse_obj(data["experiment"])
319+
experiment = Experiment.model_validate(data["experiment"])
320320

321321
try:
322322
detect_foreground_kwargs = data["detect_foreground_kwargs"]
@@ -432,7 +432,7 @@ async def manual_segment(websocket: WebSocket) -> None:
432432
experiment = json.loads(data["experiment"])
433433
else:
434434
experiment = data["experiment"]
435-
experiment = Experiment.parse_obj(experiment)
435+
experiment = Experiment.model_validate(experiment)
436436

437437
async with UltrackWebsocketLogger(websocket, experiment):
438438
await start_experiment(websocket, experiment)
@@ -495,7 +495,7 @@ async def auto_from_labels(websocket: WebSocket) -> None:
495495
)
496496
return
497497

498-
experiment = Experiment.parse_obj(data["experiment"])
498+
experiment = Experiment.model_validate(data["experiment"])
499499

500500
try:
501501
label_to_edges_kwargs = data["label_to_edges_kwargs"]

Diff for: ultrack/api/database.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Optional
77

88
import sqlalchemy as sqla
9-
from pydantic.v1 import BaseModel, Json, validator
9+
from pydantic import BaseModel, Json, ValidationInfo, field_validator
1010
from sqlalchemy import JSON, Column, Enum, Integer, String, Text
1111
from sqlalchemy.orm import declarative_base, sessionmaker
1212

@@ -22,7 +22,7 @@
2222
def _clean_db_on_exit():
2323
try:
2424
Session._temp_dir.cleanup()
25-
except:
25+
except Exception:
2626
pass
2727

2828

@@ -104,9 +104,9 @@ class Experiment(BaseModel):
104104
The experiment status. Defaults to ExperimentStatus.NOT_PERSISTED.
105105
name : str
106106
The experiment name. Defaults to "Untitled Experiment".
107-
start_time : Optional[datetime]
107+
start_time : Optional[str]
108108
The experiment start time. Defaults to the datetime where it was created.
109-
end_time : Optional[datetime]
109+
end_time : Optional[str]
110110
The experiment end time. Defaults to None and is set when the experiment
111111
finishes.
112112
std_log : str
@@ -153,34 +153,37 @@ class Experiment(BaseModel):
153153
tracks: Optional[Json] = None
154154

155155
def get_config(self) -> MainConfig:
156-
config = MainConfig.parse_obj(self.config)
156+
config = MainConfig.model_validate(self.config)
157157
config.data_config = settings.ultrack_data_config
158158
return config
159159

160-
@validator("status", pre=True, always=True)
161-
def check_if_id_is_valid(cls, v, values, **kwargs):
162-
if (
163-
v == ExperimentStatus.NOT_PERSISTED
164-
and "id" in values
165-
and values["id"] is not None
166-
):
160+
@field_validator("status", mode="before")
161+
@classmethod
162+
def check_if_id_is_valid(cls, v, info: ValidationInfo) -> ExperimentStatus:
163+
exp_id = info.data.get("id")
164+
if v == ExperimentStatus.NOT_PERSISTED and exp_id is not None:
167165
raise ValueError(
168166
"The id cannot be set if the experiment was never "
169167
"persisted in the database."
170168
)
171-
elif v != ExperimentStatus.NOT_PERSISTED and "id" not in values:
169+
elif v != ExperimentStatus.NOT_PERSISTED and exp_id is None:
172170
raise ValueError(
173171
"The id must be set if the experiment was persisted in the database."
174172
)
175173
return v
176174

177-
@validator("start_time", pre=True, always=True)
178-
def set_start_time(cls, v):
175+
@field_validator("start_time", mode="before")
176+
@classmethod
177+
def set_start_time(cls, v) -> str:
179178
return v or datetime.now().isoformat()
180179

181-
@validator("end_time", always=True)
182-
def set_end_time(cls, v, values):
183-
if values["status"] in [ExperimentStatus.SUCCESS, ExperimentStatus.FAILED]:
180+
@field_validator("end_time")
181+
@classmethod
182+
def set_end_time(cls, v, info: ValidationInfo) -> Optional[str]:
183+
if info.data.get("status") in [
184+
ExperimentStatus.SUCCESS,
185+
ExperimentStatus.FAILED,
186+
]:
184187
return v or datetime.now().isoformat()
185188
return None
186189

@@ -262,7 +265,7 @@ def update_experiment(experiment: Experiment) -> None:
262265
experiment_db = session.query(ExperimentDB).filter_by(id=experiment.id).first()
263266
if experiment_db is None:
264267
raise ValueError(f"Experiment {experiment.id} not found.")
265-
for key, value in experiment.dict().items():
268+
for key, value in experiment.model_dump().items():
266269
if key != "id":
267270
setattr(experiment_db, key, value)
268271
session.commit()

Diff for: ultrack/api/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Union
33

4-
from pydantic.v1 import BaseSettings
4+
from pydantic_settings import BaseSettings
55

66
from ultrack.config import DataConfig
77

Diff for: ultrack/cli/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def config_cli(output_path: Path) -> None:
1616
raise ValueError(f"{output_path} already exists.")
1717

1818
with open(output_path, mode="w") as f:
19-
toml.dump(config.dict(by_alias=True), f)
19+
toml.dump(config.model_dump(by_alias=True), f)

Diff for: ultrack/config/_test/test_config.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55
import toml
6-
from pydantic.v1 import ValidationError
76

87
from ultrack.config import load_config
98

@@ -14,7 +13,7 @@ def _assert_input_in_target(input: Dict, target: Dict) -> None:
1413
if isinstance(input[k], dict):
1514
_assert_input_in_target(input[k], target[k])
1615
else:
17-
assert target[k] == input[k]
16+
assert target[k] == input[k], f"Key {k} not equal"
1817

1918

2019
def _format_config(config: Dict) -> None:
@@ -29,7 +28,7 @@ def test_config_content(config_path: Path, config_content: Dict[str, Any]) -> No
2928
"""Tests if content is loaded correctly"""
3029
config = load_config(config_path)
3130
_format_config(config_content)
32-
_assert_input_in_target(config_content, config.dict())
31+
_assert_input_in_target(config_content, config.model_dump())
3332

3433

3534
def test_invalid_config_content(tmp_path: Path, config_content: Dict[str, Any]) -> None:
@@ -40,5 +39,5 @@ def test_invalid_config_content(tmp_path: Path, config_content: Dict[str, Any])
4039
with open(path, mode="w") as f:
4140
toml.dump(config_content, f)
4241

43-
with pytest.raises(ValidationError):
42+
with pytest.raises(ValueError):
4443
load_config(path)

Diff for: ultrack/config/config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, Union
44

55
import toml
6-
from pydantic.v1 import BaseModel, Extra, Field
6+
from pydantic import BaseModel, ConfigDict, Field
77

88
from ultrack.config.dataconfig import DataConfig
99
from ultrack.config.segmentationconfig import SegmentationConfig
@@ -38,13 +38,13 @@ class LinkingConfig(BaseModel):
3838
the segmentation masks of neighboring segments
3939
"""
4040

41-
class Config:
42-
extra = Extra.forbid
41+
model_config = ConfigDict(extra="forbid")
4342

4443

4544
class MainConfig(BaseModel):
4645
data_config: Optional[DataConfig] = Field(
47-
default_factory=DataConfig, alias="data", nullable=True
46+
default_factory=DataConfig,
47+
alias="data",
4848
)
4949
"""
5050
Configuration for intermediate data storage and retrieval.
@@ -71,4 +71,4 @@ def load_config(path: Union[str, Path]) -> MainConfig:
7171
with open(path) as f:
7272
data = toml.load(f)
7373
LOG.info(data)
74-
return MainConfig.parse_obj(data)
74+
return MainConfig.model_validate(data)

0 commit comments

Comments
 (0)