Skip to content

Commit 481af0a

Browse files
committed
Calculate idle timeout expiry time in Load rather than Commit; Add Expiry method
1 parent 796112f commit 481af0a

File tree

2 files changed

+102
-40
lines changed

2 files changed

+102
-40
lines changed

data.go

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,37 @@ const (
3131

3232
type sessionData struct {
3333
deadline time.Time
34+
expiry time.Time
3435
status Status
3536
token string
3637
values map[string]interface{}
3738
mu sync.Mutex
3839
}
3940

40-
func newSessionData(lifetime time.Duration) *sessionData {
41-
return &sessionData{
42-
deadline: time.Now().Add(lifetime).UTC(),
41+
func newSessionData(lifetime, idleTimeout time.Duration) *sessionData {
42+
now := time.Now()
43+
deadline := now.Add(lifetime).UTC()
44+
45+
sd := sessionData{
46+
deadline: deadline,
47+
expiry: deadline,
4348
status: Unmodified,
4449
values: make(map[string]interface{}),
4550
}
51+
52+
if idleTimeout > 0 {
53+
idleExpiry := time.Now().Add(idleTimeout).UTC()
54+
sd.expiry = minTime(sd.deadline, idleExpiry)
55+
}
56+
57+
return &sd
58+
}
59+
60+
func minTime(a, b time.Time) time.Time {
61+
if a.Before(b) {
62+
return a
63+
}
64+
return b
4665
}
4766

4867
// Load retrieves the session data for the given token from the session store,
@@ -57,28 +76,37 @@ func (s *SessionManager) Load(ctx context.Context, token string) (context.Contex
5776
}
5877

5978
if token == "" {
60-
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
79+
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime, s.IdleTimeout)), nil
6180
}
6281

6382
b, found, err := s.doStoreFind(ctx, token)
6483
if err != nil {
6584
return nil, err
6685
} else if !found {
67-
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
86+
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime, s.IdleTimeout)), nil
6887
}
6988

7089
sd := &sessionData{
7190
status: Unmodified,
7291
token: token,
7392
}
74-
if sd.deadline, sd.values, err = s.Codec.Decode(b); err != nil {
93+
94+
sd.deadline, sd.values, err = s.Codec.Decode(b)
95+
if err != nil {
7596
return nil, err
7697
}
7798

78-
// Mark the session data as modified if an idle timeout is being used. This
79-
// will force the session data to be re-committed to the session store with
80-
// a new expiry time.
99+
// By default, set the expiry time to the deadline.
100+
sd.expiry = sd.deadline
101+
102+
// If an idle timeout is being used, set the expiry to whichever comes first
103+
// between the session deadline and idleTimeout expiry. We also set the status
104+
// to Modified, which will force the session data to be re-committed to the
105+
// session store with the updated new expiry time.
81106
if s.IdleTimeout > 0 {
107+
idleExpiry := time.Now().Add(s.IdleTimeout).UTC()
108+
109+
sd.expiry = minTime(sd.deadline, idleExpiry)
82110
sd.status = Modified
83111
}
84112

@@ -108,19 +136,11 @@ func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error)
108136
return "", time.Time{}, err
109137
}
110138

111-
expiry := sd.deadline
112-
if s.IdleTimeout > 0 {
113-
ie := time.Now().Add(s.IdleTimeout).UTC()
114-
if ie.Before(expiry) {
115-
expiry = ie
116-
}
117-
}
118-
119-
if err := s.doStoreCommit(ctx, sd.token, b, expiry); err != nil {
139+
if err := s.doStoreCommit(ctx, sd.token, b, sd.expiry); err != nil {
120140
return "", time.Time{}, err
121141
}
122142

123-
return sd.token, expiry, nil
143+
return sd.token, sd.expiry, nil
124144
}
125145

126146
// Destroy deletes the session data from the session store and sets the session
@@ -591,6 +611,17 @@ func (s *SessionManager) SetDeadline(ctx context.Context, expire time.Time) {
591611
sd.status = Modified
592612
}
593613

614+
// Expiry returns the expiry time for the session. If you are not using an idle
615+
// timeout, the value returned will be the same as calling the Deadline method.
616+
func (s *SessionManager) Expiry(ctx context.Context) time.Time {
617+
sd := s.getSessionDataFromContext(ctx)
618+
619+
sd.mu.Lock()
620+
defer sd.mu.Unlock()
621+
622+
return sd.expiry
623+
}
624+
594625
// Token returns the session token. Please note that this will return the
595626
// empty string "" if it is called before the session has been committed to
596627
// the store.

