Skip to content

Commit 3f41064

Browse files
committed
test: add unit tests for the variable repository
1 parent b16a188 commit 3f41064

File tree

1 file changed

+107
-0
lines changed
  • src/backend/tests/unit/services/database/models/variable

1 file changed

+107
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import pytest
2+
from uuid import uuid4
3+
4+
from sqlmodel import SQLModel, Session, create_engine
5+
6+
from langflow.services.database.models.variable.model import Variable
7+
from langflow.services.database.models.variable.repo import VariableRepository
8+
9+
10+
@pytest.fixture
11+
def client():
12+
pass
13+
14+
15+
@pytest.fixture
16+
def repo():
17+
engine = create_engine("sqlite:///:memory:")
18+
SQLModel.metadata.create_all(engine)
19+
with Session(engine) as session:
20+
return VariableRepository(session)
21+
22+
23+
def test_add(repo):
24+
user_id = uuid4()
25+
name = "test"
26+
value = "test"
27+
_type = "test"
28+
default_fields = ["test"]
29+
variable = Variable(user_id=user_id, name=name, value=value, type=_type, default_fields=["test"])
30+
31+
result = repo.add(variable)
32+
33+
assert result.id is not None
34+
assert result.user_id == user_id
35+
assert result.name == name
36+
assert result.value == value
37+
assert result.type == _type
38+
assert result.default_fields == default_fields
39+
40+
41+
def test_get(repo):
42+
user_id = uuid4()
43+
name = "test"
44+
value = "test"
45+
_type = "test"
46+
default_fields = ["test"]
47+
variable = Variable(user_id=user_id, name=name, value=value, type=_type, default_fields=["test"])
48+
saved = repo.add(variable)
49+
50+
result = repo.get(saved.id)
51+
52+
assert result == saved
53+
54+
55+
def test_list(repo):
56+
user_id = uuid4()
57+
name = "test"
58+
value = "test"
59+
_type = "test"
60+
default_fields = ["test"]
61+
quantity = 10
62+
for index, i in enumerate(range(quantity)):
63+
variable = Variable(user_id=user_id, name=name, value=f"value_{index}", type=_type, default_fields=["test"])
64+
repo.add(variable)
65+
66+
result = repo.list()
67+
68+
assert len(result) == quantity
69+
70+
71+
def test_update(repo):
72+
user_id = uuid4()
73+
name = "test"
74+
value = "test"
75+
_type = "test"
76+
default_fields = ["test"]
77+
variable = Variable(user_id=user_id, name=name, value=value, type=_type, default_fields=["test"])
78+
saved = repo.add(variable)
79+
saved.name = "test_updated"
80+
saved.value = "test_updated"
81+
saved.type = "test_updated"
82+
saved.default_fields = ["test_updated"]
83+
84+
repo.update(saved)
85+
result = repo.get(saved.id)
86+
87+
assert result.id == saved.id
88+
assert result.user_id == saved.user_id
89+
assert result.name == saved.name
90+
assert result.value == saved.value
91+
assert result.type == saved.type
92+
assert result.default_fields == saved.default_fields
93+
94+
95+
def test_delete(repo):
96+
user_id = uuid4()
97+
name = "test"
98+
value = "test"
99+
_type = "test"
100+
default_fields = ["test"]
101+
variable = Variable(user_id=user_id, name=name, value=value, type=_type, default_fields=["test"])
102+
saved = repo.add(variable)
103+
104+
repo.delete(saved.id)
105+
result = repo.get(saved.id)
106+
107+
assert result is None

0 commit comments

Comments
 (0)