|
| 1 | +package cmd |
| 2 | + |
| 3 | +import ( |
| 4 | + "HelixDB/common" |
| 5 | + "HelixDB/db" |
| 6 | + "errors" |
| 7 | + "reflect" |
| 8 | + "testing" |
| 9 | +) |
| 10 | + |
| 11 | +func TestTTL(t *testing.T) { |
| 12 | + // Key with no TTL |
| 13 | + _, _ = Set(common.Cmd{Name: "SET", Args: []string{"ttl_persist", "v"}}) |
| 14 | + // Key with TTL (~100 seconds from now) |
| 15 | + _, _ = Set(common.Cmd{Name: "SET", Args: []string{"ttl_expiring", "v", "EX", "100"}}) |
| 16 | + // Expired key (TTL set to the past) |
| 17 | + db.DB.Store("ttl_expired", db.NewValue("v")) |
| 18 | + db.KeyTTL.Store("ttl_expired", int64(1)) // epoch ms well in the past |
| 19 | + |
| 20 | + tests := []struct { |
| 21 | + command common.Cmd |
| 22 | + want []byte |
| 23 | + wantErr error |
| 24 | + }{ |
| 25 | + // Key with no TTL returns -1 |
| 26 | + {common.Cmd{Name: "TTL", Args: []string{"ttl_persist"}}, common.RespInteger(-1), nil}, |
| 27 | + // Non-existent key returns -2 |
| 28 | + {common.Cmd{Name: "TTL", Args: []string{"no_such_key"}}, common.RespInteger(-2), nil}, |
| 29 | + // Expired key returns -2 |
| 30 | + {common.Cmd{Name: "TTL", Args: []string{"ttl_expired"}}, common.RespInteger(-2), nil}, |
| 31 | + // Wrong number of arguments |
| 32 | + {common.Cmd{Name: "TTL"}, common.RespError("wrong number of arguments for 'ttl' command"), common.ErrWrongNumberOfArgs}, |
| 33 | + {common.Cmd{Name: "TTL", Args: []string{"a", "b"}}, common.RespError("wrong number of arguments for 'ttl' command"), common.ErrWrongNumberOfArgs}, |
| 34 | + } |
| 35 | + for _, test := range tests { |
| 36 | + if got, gotErr := TTL(test.command); !reflect.DeepEqual(got, test.want) || !errors.Is(gotErr, test.wantErr) { |
| 37 | + t.Errorf("TTL(%v) = %v, %v; want %v, %v", test.command, got, gotErr, test.want, test.wantErr) |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + // Key with TTL returns a positive integer (exact value is timing-sensitive) |
| 42 | + got, gotErr := TTL(common.Cmd{Name: "TTL", Args: []string{"ttl_expiring"}}) |
| 43 | + if gotErr != nil { |
| 44 | + t.Errorf("TTL(ttl_expiring) unexpected error: %v", gotErr) |
| 45 | + } |
| 46 | + if reflect.DeepEqual(got, common.RespInteger(-1)) || reflect.DeepEqual(got, common.RespInteger(-2)) { |
| 47 | + t.Errorf("TTL(ttl_expiring) = %v; want a positive integer", got) |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +func TestPTTL(t *testing.T) { |
| 52 | + // Key with no TTL |
| 53 | + _, _ = Set(common.Cmd{Name: "SET", Args: []string{"pttl_persist", "v"}}) |
| 54 | + // Key with TTL (~100 seconds from now) |
| 55 | + _, _ = Set(common.Cmd{Name: "SET", Args: []string{"pttl_expiring", "v", "EX", "100"}}) |
| 56 | + // Expired key |
| 57 | + db.DB.Store("pttl_expired", db.NewValue("v")) |
| 58 | + db.KeyTTL.Store("pttl_expired", int64(1)) |
| 59 | + |
| 60 | + tests := []struct { |
| 61 | + command common.Cmd |
| 62 | + want []byte |
| 63 | + wantErr error |
| 64 | + }{ |
| 65 | + // Key with no TTL returns -1 |
| 66 | + {common.Cmd{Name: "PTTL", Args: []string{"pttl_persist"}}, common.RespInteger(-1), nil}, |
| 67 | + // Non-existent key returns -2 |
| 68 | + {common.Cmd{Name: "PTTL", Args: []string{"no_such_key"}}, common.RespInteger(-2), nil}, |
| 69 | + // Expired key returns -2 |
| 70 | + {common.Cmd{Name: "PTTL", Args: []string{"pttl_expired"}}, common.RespInteger(-2), nil}, |
| 71 | + // Wrong number of arguments |
| 72 | + {common.Cmd{Name: "PTTL"}, common.RespError("wrong number of arguments for 'pttl' command"), common.ErrWrongNumberOfArgs}, |
| 73 | + {common.Cmd{Name: "PTTL", Args: []string{"a", "b"}}, common.RespError("wrong number of arguments for 'pttl' command"), common.ErrWrongNumberOfArgs}, |
| 74 | + } |
| 75 | + for _, test := range tests { |
| 76 | + if got, gotErr := PTTL(test.command); !reflect.DeepEqual(got, test.want) || !errors.Is(gotErr, test.wantErr) { |
| 77 | + t.Errorf("PTTL(%v) = %v, %v; want %v, %v", test.command, got, gotErr, test.want, test.wantErr) |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + // Key with TTL returns a positive integer |
| 82 | + got, gotErr := PTTL(common.Cmd{Name: "PTTL", Args: []string{"pttl_expiring"}}) |
| 83 | + if gotErr != nil { |
| 84 | + t.Errorf("PTTL(pttl_expiring) unexpected error: %v", gotErr) |
| 85 | + } |
| 86 | + if reflect.DeepEqual(got, common.RespInteger(-1)) || reflect.DeepEqual(got, common.RespInteger(-2)) { |
| 87 | + t.Errorf("PTTL(pttl_expiring) = %v; want a positive integer", got) |
| 88 | + } |
| 89 | +} |
0 commit comments