data_test.go

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ func TestSessionManager_Load(T *testing.T) {
209209

210210
initialCtx := context.WithValue(context.Background(), s.ContextKey, &sessionData{
211211
deadline: expectedExpiry,
212+
expiry: expectedExpiry,
212213
token: expectedToken,
213214
values: map[string]interface{}{
214215
"blah": "blah",
@@ -256,6 +257,7 @@ func TestSessionManager_Commit(T *testing.T) {
256257

257258
ctx := context.WithValue(context.Background(), s.ContextKey, &sessionData{
258259
deadline: expectedExpiry,
260+
expiry: expectedExpiry,
259261
token: expectedToken,
260262
values: map[string]interface{}{
261263
"blah": "blah",
@@ -284,6 +286,7 @@ func TestSessionManager_Commit(T *testing.T) {
284286

285287
ctx := context.WithValue(context.Background(), s.ContextKey, &sessionData{
286288
deadline: expectedExpiry,
289+
expiry: expectedExpiry,
287290
token: expectedToken,
288291
values: map[string]interface{}{
289292
"blah": "blah",
@@ -337,9 +340,11 @@ func TestSessionManager_Commit(T *testing.T) {
337340

338341
store := &mockstore.MockStore{}
339342
expectedErr := errors.New("arbitrary")
343+
deadline := time.Now().Add(time.Hour)
340344

341345
sd := &sessionData{
342-
deadline: time.Now().Add(time.Hour),
346+
deadline: deadline,
347+
expiry: deadline,
343348
token: "example",
344349
values: map[string]interface{}{
345350
"blah": "blah",
@@ -375,6 +380,7 @@ func TestSessionManager_Commit(T *testing.T) {
375380

376381
ctx := context.WithValue(context.Background(), s.ContextKey, &sessionData{
377382
deadline: expectedExpiry,
383+
expiry: expectedExpiry,
378384
token: expectedToken,
379385
values: map[string]interface{}{
380386
"blah": "blah",
@@ -399,7 +405,7 @@ func TestPut(t *testing.T) {
399405
t.Parallel()
400406

401407
s := New()
402-
sd := newSessionData(time.Hour)
408+
sd := newSessionData(time.Hour, 0)
403409
ctx := s.addSessionDataToContext(context.Background(), sd)
404410

405411
s.Put(ctx, "foo", "bar")
@@ -417,7 +423,7 @@ func TestGet(t *testing.T) {
417423
t.Parallel()
418424

419425
s := New()
420-
sd := newSessionData(time.Hour)
426+
sd := newSessionData(time.Hour, 0)
421427
sd.values["foo"] = "bar"
422428
ctx := s.addSessionDataToContext(context.Background(), sd)
423429

@@ -435,7 +441,7 @@ func TestPop(t *testing.T) {
435441
t.Parallel()
436442

437443
s := New()
438-
sd := newSessionData(time.Hour)
444+
sd := newSessionData(time.Hour, 0)
439445
sd.values["foo"] = "bar"
440446
ctx := s.addSessionDataToContext(context.Background(), sd)
441447

@@ -462,7 +468,7 @@ func TestRemove(t *testing.T) {
462468
t.Parallel()
463469

464470
s := New()
465-
sd := newSessionData(time.Hour)
471+
sd := newSessionData(time.Hour, 0)
466472
sd.values["foo"] = "bar"
467473
ctx := s.addSessionDataToContext(context.Background(), sd)
468474

@@ -481,7 +487,7 @@ func TestClear(t *testing.T) {
481487
t.Parallel()
482488

483489
s := New()
484-
sd := newSessionData(time.Hour)
490+
sd := newSessionData(time.Hour, 0)
485491
sd.values["foo"] = "bar"
486492
sd.values["baz"] = "boz"
487493
ctx := s.addSessionDataToContext(context.Background(), sd)
@@ -507,7 +513,7 @@ func TestExists(t *testing.T) {
507513
t.Parallel()
508514

509515
s := New()
510-
sd := newSessionData(time.Hour)
516+
sd := newSessionData(time.Hour, 0)
511517
sd.values["foo"] = "bar"
512518
ctx := s.addSessionDataToContext(context.Background(), sd)
513519

@@ -524,7 +530,7 @@ func TestKeys(t *testing.T) {
524530
t.Parallel()
525531

526532
s := New()
527-
sd := newSessionData(time.Hour)
533+
sd := newSessionData(time.Hour, 0)
528534
sd.values["foo"] = "bar"
529535
sd.values["woo"] = "waa"
530536
ctx := s.addSessionDataToContext(context.Background(), sd)
@@ -539,7 +545,7 @@ func TestGetString(t *testing.T) {
539545
t.Parallel()
540546

541547
s := New()
542-
sd := newSessionData(time.Hour)
548+
sd := newSessionData(time.Hour, 0)
543549
sd.values["foo"] = "bar"
544550
ctx := s.addSessionDataToContext(context.Background(), sd)
545551

@@ -558,7 +564,7 @@ func TestGetBool(t *testing.T) {
558564
t.Parallel()
559565

560566
s := New()
561-
sd := newSessionData(time.Hour)
567+
sd := newSessionData(time.Hour, 0)
562568
sd.values["foo"] = true
563569
ctx := s.addSessionDataToContext(context.Background(), sd)
564570

@@ -577,7 +583,7 @@ func TestGetInt(t *testing.T) {
577583
t.Parallel()
578584

579585
s := New()
580-
sd := newSessionData(time.Hour)
586+
sd := newSessionData(time.Hour, 0)
581587
sd.values["foo"] = 123
582588
ctx := s.addSessionDataToContext(context.Background(), sd)
583589

@@ -596,7 +602,7 @@ func TestGetFloat(t *testing.T) {
596602
t.Parallel()
597603

598604
s := New()
599-
sd := newSessionData(time.Hour)
605+
sd := newSessionData(time.Hour, 0)
600606
sd.values["foo"] = 123.456
601607
ctx := s.addSessionDataToContext(context.Background(), sd)
602608

@@ -615,7 +621,7 @@ func TestGetBytes(t *testing.T) {
615621
t.Parallel()
616622

617623
s := New()
618-
sd := newSessionData(time.Hour)
624+
sd := newSessionData(time.Hour, 0)
619625
sd.values["foo"] = []byte("bar")
620626
ctx := s.addSessionDataToContext(context.Background(), sd)
621627

@@ -636,7 +642,7 @@ func TestGetTime(t *testing.T) {
636642
now := time.Now()
637643

638644
s := New()
639-
sd := newSessionData(time.Hour)
645+
sd := newSessionData(time.Hour, 0)
640646
sd.values["foo"] = now
641647
ctx := s.addSessionDataToContext(context.Background(), sd)
642648

@@ -655,7 +661,7 @@ func TestPopString(t *testing.T) {
655661
t.Parallel()
656662

657663
s := New()
658-
sd := newSessionData(time.Hour)
664+
sd := newSessionData(time.Hour, 0)
659665
sd.values["foo"] = "bar"
660666
ctx := s.addSessionDataToContext(context.Background(), sd)
661667

@@ -683,7 +689,7 @@ func TestPopBool(t *testing.T) {
683689
t.Parallel()
684690

685691
s := New()
686-
sd := newSessionData(time.Hour)
692+
sd := newSessionData(time.Hour, 0)
687693
sd.values["foo"] = true
688694
ctx := s.addSessionDataToContext(context.Background(), sd)
689695

@@ -711,7 +717,7 @@ func TestPopInt(t *testing.T) {
711717
t.Parallel()
712718

713719
s := New()
714-
sd := newSessionData(time.Hour)
720+
sd := newSessionData(time.Hour, 0)
715721
sd.values["foo"] = 123
716722
ctx := s.addSessionDataToContext(context.Background(), sd)
717723

@@ -739,7 +745,7 @@ func TestPopFloat(t *testing.T) {
739745
t.Parallel()
740746

741747
s := New()
742-
sd := newSessionData(time.Hour)
748+
sd := newSessionData(time.Hour, 0)
743749
sd.values["foo"] = 123.456
744750
ctx := s.addSessionDataToContext(context.Background(), sd)
745751

@@ -767,7 +773,7 @@ func TestPopBytes(t *testing.T) {
767773
t.Parallel()
768774

769775
s := New()
770-
sd := newSessionData(time.Hour)
776+
sd := newSessionData(time.Hour, 0)
771777
sd.values["foo"] = []byte("bar")
772778
ctx := s.addSessionDataToContext(context.Background(), sd)
773779

@@ -795,7 +801,7 @@ func TestPopTime(t *testing.T) {
795801

796802
now := time.Now()
797803
s := New()
798-
sd := newSessionData(time.Hour)
804+
sd := newSessionData(time.Hour, 0)
799805
sd.values["foo"] = now
800806
ctx := s.addSessionDataToContext(context.Background(), sd)
801807

@@ -824,7 +830,7 @@ func TestStatus(t *testing.T) {
824830
t.Parallel()
825831

826832
s := New()
827-
sd := newSessionData(time.Hour)
833+
sd := newSessionData(time.Hour, 0)
828834
ctx := s.addSessionDataToContext(context.Background(), sd)
829835

830836
status := s.Status(ctx)
@@ -847,3 +853,28 @@ func TestStatus(t *testing.T) {
847853
t.Errorf("got %d: expected %d", status, Destroyed)
848854
}
849855
}
856+
857+
func TestDeadlineAndExpiry(t *testing.T) {
858+
t.Parallel()
859+
860+
s := New()
861+
sd := newSessionData(time.Hour, 0)
862+
ctx := s.addSessionDataToContext(context.Background(), sd)
863+
864+
expiry := s.Expiry(ctx)
865+
deadline := s.Deadline(ctx)
866+
867+
if !expiry.Equal(deadline) {
868+
t.Errorf("got %v: expected %v", expiry, deadline)
869+
}
870+
871+
sd = newSessionData(time.Hour, time.Minute)
872+
ctx = s.addSessionDataToContext(context.Background(), sd)
873+
874+
expiry = s.Expiry(ctx)
875+
deadline = s.Deadline(ctx)
876+
877+
if !expiry.Before(deadline) {
878+
t.Errorf("got %v: expected before %v", expiry, deadline)
879+
}
880+
}

0 commit comments

Comments
 (0)