Skip to content

Commit e7d2cec

Browse files
committed
实现BaseModelSerializer,支持序列化字段过滤
1 parent 43315f5 commit e7d2cec

File tree

3 files changed

+62
-48
lines changed

3 files changed

+62
-48
lines changed

sql_api/api_views/sql_workflow.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
ExecuteCheckSerializer,
2828
ExecuteCheckResultSerializer,
2929
SqlWorkflowSerializer,
30-
SqlWorkflowDetailSerializer,
3130
)
3231

3332

@@ -64,12 +63,12 @@ def post(self, request):
6463
summary="获取SQL工单列表",
6564
description="获取SQL工单列表,支持筛选、分页、检索等",
6665
request=SqlWorkflowSerializer,
67-
responses={200: SqlWorkflowSerializer},
66+
responses={200: SqlWorkflowSerializer(exclude=["sql_content", "display_content"])},
6867
),
6968
retrieve=extend_schema(
7069
summary="获取SQL工单详情",
7170
description="通过工单ID获取工单详情",
72-
responses={200: SqlWorkflowDetailSerializer},
71+
responses={200: SqlWorkflowSerializer},
7372
),
7473
rollback_sql=extend_schema(
7574
summary="获取SQL工单回滚语句",
@@ -86,6 +85,7 @@ def post(self, request):
8685
)
8786
class SqlWorkflowView(viewsets.ModelViewSet):
8887
permission_classes = [IsAuthenticated, SqlWorkFlowViewPermission]
88+
serializer_class = SqlWorkflowSerializer
8989
pagination_class = BootStrapTablePagination
9090
filter_backends = [
9191
filters.SearchFilter,
@@ -118,10 +118,13 @@ def get_queryset(self):
118118
queryset = SqlWorkflow.objects.filter(**filter_dict).order_by("-id")
119119
return self.get_serializer_class().setup_eager_loading(queryset)
120120

121-
def get_serializer_class(self):
121+
def get_serializer(self, *args, **kwargs):
122+
123+
serializer_class = self.get_serializer_class()
124+
kwargs.setdefault('context', self.get_serializer_context())
122125
if self.action == "retrieve":
123-
return SqlWorkflowDetailSerializer
124-
return SqlWorkflowSerializer
126+
return serializer_class(*args, **kwargs)
127+
return serializer_class(exclude=["sql_content", "display_content"], *args, **kwargs)
125128

126129
@action(methods=["get"], detail=True, pagination_class=None, filter_backends=[], search_fields=None)
127130
def rollback_sql(self, request, *args, **kwargs):

sql_api/serializers/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# -*- coding: UTF-8 -*-
2+
"""
3+
@author: hhyo
4+
@license: Apache Licence
5+
@file: __init__.py
6+
@time: 2022/10/22
7+
"""
8+
__author__ = 'hhyo'
9+
10+
from rest_framework import serializers
11+
12+
13+
class BaseModelSerializer(serializers.ModelSerializer):
14+
"""BaseModelSerializer,主要是引入过滤和排除字段的方法"""
15+
16+
def __init__(self, *args, **kwargs):
17+
"""
18+
``fields`` 需要保留的字段列表
19+
``exclude`` 需要排除的字段列表
20+
"""
21+
fields = kwargs.pop("fields", None)
22+
exclude = kwargs.pop("exclude", None)
23+
super(BaseModelSerializer, self).__init__(*args, **kwargs)
24+
25+
for field_name in set(self.fields.keys()):
26+
if not any([fields, exclude]):
27+
break
28+
if fields and field_name in fields:
29+
continue
30+
if exclude and field_name not in exclude:
31+
continue
32+
self.fields.pop(field_name, None)
33+
34+
@staticmethod
35+
def setup_eager_loading(queryset):
36+
"""
37+
Perform necessary eager loading of data.
38+
https://ses4j.github.io/2015/11/23/optimizing-slow-django-rest-framework-performance/
39+
"""
40+
pass

sql_api/serializers/sql_workflow.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sql.engines import ReviewSet, get_engine
1818
from sql.engines.models import ReviewResult
1919
from sql.models import Instance, SqlWorkflow
20+
from sql_api.serializers import BaseModelSerializer
2021

2122
logger = logging.getLogger("default")
2223

@@ -69,57 +70,18 @@ class ExecuteCheckResultSerializer(serializers.Serializer):
6970
affected_rows = serializers.IntegerField(read_only=True)
7071

7172

72-
class SqlWorkflowSerializer(serializers.ModelSerializer):
73+
class SqlWorkflowSerializer(BaseModelSerializer):
7374
"""SQL工单"""
7475

7576
instance_name = serializers.CharField(source="instance.instance_name")
76-
77-
def __init__(self, *args, **kwargs):
78-
"""
79-
``fields`` 需要保留的字段列表
80-
``exclude`` 需要排除的字段列表
81-
"""
82-
fields = kwargs.pop("fields", None)
83-
exclude = kwargs.pop("exclude", None)
84-
super(SqlWorkflowSerializer, self).__init__(*args, **kwargs)
85-
86-
for field_name in set(self.fields.keys()):
87-
if not any([fields, exclude]):
88-
break
89-
if fields and field_name in fields:
90-
continue
91-
if exclude and field_name not in exclude:
92-
continue
93-
self.fields.pop(field_name, None)
77+
sql_content = serializers.CharField(source="sqlworkflowcontent.sql_content")
78+
display_content = serializers.SerializerMethodField()
9479

9580
@staticmethod
9681
def setup_eager_loading(queryset):
97-
"""
98-
Perform necessary eager loading of data.
99-
https://ses4j.github.io/2015/11/23/optimizing-slow-django-rest-framework-performance/
100-
"""
10182
queryset = queryset.select_related("instance")
10283
return queryset
10384

104-
@staticmethod
105-
def rollback_sql(obj):
106-
try:
107-
query_engine = get_engine(instance=obj.instance)
108-
return query_engine.get_rollback(workflow=obj)
109-
except Exception as msg:
110-
logger.error(traceback.format_exc())
111-
raise serializers.ValidationError({"errors": msg})
112-
113-
class Meta:
114-
model = SqlWorkflow
115-
fields = "__all__"
116-
117-
118-
class SqlWorkflowDetailSerializer(SqlWorkflowSerializer):
119-
instance_name = serializers.CharField(source="instance.instance_name")
120-
sql_content = serializers.CharField(source="sqlworkflowcontent.sql_content")
121-
display_content = serializers.SerializerMethodField()
122-
12385
@extend_schema_field(field=serializers.ListField(child=ReviewResultSerializer()))
12486
def get_display_content(self, obj):
12587
"""获取工单详情用于列表展示的内容,区分不同的状态进行转换"""
@@ -161,6 +123,15 @@ def get_display_content(self, obj):
161123
rows = obj.sqlworkflowcontent.review_content
162124
return json.loads(rows)
163125

126+
@staticmethod
127+
def rollback_sql(obj):
128+
try:
129+
query_engine = get_engine(instance=obj.instance)
130+
return query_engine.get_rollback(workflow=obj)
131+
except Exception as msg:
132+
logger.error(traceback.format_exc())
133+
raise serializers.ValidationError({"errors": msg})
134+
164135
class Meta:
165136
model = SqlWorkflow
166137
fields = "__all__"

0 commit comments

Comments
 (0)