diff --git a/subscription.go b/subscription.go index 581efb6..adc6e85 100644 --- a/subscription.go +++ b/subscription.go @@ -15,6 +15,7 @@ package ro import ( + "context" "sync" "github.com/samber/lo" @@ -26,11 +27,13 @@ import ( // It is part of a Subscription, and is returned by the Observable creation. // It will be called only once, when the Subscription is canceled. type Teardown func() +type TeardownWithContext func(ctx context.Context) // Unsubscribable represents any type that can be unsubscribed from. // It provides a common interface for cancellation operations. type Unsubscribable interface { Unsubscribe() + UnsubscribeWithContext(ctx context.Context) } // Subscription represents an ongoing execution of an `Observable`, and has @@ -39,33 +42,51 @@ type Subscription interface { Unsubscribable Add(teardown Teardown) + AddWithContext(teardown TeardownWithContext) AddUnsubscribable(unsubscribable Unsubscribable) IsClosed() bool Wait() // Note: using .Wait() is not recommended. } +type subscriptionImpl struct { + done bool + mu sync.Mutex + finalizers []Teardown + ctxFinalizers []TeardownWithContext + +} + var _ Subscription = (*subscriptionImpl)(nil) // NewSubscription creates a new Subscription. When `teardown` is nil, nothing // is added. When the subscription is already disposed, the `teardown` callback // is triggered immediately. func NewSubscription(teardown Teardown) Subscription { - teardowns := []func(){} + s := &subscriptionImpl{ + finalizers: []Teardown{}, + ctxFinalizers: []TeardownWithContext{}, + } if teardown != nil { - teardowns = append(teardowns, teardown) + s.finalizers = append(s.finalizers, teardown) } - return &subscriptionImpl{ - done: false, - mu: sync.Mutex{}, - finalizers: teardowns, - } + return s } -type subscriptionImpl struct { - done bool - mu sync.Mutex // Should be a RWMutex because of the .IsClosed() method, but sync.RWMutex is 30% slower. - finalizers []func() +func NewSubscriptionWithContext(teardown TeardownWithContext) Subscription { + s := &subscriptionImpl{ + finalizers: []Teardown{}, + ctxFinalizers: []TeardownWithContext{}, + + } + + if teardown != nil { + s.ctxFinalizers = append(s.ctxFinalizers, teardown) + } + + + + return s } // Add receives a finalizer to execute upon unsubscription. When `teardown` @@ -84,10 +105,29 @@ func (s *subscriptionImpl) Add(teardown Teardown) { defer s.mu.Unlock() if s.done { - teardown() // not protected against panics - } else { - s.finalizers = append(s.finalizers, teardown) + _ = execFinalizer(teardown) + return + } + + s.finalizers = append(s.finalizers, teardown) +} + +// AddWithContext registers a teardown function that receives a context when +// the subscription is unsubscribed. +func (s *subscriptionImpl) AddWithContext(teardown TeardownWithContext) { + if teardown == nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.done { + _ = execFinalizerWithContext(teardown, context.Background()) + return } + + s.ctxFinalizers = append(s.ctxFinalizers, teardown) } // AddUnsubscribable merges multiple subscriptions into one. The method does nothing @@ -101,7 +141,9 @@ func (s *subscriptionImpl) AddUnsubscribable(unsubscribable Unsubscribable) { return } - s.Add(unsubscribable.Unsubscribe) + s.Add(func() { + unsubscribable.Unsubscribe() + }) } // Unsubscribe disposes the resources held by the subscription. May, for @@ -120,35 +162,74 @@ func (s *subscriptionImpl) Unsubscribe() { } s.done = true + finals := s.finalizers + ctxFinals := s.ctxFinalizers + s.finalizers = nil + s.ctxFinalizers = nil + s.mu.Unlock() - if len(s.finalizers) == 0 { + var errs []error + + // Execute simple teardowns + for _, f := range finals { + if err := execFinalizer(f); err != nil { + errs = append(errs, err) + } + } + + // Execute context teardowns with a background context + for _, f := range ctxFinals { + if err := execFinalizerWithContext(f, context.Background()); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + panic(xerrors.Join(errs...)) + } +} + +// UnsubscribeWithContext cancels the subscription and executes all registered +// teardown functions with the provided context. This allows cancellation-aware +// cleanup logic (e.g. context timeout or cancellation). + +func (s *subscriptionImpl) UnsubscribeWithContext(ctx context.Context) { + + s.mu.Lock() + + if s.done { s.mu.Unlock() return } - finalizers := s.finalizers - s.finalizers = make([]func(), 0) + s.done = true + finals := s.finalizers + ctxFinals := s.ctxFinalizers + s.finalizers = nil + s.ctxFinalizers = nil s.mu.Unlock() var errs []error + // Execute simple teardowns + for _, f := range finals { + if err := execFinalizer(f); err != nil { + errs = append(errs, err) + } + } - // Note: we prefer not running this in parallel. - for i := range finalizers { - err := execFinalizer(finalizers[i]) // protected against panics - if err != nil { - // OnUnhandledError(err) + // Execute context teardowns with provided context + for _, f := range ctxFinals { + if err := execFinalizerWithContext(f, ctx); err != nil { errs = append(errs, err) } } - // Error is triggered after the recursive call to finalizers - // because we want to execute all finalizers before panicking. if len(errs) > 0 { - // errors.Join has been introduced in go 1.20 panic(xerrors.Join(errs...)) } } + // IsClosed returns true if the subscription has been disposed // or if unsubscription is in progress. // @@ -187,9 +268,6 @@ func execFinalizer(finalizer func()) (err error) { lo.TryCatchWithErrorValue( func() error { finalizer() - - err = nil - return nil }, func(e any) { @@ -200,6 +278,35 @@ func execFinalizer(finalizer func()) (err error) { return err } +func execFinalizerWithContext(finalizer any, ctx context.Context) (err error) { + switch f := finalizer.(type) { + case func(): + return execFinalizer(f) + case func(context.Context): + lo.TryCatchWithErrorValue( + func() error { + f(ctx) + return nil + }, + func(e any) { + err = newUnsubscriptionError(recoverValueToError(e)) + }, + ) + case TeardownWithContext: + lo.TryCatchWithErrorValue( + func() error { + f(ctx) + return nil + }, + func(e any) { + err = newUnsubscriptionError(recoverValueToError(e)) + }, + ) + } + return err +} + + // @TODO: Add methods Remove + RemoveSubscription. // Currently, Go does not support function address comparison, so we cannot // remove a finalizer from the list. diff --git a/subscription_test.go b/subscription_test.go index 67cb2f1..d6391a4 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -15,6 +15,7 @@ package ro import ( + "context" "errors" "sync" "sync/atomic" @@ -552,3 +553,10 @@ func (m *mockUnsubscribable) Unsubscribe() { m.unsubscribe() } } +func (m *mockUnsubscribable) UnsubscribeWithContext(ctx context.Context) { + if m.unsubscribe != nil { + m.unsubscribe() + } +} + +