11import json
22import logging
3- from typing import Generator , Iterable , Iterator , List , Optional , Type
3+ from typing import Generator , List , Optional , Type
44
55from django .core .serializers .json import DjangoJSONEncoder
66from django .db .models import Q
1717)
1818from morango .registry import syncable_models
1919from morango .sync .stream .core import Buffer , Sink , Source , Transform , Unbuffer
20- from morango .utils import self_referential_fk
2120
2221logger = logging .getLogger (__name__ )
2322
@@ -31,6 +30,7 @@ class SerializeTask(object):
3130 "store" ,
3231 "counter" ,
3332 "_self_ref_fk_value" ,
33+ "_self_ref_fk_value" ,
3434 "_self_ref_order" ,
3535 )
3636
@@ -58,7 +58,7 @@ def set_counter(self, counter: RecordMaxCounter):
5858
5959 def self_referential_fk (self ) -> Optional [str ]:
6060 """Return the attname of the self-referential FK on *model*, or ``None``."""
61- return self_referential_fk (self .model )
61+ return syncable_models . get_self_referential_fk (self .model )
6262
6363 @property
6464 def self_ref_fk_value (self ) -> Optional [str ]:
@@ -172,46 +172,19 @@ class SelfRefOrderLookup(Transform[List[SerializeTask]]):
172172 """
173173
174174 def __init__ (self ):
175- # Carries resolved parent orders across buffered chunks in the same pipeline run.
176- self .known_order_by_id = {}
177- self .current_model = None
178-
179- def transform (self , tasks : List [SerializeTask ]) -> List [SerializeTask ]:
180175 # Cache of resolved order values keyed by record id.
181- known_order_by_id = self . known_order_by_id
182- unresolved_tasks = []
176+ self . known_order_by_id : dict [ str , int ] = {}
177+ self . current_model : Optional [ Type [ SyncableModel ]] = None
183178
179+ def transform (self , tasks : List [SerializeTask ]) -> List [SerializeTask ]:
180+ # Carries resolved parent orders across buffered chunks in the same pipeline run,
181+ # clearing the cache when the model changes, which depends on partitioned buffers.
184182 if tasks and self .current_model != tasks [0 ].model :
185- known_order_by_id .clear ()
183+ self . known_order_by_id .clear ()
186184 self .current_model = tasks [0 ].model
187185
188- # First pass:
189- # - identify roots (order = 0)
190- # - resolve children immediately if parent is already in cache
191- # - queue remaining children for DB fallback and secondary pass
192- for task in tasks :
193- self_ref_fk = task .self_referential_fk ()
194- if not self_ref_fk :
195- task .set_self_ref_fk_value (None )
196- task .set_self_ref_order (None )
197- continue
198-
199- self_ref_fk_value = getattr (task .obj , self_ref_fk ) or ""
200- task .set_self_ref_fk_value (self_ref_fk_value )
201-
202- if not self_ref_fk_value :
203- task .set_self_ref_order (0 )
204- known_order_by_id [task .obj .id ] = 0
205- continue
206-
207- parent_order = known_order_by_id .get (self_ref_fk_value )
208- if parent_order is not None :
209- child_order = parent_order + 1
210- task .set_self_ref_order (child_order )
211- known_order_by_id [task .obj .id ] = child_order
212- else :
213- task .set_self_ref_order (None )
214- unresolved_tasks .append (task )
186+ # first pass, assign order from cache if available and determine what needs looked up
187+ unresolved_tasks = self ._assign_from_cache (tasks )
215188
216189 # DB fallback for parents that were not available in cache during the first pass.
217190 unresolved_parent_ids = set (
@@ -222,22 +195,22 @@ def transform(self, tasks: List[SerializeTask]) -> List[SerializeTask]:
222195 id__in = unresolved_parent_ids
223196 ).values_list ("id" , "_self_ref_order" ):
224197 if parent_order is not None :
225- known_order_by_id [parent_id ] = parent_order
198+ self . known_order_by_id [parent_id ] = parent_order
226199
227200 # Resolve remaining children by repeatedly scanning only unresolved tasks.
228201 pending = unresolved_tasks
229202 while pending :
230203 next_pending = []
231204 progressed = False
232205 for task in pending :
233- parent_order = known_order_by_id .get (task .self_ref_fk_value )
206+ parent_order = self . known_order_by_id .get (task .self_ref_fk_value )
234207 if parent_order is None :
235208 next_pending .append (task )
236209 continue
237210
238211 child_order = parent_order + 1
239212 task .set_self_ref_order (child_order )
240- known_order_by_id [task .obj .id ] = child_order
213+ self . known_order_by_id [task .obj .id ] = child_order
241214 progressed = True
242215
243216 if not progressed :
@@ -246,6 +219,41 @@ def transform(self, tasks: List[SerializeTask]) -> List[SerializeTask]:
246219
247220 return tasks
248221
222+ def _assign_from_cache (self , tasks : List [SerializeTask ]) -> List [SerializeTask ]:
223+ """
224+ First pass:
225+ - identify roots (order = 0)
226+ - resolve children immediately if parent is already in cache
227+ - queue remaining children for DB fallback and secondary pass
228+ """
229+ unresolved_tasks = []
230+
231+ for task in tasks :
232+ self_ref_fk = task .self_referential_fk ()
233+ if not self_ref_fk :
234+ task .set_self_ref_fk_value (None )
235+ task .set_self_ref_order (None )
236+ continue
237+
238+ self_ref_fk_value = getattr (task .obj , self_ref_fk ) or ""
239+ task .set_self_ref_fk_value (self_ref_fk_value )
240+
241+ if not self_ref_fk_value :
242+ task .set_self_ref_order (0 )
243+ self .known_order_by_id [task .obj .id ] = 0
244+ continue
245+
246+ parent_order = self .known_order_by_id .get (self_ref_fk_value )
247+ if parent_order is not None :
248+ child_order = parent_order + 1
249+ task .set_self_ref_order (child_order )
250+ self .known_order_by_id [task .obj .id ] = child_order
251+ else :
252+ task .set_self_ref_order (None )
253+ unresolved_tasks .append (task )
254+
255+ return unresolved_tasks
256+
249257
250258class StoreUpdate (Transform [SerializeTask ]):
251259 """Processes the updates to the Morango store and record counters."""
@@ -324,24 +332,6 @@ def _handle_store_create(self, task: SerializeTask):
324332 task .set_store (Store (** kwargs ))
325333
326334
327- class ModelPartitionBuffer (Buffer [List [SerializeTask ]]):
328- """Buffers tasks into chunks that have the same model class."""
329-
330- def __call__ (self , tasks : Iterable [SerializeTask ]) -> Iterator [List [SerializeTask ]]:
331- chunk = []
332- last_model = None
333-
334- for task in tasks :
335- if len (chunk ) >= self .size or (last_model and last_model != task .model ):
336- yield chunk
337- chunk = []
338- last_model = task .model
339- chunk .append (task )
340-
341- if chunk :
342- yield chunk
343-
344-
345335class WriteSink (Sink [List [SerializeTask ]]):
346336 """
347337 Consumes SerializeTask objects and writes the appropriate changes to the database.
@@ -494,6 +484,10 @@ def _update_counters(self):
494484 )
495485
496486
487+ def task_model_partition_fn (task : SerializeTask ) -> Type [SyncableModel ]:
488+ return task .model
489+
490+
497491def serialize_into_store (
498492 profile : str , sync_filter : Optional [Filter ] = None , dirty_only : bool = True
499493):
@@ -510,12 +504,12 @@ def serialize_into_store(
510504 # Execute the main pipeline (consumes the source through to the sink).
511505 result_count = (
512506 AppModelSource (profile , sync_filter = sync_filter , dirty_only = dirty_only )
513- .pipe (Buffer (size = 500 ))
507+ .pipe (Buffer (size = 500 , partition_fn = task_model_partition_fn ))
514508 .pipe (StoreLookup (current_id ))
515509 .pipe (SelfRefOrderLookup ())
516510 .pipe (Unbuffer ())
517511 .pipe (StoreUpdate (current_id ))
518- .pipe (ModelPartitionBuffer (size = 500 ))
512+ .pipe (Buffer (size = 500 , partition_fn = task_model_partition_fn ))
519513 .end (WriteSink (profile , current_id , sync_filter = sync_filter ))
520514 )
521515 logger .info (f"Serialization done: { result_count } records" )
0 commit comments