Skip to content

Commit 7093998

Browse files
authored
Merge pull request #86 from smallstep/cli/jose
Add Encrypt helper
2 parents 80b1f3d + e324992 commit 7093998

File tree

5 files changed

+235
-26
lines changed

5 files changed

+235
-26
lines changed

jose/encrypt.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,28 @@ type PasswordPrompter func(s string) ([]byte, error)
1919
// the parse of the key will fail.
2020
var PromptPassword PasswordPrompter
2121

22-
// EncryptJWK returns the given JWK encrypted with the default encryption
22+
// Encrypt returns the given data encrypted with the default encryption
2323
// algorithm (PBES2-HS256+A128KW).
24-
func EncryptJWK(jwk *JSONWebKey, passphrase []byte) (*JSONWebEncryption, error) {
25-
b, err := json.Marshal(jwk)
24+
func Encrypt(data []byte, opts ...Option) (*JSONWebEncryption, error) {
25+
ctx, err := new(context).apply(opts...)
2626
if err != nil {
27-
return nil, errors.Wrap(err, "error marshaling JWK")
27+
return nil, err
28+
}
29+
30+
var passphrase []byte
31+
switch {
32+
case len(ctx.password) > 0:
33+
passphrase = ctx.password
34+
case ctx.passwordPrompter != nil:
35+
if passphrase, err = ctx.passwordPrompter(ctx.passwordPrompt); err != nil {
36+
return nil, err
37+
}
38+
case PromptPassword != nil:
39+
if passphrase, err = PromptPassword("Please enter the password to encrypt the data"); err != nil {
40+
return nil, err
41+
}
42+
default:
43+
return nil, errors.New("failed to encrypt the data: missing password")
2844
}
2945

3046
salt, err := randutil.Salt(PBKDF2SaltSize)
@@ -40,22 +56,35 @@ func EncryptJWK(jwk *JSONWebKey, passphrase []byte) (*JSONWebEncryption, error)
4056
PBES2Salt: salt,
4157
}
4258

43-
opts := new(EncrypterOptions)
44-
opts.WithContentType(ContentType("jwk+json"))
59+
encrypterOptions := new(EncrypterOptions)
60+
if ctx.contentType != "" {
61+
encrypterOptions.WithContentType(ContentType(ctx.contentType))
62+
}
4563

46-
encrypter, err := NewEncrypter(DefaultEncAlgorithm, recipient, opts)
64+
encrypter, err := NewEncrypter(DefaultEncAlgorithm, recipient, encrypterOptions)
4765
if err != nil {
4866
return nil, errors.Wrap(err, "error creating cipher")
4967
}
5068

51-
jwe, err := encrypter.Encrypt(b)
69+
jwe, err := encrypter.Encrypt(data)
5270
if err != nil {
5371
return nil, errors.Wrap(err, "error encrypting data")
5472
}
5573

5674
return jwe, nil
5775
}
5876

