Skip to content

Commit d35785a

Browse files
authored
Merge pull request #18 from my3rs/fix#16
fix: #16
2 parents bf3a665 + 50a1900 commit d35785a

File tree

2 files changed

+163
-1
lines changed

2 files changed

+163
-1
lines changed

claims.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,17 @@ func Merge(claims any, other any) []byte {
296296
return nil
297297
}
298298

299+
// Return the serialized claims if `other` is nil.
300+
if other == nil {
301+
return claimsB
302+
}
303+
299304
otherB, err := Marshal(other)
300305
if err != nil {
301306
return nil
302307
}
303308

304-
if len(otherB) == 0 {
309+
if len(otherB) == 0 || string(otherB) == "{}" {
305310
return claimsB
306311
}
307312

claims_test.go

+157
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package jwt
22

33
import (
4+
"encoding/json"
5+
"fmt"
46
"reflect"
57
"testing"
68
"time"
@@ -162,3 +164,158 @@ func TestClaimsSubAsInt(t *testing.T) {
162164
t.Fatalf("expected: %#+v but got: %#+v\n", expectedClaims, verifiedToken.StandardClaims)
163165
}
164166
}
167+
168+
func TestMerge(t *testing.T) {
169+
now := time.Now().Unix()
170+
expiry := time.Now().Add(15 * time.Minute).Unix()
171+
172+
tests := []struct {
173+
name string
174+
claims interface{}
175+
other interface{}
176+
expected map[string]interface{}
177+
}{
178+
{
179+
name: "merge with empty object",
180+
claims: map[string]interface{}{"foo": "bar"},
181+
other: map[string]interface{}{},
182+
expected: map[string]interface{}{"foo": "bar"},
183+
},
184+
{
185+
name: "merge two maps",
186+
claims: map[string]interface{}{"foo": "bar"},
187+
other: map[string]interface{}{"baz": "qux"},
188+
expected: map[string]interface{}{"foo": "bar", "baz": "qux"},
189+
},
190+
{
191+
name: "merge with Claims struct",
192+
claims: map[string]interface{}{
193+
"custom": "value",
194+
},
195+
other: Claims{
196+
Issuer: "test-issuer",
197+
IssuedAt: now,
198+
Expiry: expiry,
199+
},
200+
expected: map[string]interface{}{
201+
"custom": "value",
202+
"iss": "test-issuer",
203+
"iat": float64(now), // JSON numbers are decoded as float64
204+
"exp": float64(expiry),
205+
},
206+
},
207+
{
208+
name: "merge with nil",
209+
claims: map[string]interface{}{"foo": "bar"},
210+
other: nil,
211+
expected: map[string]interface{}{"foo": "bar"},
212+
},
213+
}
214+
215+
for _, tt := range tests {
216+
t.Run(tt.name, func(t *testing.T) {
217+
result := Merge(tt.claims, tt.other)
218+
219+
// 解码结果
220+
var got map[string]interface{}
221+
if err := json.Unmarshal(result, &got); err != nil {
222+
t.Fatalf("Failed to unmarshal result: %v", err)
223+
}
224+
225+
// 比较解码后的结果
226+
if !reflect.DeepEqual(got, tt.expected) {
227+
t.Errorf("Merge() = %v, want %v", got, tt.expected)
228+
}
229+
})
230+
}
231+
}
232+
233+
func TestMergeAndSign(t *testing.T) {
234+
now := time.Now().Unix()
235+
expiry := time.Now().Add(15 * time.Minute).Unix()
236+
237+
tests := []struct {
238+
name string
239+
claims interface{}
240+
other interface{}
241+
expected map[string]interface{}
242+
}{
243+
{
244+
name: "merge and sign with empty object",
245+
claims: map[string]interface{}{"foo": "bar"},
246+
other: map[string]interface{}{},
247+
expected: map[string]interface{}{"foo": "bar"},
248+
},
249+
{
250+
name: "merge and sign two maps",
251+
claims: map[string]interface{}{"foo": "bar"},
252+
other: map[string]interface{}{"baz": "qux"},
253+
expected: map[string]interface{}{"foo": "bar", "baz": "qux"},
254+
},
255+
{
256+
name: "merge and sign with Claims struct",
257+
claims: map[string]interface{}{
258+
"custom": "value",
259+
},
260+
other: Claims{
261+
Issuer: "test-issuer",
262+
IssuedAt: now,
263+
Expiry: expiry,
264+
},
265+
expected: map[string]interface{}{
266+
"custom": "value",
267+
"iss": "test-issuer",
268+
"iat": fmt.Sprintf("%d", now),
269+
"exp": fmt.Sprintf("%d", expiry),
270+
},
271+
},
272+
}
273+
274+
key := []byte("secret")
275+
for _, tt := range tests {
276+
t.Run(tt.name, func(t *testing.T) {
277+
// 合并 claims
278+
mergedClaims := Merge(tt.claims, tt.other)
279+
280+
// 使用合并后的 claims 生成 token
281+
token, err := Sign(HS256, key, mergedClaims)
282+
if err != nil {
283+
t.Fatalf("Failed to sign token: %v", err)
284+
}
285+
286+
// 打印生成的 token
287+
t.Logf("Generated token: %s", string(token))
288+
289+
// 验证并解析 token
290+
var verifiedClaims map[string]interface{}
291+
verifiedToken, err := Verify(HS256, key, token)
292+
if err != nil {
293+
t.Fatalf("Failed to verify token: %v", err)
294+
}
295+
296+
err = verifiedToken.Claims(&verifiedClaims)
297+
if err != nil {
298+
t.Fatalf("Failed to get claims from token: %v", err)
299+
}
300+
301+
// 将 json.Number 转换为字符串
302+
if exp, ok := verifiedClaims["exp"].(json.Number); ok {
303+
verifiedClaims["exp"] = exp.String()
304+
}
305+
if iat, ok := verifiedClaims["iat"].(json.Number); ok {
306+
verifiedClaims["iat"] = iat.String()
307+
}
308+
309+
// 打印类型信息
310+
t.Logf("Expected exp type: %T, value: %v", tt.expected["exp"], tt.expected["exp"])
311+
t.Logf("Actual exp type: %T, value: %v", verifiedClaims["exp"], verifiedClaims["exp"])
312+
t.Logf("Expected iat type: %T, value: %v", tt.expected["iat"], tt.expected["iat"])
313+
t.Logf("Actual iat type: %T, value: %v", verifiedClaims["iat"], verifiedClaims["iat"])
314+
315+
// 比较解析后的 claims 是否与预期一致
316+
if !reflect.DeepEqual(verifiedClaims, tt.expected) {
317+
t.Errorf("Claims after merge and verify = %#v, want %#v", verifiedClaims, tt.expected)
318+
}
319+
})
320+
}
321+
}

0 commit comments

Comments
 (0)