Skip to content

Commit bb0f65b

Browse files
sandeeppalsandeepvinayak
authored andcommitted
Add atomic writes in mem docstore
1 parent 76b0ecc commit bb0f65b

File tree

2 files changed

+95
-17
lines changed

2 files changed

+95
-17
lines changed

docstore/memdocstore/mem.go

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
// Action lists are executed concurrently. Each action in an action list is executed
2424
// in a separate goroutine.
2525
//
26+
// memdocstore supports atomic writes. When using AtomicWrites(), all write actions
27+
// in the action list are executed atomically - either all succeed or all fail together.
28+
//
2629
// memdocstore calls the BeforeDo function of an ActionList once before executing the
2730
// actions. Its As function never returns true.
2831
//
@@ -198,14 +201,79 @@ func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, o
198201
}
199202
}
200203

201-
beforeGets, gets, writes, _, afterGets := driver.GroupActions(actions)
204+
beforeGets, gets, writes, writesTx, afterGets := driver.GroupActions(actions)
202205
run(beforeGets)
203206
run(gets)
204207
run(writes)
208+
209+
// Handle atomic writes separately to ensure they are truly atomic
210+
if len(writesTx) > 0 {
211+
c.runAtomicWrites(ctx, writesTx, errs)
212+
}
213+
205214
run(afterGets)
206215
return driver.NewActionListError(errs)
207216
}
208217

218+
// runAtomicWrites executes multiple write actions atomically.
219+
// All writes either succeed or all fail together.
220+
func (c *collection) runAtomicWrites(ctx context.Context, actions []*driver.Action, errs []error) {
221+
// Stop if the context is done.
222+
if ctx.Err() != nil {
223+
for _, a := range actions {
224+
errs[a.Index] = ctx.Err()
225+
}
226+
return
227+
}
228+
229+
c.mu.Lock()
230+
defer c.mu.Unlock()
231+
232+
// First, validate all actions and collect current documents
233+
type actionInfo struct {
234+
action *driver.Action
235+
current storedDoc
236+
exists bool
237+
}
238+
239+
actionInfos := make([]actionInfo, len(actions))
240+
for i, a := range actions {
241+
info := &actionInfos[i]
242+
info.action = a
243+
244+
if a.Key != nil {
245+
info.current, info.exists = c.docs[a.Key]
246+
}
247+
248+
// Check for NotFound errors
249+
if !info.exists && (a.Kind == driver.Replace || a.Kind == driver.Update || a.Kind == driver.Get) {
250+
for _, a2 := range actions {
251+
errs[a2.Index] = gcerr.Newf(gcerr.NotFound, nil, "document with key %v does not exist", a.Key)
252+
}
253+
return
254+
}
255+
256+
// Check revision conflicts
257+
if err := c.checkRevision(a.Doc, info.current); err != nil {
258+
for _, a2 := range actions {
259+
errs[a2.Index] = err
260+
}
261+
return
262+
}
263+
}
264+
265+
// Now execute all actions atomically
266+
for _, info := range actionInfos {
267+
if err := c.executeAction(info.action, info.current, info.exists); err != nil {
268+
// If any action fails, mark all actions as failed
269+
for _, a2 := range actions {
270+
errs[a2.Index] = err
271+
}
272+
return
273+
}
274+
}
275+
}
276+
209277
// runAction executes a single action.
210278
func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
211279
// Stop if the context is done.
@@ -227,6 +295,31 @@ func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
227295
if !exists && (a.Kind == driver.Replace || a.Kind == driver.Update || a.Kind == driver.Get) {
228296
return gcerr.Newf(gcerr.NotFound, nil, "document with key %v does not exist", a.Key)
229297
}
298+
299+
// Check revision conflicts
300+
if a.Kind != driver.Get && a.Kind != driver.Create {
301+
if err := c.checkRevision(a.Doc, current); err != nil {
302+
return err
303+
}
304+
}
305+
306+
// Execute the action for Get
307+
if a.Kind == driver.Get {
308+
// Handle Get separately since it doesn't modify the document
309+
// We've already retrieved the document into current, above.
310+
// Now we copy its fields into the user-provided document.
311+
if err := decodeDoc(current, a.Doc, a.FieldPaths); err != nil {
312+
return err
313+
}
314+
return nil
315+
}
316+
317+
return c.executeAction(a, current, exists)
318+
}
319+
320+
// executeAction executes a single action. Must be called with the lock held.
321+
// This method is shared between runAction and runAtomicWrites to eliminate code duplication.
322+
func (c *collection) executeAction(a *driver.Action, current storedDoc, exists bool) error {
230323
switch a.Kind {
231324
case driver.Create:
232325
// It is an error to attempt to create an existing document.
@@ -244,9 +337,6 @@ func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
244337
fallthrough
245338

246339
case driver.Replace, driver.Put:
247-
if err := c.checkRevision(a.Doc, current); err != nil {
248-
return err
249-
}
250340
doc, err := encodeDoc(a.Doc)
251341
if err != nil {
252342
return err
@@ -260,15 +350,9 @@ func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
260350
c.docs[a.Key] = doc
261351

262352
case driver.Delete:
263-
if err := c.checkRevision(a.Doc, current); err != nil {
264-
return err
265-
}
266353
delete(c.docs, a.Key)
267354

268355
case driver.Update:
269-
if err := c.checkRevision(a.Doc, current); err != nil {
270-
return err
271-
}
272356
if err := c.update(current, a.Mods); err != nil {
273357
return err
274358
}
@@ -279,12 +363,6 @@ func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
279363
}
280364
}
281365

282-
case driver.Get:
283-
// We've already retrieved the document into current, above.
284-
// Now we copy its fields into the user-provided document.
285-
if err := decodeDoc(current, a.Doc, a.FieldPaths); err != nil {
286-
return err
287-
}
288366
default:
289367
return gcerr.Newf(gcerr.Internal, nil, "unknown kind %v", a.Kind)
290368
}

docstore/memdocstore/mem_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (*harness) BeforeQueryTypes() []interface{} { return nil }
5454

5555
func (*harness) RevisionsEqual(rev1, rev2 interface{}) bool { return rev1 == rev2 }
5656

57-
func (*harness) SupportsAtomicWrites() bool { return false }
57+
func (*harness) SupportsAtomicWrites() bool { return true }
5858

5959
func (*harness) Close() {}
6060

0 commit comments

Comments
 (0)