77+
// EncryptJWK returns the given JWK encrypted with the default encryption
78+
// algorithm (PBES2-HS256+A128KW).
79+
func EncryptJWK(jwk *JSONWebKey, passphrase []byte) (*JSONWebEncryption, error) {
80+
b, err := json.Marshal(jwk)
81+
if err != nil {
82+
return nil, errors.Wrap(err, "error marshaling JWK")
83+
}
84+
85+
return Encrypt(b, WithPassword(passphrase), WithContentType("jwk+json"))
86+
}
87+
5988
// Decrypt returns the decrypted version of the given data if it's encrypted,
6089
// it will return the raw data if it's not encrypted or the format is not
6190
// valid.
@@ -82,11 +111,16 @@ func Decrypt(data []byte, opts ...Option) ([]byte, error) {
82111
if ctx.passwordPrompter != nil || PromptPassword != nil {
83112
var pass []byte
84113
for i := 0; i < MaxDecryptTries; i++ {
85-
if ctx.passwordPrompter != nil {
114+
switch {
115+
case ctx.passwordPrompter != nil:
86116
if pass, err = ctx.passwordPrompter(ctx.passwordPrompt); err != nil {
87117
return nil, err
88118
}
89-
} else {
119+
case ctx.filename != "":
120+
if pass, err = PromptPassword("Please enter the password to decrypt " + ctx.filename); err != nil {
121+
return nil, err
122+
}
123+
default:
90124
if pass, err = PromptPassword("Please enter the password to decrypt the JWE"); err != nil {
91125
return nil, err
92126
}

jose/encrypt_test.go

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,112 @@ func rsaEqual(priv *rsa.PrivateKey, x crypto.PrivateKey) bool {
114114
return true
115115
}
116116

117+
func TestEncrypt(t *testing.T) {
118+
jwk := fixJWK(mustGenerateJWK(t, "EC", "P-256", "ES256", "", "", 0))
119+
data, err := json.Marshal(jwk)
120+
if err != nil {
121+
t.Fatal(err)
122+
}
123+
124+
type args struct {
125+
data []byte
126+
opts []Option
127+
}
128+
tests := []struct {
129+
name string
130+
args args
131+
wantFn func(t *testing.T) *JSONWebEncryption
132+
wantErr bool
133+
}{
134+
{"ok", args{data, []Option{WithPassword([]byte("password")), WithContentType("jwk+json")}},
135+
func(t *testing.T) *JSONWebEncryption {
136+
reader := mustTeeReader(t)
137+
jwe := mustEncryptJWK(t, jwk, []byte("password"))
138+
rand.Reader = reader
139+
jose.RandReader = reader
140+
return jwe
141+
}, false},
142+
{"ok WithPasswordPrompter", args{data, []Option{
143+
WithContentType("jwk+json"),
144+
WithPasswordPrompter("Enter the password", func(s string) ([]byte, error) {
145+
return []byte("password"), nil
146+
})}},
147+
func(t *testing.T) *JSONWebEncryption {
148+
reader := mustTeeReader(t)
149+
jwe := mustEncryptJWK(t, jwk, []byte("password"))
150+
rand.Reader = reader
151+
jose.RandReader = reader
152+
return jwe
153+
}, false},
154+
{"ok with PromptPassword", args{data, []Option{WithContentType("jwk+json")}},
155+
func(t *testing.T) *JSONWebEncryption {
156+
tmp := PromptPassword
157+
t.Cleanup(func() { PromptPassword = tmp })
158+
PromptPassword = func(s string) ([]byte, error) {
159+
return []byte("password"), nil
160+
}
161+
reader := mustTeeReader(t)
162+
jwe := mustEncryptJWK(t, jwk, []byte("password"))
163+
rand.Reader = reader
164+
jose.RandReader = reader
165+
return jwe
166+
}, false},
167+
{"fail apply", args{data, []Option{WithPasswordFile("testdata/missing.txt")}},
168+
func(t *testing.T) *JSONWebEncryption {
169+
return nil
170+
}, true},
171+
{"fail WithPasswordPrompter", args{data, []Option{
172+
WithContentType("jwk+json"),
173+
WithPasswordPrompter("Enter the password", func(s string) ([]byte, error) {
174+
return nil, errors.New("test error")
175+
})}},
176+
func(t *testing.T) *JSONWebEncryption {
177+
return nil
178+
}, true},
179+
{"fail with PromptPassword", args{data, []Option{WithContentType("jwk+json")}},
180+
func(t *testing.T) *JSONWebEncryption {
181+
tmp := PromptPassword
182+
t.Cleanup(func() { PromptPassword = tmp })
183+
PromptPassword = func(s string) ([]byte, error) {
184+
return nil, errors.New("test error")
185+
}
186+
return nil
187+
}, true},
188+
{"fail no passowrd", args{data, nil},
189+
func(t *testing.T) *JSONWebEncryption {
190+
return nil
191+
}, true},
192+
{"fail encrypt", args{data, []Option{WithPassword([]byte("password"))}},
193+
func(t *testing.T) *JSONWebEncryption {
194+
reader := mustTeeReader(t)
195+
_, _ = randutil.Salt(PBKDF2SaltSize)
196+
rand.Reader = reader
197+
jose.RandReader = reader
198+
return nil
199+
}, true},
200+
{"fail salt", args{data, []Option{WithPassword([]byte("password"))}},
201+
func(t *testing.T) *JSONWebEncryption {
202+
reader := mustTeeReader(t)
203+
rand.Reader = reader
204+
jose.RandReader = reader
205+
return nil
206+
}, true},
207+
}
208+
for _, tt := range tests {
209+
t.Run(tt.name, func(t *testing.T) {
210+
want := tt.wantFn(t)
211+
got, err := Encrypt(tt.args.data, tt.args.opts...)
212+
if (err != nil) != tt.wantErr {
213+
t.Errorf("Encrypt() error = %v, wantErr %v", err, tt.wantErr)
214+
return
215+
}
216+
if !reflect.DeepEqual(got, want) {
217+
t.Errorf("Encrypt() = %v, want %v", got, want)
218+
}
219+
})
220+
}
221+
}
222+
117223
func TestEncryptJWK(t *testing.T) {
118224
jwk := fixJWK(mustGenerateJWK(t, "EC", "P-256", "ES256", "", "", 0))
119225

@@ -266,35 +372,47 @@ func TestDecrypt(t *testing.T) {
266372
want []byte
267373
wantErr bool
268374
}{
269-
{"okNotEncrypted", args{[]byte("foobar"), nil, nil}, []byte("foobar"), false},
270-
{"okWithPassword", args{encryptedData, []Option{WithPassword(testPassword)}, nil}, data, false},
271-
{"okWithPasswordFile", args{encryptedData, []Option{WithPasswordFile("testdata/passphrase.txt")}, nil}, data, false},
272-
{"okWithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
375+
{"ok not encrypted", args{[]byte("foobar"), nil, nil}, []byte("foobar"), false},
376+
{"ok WithPassword", args{encryptedData, []Option{WithPassword(testPassword)}, nil}, data, false},
377+
{"ok WithPasswordFile", args{encryptedData, []Option{WithPasswordFile("testdata/passphrase.txt")}, nil}, data, false},
378+
{"ok WithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
273379
return testPassword, nil
274380
})}, nil}, data, false},
275-
{"okGlobalPasswordPrompter", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
381+
{"ok PasswordPrompter", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
382+
return testPassword, nil
383+
}}, data, false},
384+
{"ok WithFilename and PasswordPrompter", args{encryptedData, []Option{WithFilename("test.jwk")}, func(s string) ([]byte, error) {
276385
return testPassword, nil
277386
}}, data, false},
278-
{"failBadData", args{badEncryptedData, []Option{WithPassword(testPassword)}, nil}, nil, true},
279-
{"failWithPassword", args{encryptedData, []Option{WithPassword([]byte("bad-password"))}, nil}, nil, true},
280-
{"failWithPasswordFile", args{encryptedData, []Option{WithPasswordFile("testdata/oct.txt")}, nil}, nil, true},
281-
{"failWithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
387+
{"fail bad data", args{badEncryptedData, []Option{WithPassword(testPassword)}, nil}, nil, true},
388+
{"fail WithPassword", args{encryptedData, []Option{WithPassword([]byte("bad-password"))}, nil}, nil, true},
389+
{"fail WithPasswordFile", args{encryptedData, []Option{WithPasswordFile("testdata/oct.txt")}, nil}, nil, true},
390+
{"fail WithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
282391
return []byte("bad-password"), nil
283392
})}, nil}, nil, true},
284-
{"failGlobalPasswordPrompter", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
393+
{"fail PasswordPrompter", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
285394
return []byte("bad-password"), nil
286395
}}, nil, true},
287-
{"failApplyWithPassword", args{encryptedData, []Option{WithPasswordFile("testdata/missing.txt")}, nil}, nil, true},
288-
{"failApplyWithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
396+
{"fail apply WithPassword", args{encryptedData, []Option{WithPasswordFile("testdata/missing.txt")}, nil}, nil, true},
397+
{"fail apply WithPasswordPrompter", args{encryptedData, []Option{WithPasswordPrompter("What's the password?", func(s string) ([]byte, error) {
289398
return nil, errors.New("unexpected error")
290399
})}, nil}, nil, true},
291-
{"failGlobalPasswordPrompterError", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
400+
{"fail PasswordPrompter", args{encryptedData, []Option{}, func(s string) ([]byte, error) {
401+
return nil, errors.New("unexpected error")
402+
}}, nil, true},
403+
{"fail WithFilename and PasswordPrompter", args{encryptedData, []Option{WithFilename("test.jwk")}, func(s string) ([]byte, error) {
292404
return nil, errors.New("unexpected error")
293405
}}, nil, true},
294406
}
295407
for _, tt := range tests {
296408
t.Run(tt.name, func(t *testing.T) {
409+
if tt.name == "okGlobalPasswordPrompter" {
410+
t.Log("foo")
411+
}
412+
tmp := PromptPassword
413+
t.Cleanup(func() { PromptPassword = tmp })
297414
PromptPassword = tt.args.passwordPrompter
415+
298416
got, err := Decrypt(tt.args.data, tt.args.opts...)
299417
if (err != nil) != tt.wantErr {
300418
t.Errorf("Decrypt() error = %v, wantErr %v", err, tt.wantErr)

jose/options.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ type context struct {
1212
password []byte
1313
passwordPrompt string
1414
passwordPrompter PasswordPrompter
15+
contentType string
1516
}
1617

1718
// apply the options to the context and returns an error if one of the options
@@ -22,9 +23,6 @@ func (ctx *context) apply(opts ...Option) (*context, error) {
2223
return nil, err
2324
}
2425
}
25-
if ctx.filename == "" {
26-
ctx.filename = "key"
27-
}
2826
return ctx, nil
2927
}
3028

@@ -117,3 +115,11 @@ func WithPasswordPrompter(prompt string, fn PasswordPrompter) Option {
117115
return nil
118116
}
119117
}
118+
119+
// WithContentType adds the content type when encrypting data.
120+
func WithContentType(cty string) Option {
121+
return func(ctx *context) error {
122+
ctx.contentType = cty
123+
return nil
124+
}
125+
}

jose/parse.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ func ParseKey(b []byte, opts ...Option) (*JSONWebKey, error) {
7373
if err != nil {
7474
return nil, err
7575
}
76+
if ctx.filename == "" {
77+
ctx.filename = "key"
78+
}
7679

7780
jwk := new(JSONWebKey)
7881
switch guessKeyType(ctx, b) {

jose/parse_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,54 @@ func TestParseKey(t *testing.T) {
371371
}
372372
}
373373

374+
func TestParseKeyPemutilPromptPassword(t *testing.T) {
375+
pemKey, err := pemutil.Read("../pemutil/testdata/pkcs8/openssl.ed25519.pem")
376+
assert.FatalError(t, err)
377+
378+
pemBytes, err := os.ReadFile("../pemutil/testdata/pkcs8/openssl.ed25519.enc.pem")
379+
assert.FatalError(t, err)
380+
381+
tmp0 := pemutil.PromptPassword
382+
tmp1 := PromptPassword
383+
t.Cleanup(func() {
384+
pemutil.PromptPassword = tmp0
385+
PromptPassword = tmp1
386+
})
387+
388+
tests := []struct {
389+
name string
390+
promptPassword PasswordPrompter
391+
want *JSONWebKey
392+
wantErr bool
393+
}{
394+
{"ok", func(s string) ([]byte, error) {
395+
return []byte("mypassword"), nil
396+
}, &JSONWebKey{
397+
Key: pemKey,
398+
KeyID: "vEk4UARa85PrW0eea2zeVLqGBF-n5Jzd9GVmKAc0AHQ",
399+
Algorithm: "EdDSA",
400+
}, false},
401+
{"fail", func(s string) ([]byte, error) {
402+
return []byte("not-mypassword"), nil
403+
}, nil, true},
404+
}
405+
406+
for _, tt := range tests {
407+
t.Run(tt.name, func(t *testing.T) {
408+
PromptPassword = tt.promptPassword
409+
pemutil.PromptPassword = nil
410+
got, err := ParseKey(pemBytes)
411+
if (err != nil) != tt.wantErr {
412+
t.Errorf("ParseKey() error = %v, wantErr %v", err, tt.wantErr)
413+
return
414+
}
415+
if !reflect.DeepEqual(got, tt.want) {
416+
t.Errorf("ParseKey() = %v, want %v", got, tt.want)
417+
}
418+
})
419+
}
420+
}
421+
374422
func TestReadKeySet(t *testing.T) {
375423
jwk, err := ReadKeySet("testdata/jwks.json", WithKid("qiCJG7r2L80rmWRrZMPfpanQHmZRcncOG7A7MBWn9qM"))
376424
assert.NoError(t, err)

0 commit comments

Comments
 (0)