|
4 | 4 | package keysetpagination |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "crypto/rand" |
7 | 8 | "database/sql" |
| 9 | + "encoding/base64" |
8 | 10 | "encoding/json" |
| 11 | + "io" |
9 | 12 | "time" |
10 | 13 |
|
11 | 14 | "github.com/gofrs/uuid" |
12 | | - "github.com/pkg/errors" |
13 | | - "github.com/ssoready/hyrumtoken" |
14 | | - |
15 | 15 | "github.com/ory/herodot" |
| 16 | + "github.com/pkg/errors" |
| 17 | + "golang.org/x/crypto/nacl/secretbox" |
16 | 18 | ) |
17 | 19 |
|
18 | 20 | var fallbackEncryptionKey = &[32]byte{} |
@@ -58,7 +60,16 @@ func (t PageToken) Encrypt(keys [][32]byte) string { |
58 | 60 | if len(keys) > 0 { |
59 | 61 | key = &keys[0] |
60 | 62 | } |
61 | | - return hyrumtoken.Marshal(key, t) |
| 63 | + enc, err := t.encrypt(key) |
| 64 | + if err != nil { |
| 65 | + // This should basically never happen, only if reading from the random source or marshaling the token fails. |
| 66 | + // In both cases, we have a bigger problem than just not being able to generate the page token. |
| 67 | + // Therefore, if we do get an error, there is no point in returning it to the client, |
| 68 | + // as we already have a working result set. With this string, the next page will return an error, |
| 69 | + // but that is better than breaking the current page. |
| 70 | + return "internal error: failed to generate page token" |
| 71 | + } |
| 72 | + return enc |
62 | 73 | } |
63 | 74 |
|
64 | 75 | func (t PageToken) MarshalJSON() ([]byte, error) { |
@@ -109,7 +120,7 @@ func (t PageToken) MarshalJSON() ([]byte, error) { |
109 | 120 | return json.Marshal(toEncode) |
110 | 121 | } |
111 | 122 |
|
112 | | -var ErrPageTokenExpired = herodot.ErrBadRequest.WithReason("page token expired, do not persist page tokens") |
| 123 | +var ErrPageTokenExpired = herodot.ErrBadRequest.WithError("page token expired, do not persist page tokens") |
113 | 124 |
|
114 | 125 | func (t *PageToken) UnmarshalJSON(data []byte) error { |
115 | 126 | rawToken := jsonPageToken{} |
@@ -149,3 +160,46 @@ func (t *PageToken) UnmarshalJSON(data []byte) error { |
149 | 160 | } |
150 | 161 |
|
151 | 162 | func NewPageToken(cols ...Column) PageToken { return PageToken{cols: cols} } |
| 163 | + |
| 164 | +func (t *PageToken) encrypt(key *[32]byte) (string, error) { |
| 165 | + var nonce [24]byte |
| 166 | + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { |
| 167 | + return "", errors.Wrap(err, "cannot seed nonce") |
| 168 | + } |
| 169 | + |
| 170 | + raw, err := json.Marshal(t) |
| 171 | + if err != nil { |
| 172 | + return "", errors.Wrap(err, "cannot marshal page token") |
| 173 | + } |
| 174 | + |
| 175 | + enc := secretbox.Seal(nonce[:], raw, &nonce, key) |
| 176 | + return base64.URLEncoding.EncodeToString(enc), nil |
| 177 | +} |
| 178 | + |
| 179 | +func (t *PageToken) decrypt(key *[32]byte, s string) error { |
| 180 | + if s == "" { |
| 181 | + return errors.WithStack(ErrInvalidPaginationToken) |
| 182 | + } |
| 183 | + |
| 184 | + raw, err := base64.URLEncoding.DecodeString(s) |
| 185 | + if err != nil || len(raw) < 24 { |
| 186 | + return errors.WithStack(ErrInvalidPaginationToken) |
| 187 | + } |
| 188 | + |
| 189 | + var nonce [24]byte |
| 190 | + copy(nonce[:], raw[:24]) |
| 191 | + |
| 192 | + dec, ok := secretbox.Open(nil, raw[24:], &nonce, key) |
| 193 | + if !ok { |
| 194 | + return errors.WithStack(ErrInvalidPaginationToken) |
| 195 | + } |
| 196 | + |
| 197 | + if err := json.Unmarshal(dec, t); err != nil { |
| 198 | + if errors.As(err, new(*herodot.DefaultError)) { |
| 199 | + return err |
| 200 | + } |
| 201 | + return errors.WithStack(herodot.ErrInternalServerError.WithReason("unable to unmarshal page token").WithDebug(err.Error())) |
| 202 | + } |
| 203 | + |
| 204 | + return nil |
| 205 | +} |
0 commit comments