|
9 | 9 | from sqlalchemy.ext.asyncio import AsyncSession |
10 | 10 |
|
11 | 11 | from aqueductcore.backend.context import UserInfo, UserScope |
12 | | -from aqueductcore.backend.errors import AQDDBExperimentNonExisting |
| 12 | +from aqueductcore.backend.errors import AQDDBExperimentNonExisting, AQDValidationError |
13 | 13 | from aqueductcore.backend.models import orm |
14 | 14 | from aqueductcore.backend.models.experiment import ExperimentCreate, TagCreate |
15 | 15 | from aqueductcore.backend.services.experiment import ( |
16 | | - add_tag_to_experiment, |
| 16 | + add_tags_to_experiment, |
17 | 17 | build_experiment_dir_absolute_path, |
18 | 18 | create_experiment, |
19 | 19 | generate_experiment_id_and_alias, |
@@ -384,20 +384,108 @@ async def test_add_db_tag_to_experiment( |
384 | 384 |
|
385 | 385 | await db_session.commit() |
386 | 386 |
|
387 | | - in_db_experiment = await add_tag_to_experiment( |
| 387 | + in_db_experiment = await add_tags_to_experiment( |
388 | 388 | user_info=UserInfo( |
389 | 389 | user_id=uuid4(), username=settings.default_username, scopes=set(UserScope) |
390 | 390 | ), |
391 | 391 | db_session=db_session, |
392 | 392 | experiment_id=experiments_data[0].id, |
393 | | - tag="important", |
| 393 | + tags=["important"], |
394 | 394 | ) |
395 | | - await db_session.commit() |
396 | 395 |
|
397 | 396 | in_db_experiment_tags = [tag.name for tag in in_db_experiment.tags] |
398 | 397 | assert "important" in in_db_experiment_tags |
399 | 398 |
|
400 | 399 |
|
| 400 | +@pytest.mark.asyncio |
| 401 | +async def test_add_db_unique_tags_to_experiment( |
| 402 | + db_session: AsyncSession, experiments_data: List[ExperimentCreate] |
| 403 | +): |
| 404 | + """Test update_db_experiment operation""" |
| 405 | + |
| 406 | + db_user = orm.User(id=UUID(int=0), username=settings.default_username) |
| 407 | + db_session.add(db_user) |
| 408 | + |
| 409 | + for experiment in experiments_data: |
| 410 | + db_experiment = experiment_model_to_orm(experiment) |
| 411 | + db_experiment.created_by_user = db_user |
| 412 | + db_session.add(db_experiment) |
| 413 | + |
| 414 | + await db_session.commit() |
| 415 | + |
| 416 | + expected_tags = ["test1", "test2", "test3"] |
| 417 | + in_db_experiment = await add_tags_to_experiment( |
| 418 | + user_info=UserInfo( |
| 419 | + user_id=uuid4(), username=settings.default_username, scopes=set(UserScope) |
| 420 | + ), |
| 421 | + db_session=db_session, |
| 422 | + experiment_id=experiments_data[0].id, |
| 423 | + tags=expected_tags, |
| 424 | + ) |
| 425 | + |
| 426 | + in_db_experiment_tags = [tag.name for tag in in_db_experiment.tags] |
| 427 | + for item in expected_tags: |
| 428 | + assert item in in_db_experiment_tags |
| 429 | + |
| 430 | + |
| 431 | +@pytest.mark.asyncio |
| 432 | +async def test_add_db_unique_tags_to_experiment_pre_existing_tags( |
| 433 | + db_session: AsyncSession, experiments_data: List[ExperimentCreate] |
| 434 | +): |
| 435 | + """Test update_db_experiment operation""" |
| 436 | + |
| 437 | + db_user = orm.User(id=UUID(int=0), username=settings.default_username) |
| 438 | + db_session.add(db_user) |
| 439 | + |
| 440 | + for experiment in experiments_data: |
| 441 | + db_experiment = experiment_model_to_orm(experiment) |
| 442 | + db_experiment.created_by_user = db_user |
| 443 | + db_session.add(db_experiment) |
| 444 | + |
| 445 | + await db_session.commit() |
| 446 | + |
| 447 | + expected_tags = {"test1", "test2", "test3", experiments_data[0].tags[0].name} |
| 448 | + in_db_experiment = await add_tags_to_experiment( |
| 449 | + user_info=UserInfo( |
| 450 | + user_id=uuid4(), username=settings.default_username, scopes=set(UserScope) |
| 451 | + ), |
| 452 | + db_session=db_session, |
| 453 | + experiment_id=experiments_data[0].id, |
| 454 | + tags=list(expected_tags), |
| 455 | + ) |
| 456 | + |
| 457 | + in_db_experiment_tags = [tag.name for tag in in_db_experiment.tags] |
| 458 | + assert expected_tags.issubset(in_db_experiment_tags) |
| 459 | + |
| 460 | + |
| 461 | +@pytest.mark.asyncio |
| 462 | +async def test_add_db_duplicate_tags_in_request_to_experiment( |
| 463 | + db_session: AsyncSession, experiments_data: List[ExperimentCreate] |
| 464 | +): |
| 465 | + """Test update_db_experiment operation""" |
| 466 | + |
| 467 | + db_user = orm.User(id=UUID(int=0), username=settings.default_username) |
| 468 | + db_session.add(db_user) |
| 469 | + |
| 470 | + for experiment in experiments_data: |
| 471 | + db_experiment = experiment_model_to_orm(experiment) |
| 472 | + db_experiment.created_by_user = db_user |
| 473 | + db_session.add(db_experiment) |
| 474 | + |
| 475 | + await db_session.commit() |
| 476 | + |
| 477 | + expected_tags = ["test1", "test1"] |
| 478 | + with pytest.raises(AQDValidationError): |
| 479 | + await add_tags_to_experiment( |
| 480 | + user_info=UserInfo( |
| 481 | + user_id=uuid4(), username=settings.default_username, scopes=set(UserScope) |
| 482 | + ), |
| 483 | + db_session=db_session, |
| 484 | + experiment_id=experiments_data[0].id, |
| 485 | + tags=expected_tags, |
| 486 | + ) |
| 487 | + |
| 488 | + |
401 | 489 | @pytest.mark.asyncio |
402 | 490 | async def test_remove_db_tag_from_experiment( |
403 | 491 | db_session: AsyncSession, experiments_data: List[ExperimentCreate] |
|
0 commit comments