Skip to content

Commit f51bcb9

Browse files
authored
Merge pull request #5 from muehlemann-popp/4-get_query_from_queryset-not-works-if-value-passes
fix get_query_from_queryset
2 parents e008d01 + b3affa4 commit f51bcb9

File tree

3 files changed

+20
-33
lines changed

3 files changed

+20
-33
lines changed

django_materialized_view/base_model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import datetime
22
import logging
33
import time
4-
from typing import Callable, Dict, Optional, Union
4+
from typing import Callable, Dict, Optional, Union, Tuple
55

66
from django.conf import settings
77
from django.db import DEFAULT_DB_ALIAS, connections, models
@@ -114,7 +114,7 @@ def refresh(cls, using: Optional[str] = None, concurrently: Optional[bool] = Non
114114
log.save()
115115

116116
@classmethod
117-
def view_definition(cls):
117+
def view_definition(cls) -> Tuple[str, tuple]:
118118
return cls.__get_query()
119119

120120
@classmethod
@@ -150,18 +150,18 @@ def get_query_from_queryset() -> QuerySet:
150150
pass
151151

152152
@classmethod
153-
def __get_query(cls) -> str:
153+
def __get_query(cls, *args) -> Tuple[str, tuple]:
154154
queryset = cls.get_query_from_queryset()
155-
156155
if isinstance(queryset, QuerySet):
157-
sql_query = f"{queryset.query}; {cls.__create_index_for_primary_key()}"
158-
return sql_query
156+
query, args = queryset.query.sql_with_params()
157+
sql_query = f"{query}; {cls.__create_index_for_primary_key()}"
158+
return sql_query, args
159159
try:
160160
with open(cls.__get_sql_file_path(), "r") as sql_file:
161161
sql_query = f"{sql_file.read()}; {cls.__create_index_for_primary_key()}"
162162
except FileNotFoundError as exc:
163163
raise FileNotFoundError(f"{exc}, - please create SQL file and put it to this directory")
164-
return sql_query
164+
return sql_query, args
165165

166166
@classmethod
167167
def __get_class_name(cls) -> str:

django_materialized_view/processor.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def mark_to_be_applied_new_views(self) -> None:
5151
view_models = self.__get_current_view_models()
5252
for (app_label, model_name), view_model in view_models.items():
5353
view_name = self.__get_view_name(app_label, model_name)
54-
actual_view_definition = self.__get_actual_view_definition(view_name)
55-
actual_view_definition_hash = self.__get_hash_from_string(actual_view_definition)
54+
actual_view_definition, args = self.__get_actual_view_definition(view_name)
55+
actual_view_definition_hash = self.__get_hash_from_string(actual_view_definition % args)
5656

5757
previous_view_definition_hash = self.__get_previous_view_definition_hash(app_label, model_name)
5858

@@ -127,10 +127,10 @@ def get_view_list_sorted_by_dependencies(self, views: Set[str]) -> OrderedDict[s
127127
def _create_view(self, view_name: str) -> bool:
128128
logger.debug(f"Creating view: {view_name}")
129129

130-
view_definition = self.__get_actual_view_definition(view_name)
130+
view_definition, args = self.__get_actual_view_definition(view_name)
131131
with connection.cursor() as cursor:
132132
try:
133-
cursor.execute(self.CREATE_COMMAND_TEMPLATE % (view_name, view_definition))
133+
cursor.execute(self.CREATE_COMMAND_TEMPLATE % (view_name, view_definition), args)
134134
except ProgrammingError as exc:
135135
logger.debug(f"Unable to create view: {view_name}. Error: {exc.args}")
136136
if "already exists" in exc.args[0]:
@@ -224,11 +224,11 @@ def __get_ref_views(self, view_name: str) -> List[str]:
224224
def __get_actual_view_definition(self, view_name: str) -> str:
225225
view_model = DBViewsRegistry[view_name]
226226
if callable(view_model.view_definition):
227-
raw_view_definition = view_model.view_definition()
227+
raw_view_definition, args = view_model.view_definition()
228228
else:
229-
raw_view_definition = view_model.view_definition
229+
raise ValueError("view_definition must be callable")
230230
view_definition = self.__get_cleaned_view_definition_value(raw_view_definition)
231-
return view_definition
231+
return view_definition, args
232232

233233
def __prioritize_view(self, view: str, related_views: List[str], dependencies_story: set[str]) -> None:
234234
if related_views:
@@ -277,7 +277,7 @@ def __get_related_views(self) -> List[Dict[str, str]]:
277277
def __get_cleaned_view_definition_value(view_definition: str) -> str:
278278
assert isinstance(
279279
view_definition, str
280-
), "View definition must be callable and return string or be itself a string."
280+
), "View definition must be callable and return Tuple[str, Optional[tuple]]."
281281
return view_definition.strip()
282282

283283
@staticmethod

testproject/tests/test_base_mode_tests.py

+5-18
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test__get_cleaned_view_definition_value__success(self):
9191
def test__get_cleaned_view_definition_value__invalid(self):
9292
with pytest.raises(AssertionError) as exc:
9393
self.view_processor._MaterializedViewsProcessor__get_cleaned_view_definition_value(1)
94-
assert exc.value.args == ("View definition must be callable and return string or be itself a string.",)
94+
assert exc.value.args == ("View definition must be callable and return Tuple[str, Optional[tuple]].",)
9595

9696
@pytest.mark.django_db
9797
def test__get_related_views__success(self):
@@ -130,6 +130,7 @@ def test__get_actual_view_definition__success(self, mocker, subtests):
130130
test_view_definition = "test"
131131
with subtests.test(msg="view_definition callable"):
132132
test_view_mock = MagicMock()
133+
test_view_mock.view_definition.return_value = ("test raw query", ())
133134
DBViewsRegistry[test_view_definition] = test_view_mock
134135
get_cleaned_view_mock = mocker.patch.object(
135136
MaterializedViewsProcessor,
@@ -139,22 +140,8 @@ def test__get_actual_view_definition__success(self, mocker, subtests):
139140

140141
result = self.view_processor._MaterializedViewsProcessor__get_actual_view_definition("test")
141142

142-
get_cleaned_view_mock.assert_called_once_with(test_view_mock.view_definition())
143-
assert result == test_view_definition
144-
145-
with subtests.test(msg="view_definition is str"):
146-
test_view_mock = MagicMock()
147-
test_view_mock.view_definition = "test"
148-
DBViewsRegistry[test_view_definition] = test_view_mock
149-
get_cleaned_view_mock = mocker.patch.object(
150-
MaterializedViewsProcessor,
151-
"_MaterializedViewsProcessor__get_cleaned_view_definition_value",
152-
return_value=test_view_definition,
153-
)
154-
155-
result = self.view_processor._MaterializedViewsProcessor__get_actual_view_definition("test")
156-
get_cleaned_view_mock.assert_called_once_with(test_view_mock.view_definition)
157-
assert result == test_view_definition
143+
get_cleaned_view_mock.assert_called_once_with("test raw query")
144+
assert result == (test_view_definition, ())
158145

159146
def test__get_ref_views__success(self, mocker):
160147
ref_view_name = "test_ref_view"
@@ -325,7 +312,7 @@ def test__create_view__success(self, mocker, subtests):
325312
test_view_name = "viewname"
326313
full_view_name = f"{test_app_name}_{test_view_name}"
327314

328-
view_definition = "SELECT * FROM pg_depend"
315+
view_definition = "SELECT * FROM pg_depend", ()
329316
get_actual_view_definition_mock = mocker.patch.object(
330317
MaterializedViewsProcessor,
331318
"_MaterializedViewsProcessor__get_actual_view_definition",

0 commit comments

Comments
 (0)