Skip to content

Commit f308cae

Browse files
committed
feat(api): add Session.DoJSON for authenticated custom API requests
Add DoJSON method to Session that executes authenticated JSON requests against the Proton API using the session's UID, access token, and cookie jar. Uses net/http directly to avoid coupling to go-proton-api internals. Add APIError type for structured Proton API error responses containing HTTP status code, API error code, and message. Rename share info command to share show. Assisted-by: Kiro <noreply@kiro.dev>
1 parent 47290cb commit f308cae

File tree

4 files changed

+314
-4
lines changed

4 files changed

+314
-4
lines changed

api/dojson_test.go

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
"net/http/cookiejar"
9+
"net/http/httptest"
10+
"testing"
11+
12+
"github.com/ProtonMail/go-proton-api"
13+
)
14+
15+
// testSession creates a minimal Session pointing at the given test server.
16+
// It overrides proton.DefaultHostURL for the duration of the test.
17+
func testSession(t *testing.T, serverURL string) *Session {
18+
t.Helper()
19+
jar, _ := cookiejar.New(nil)
20+
return &Session{
21+
Auth: proton.Auth{
22+
UID: "test-uid-123",
23+
AccessToken: "test-token-abc",
24+
},
25+
cookieJar: jar,
26+
}
27+
}
28+
29+
func TestDoJSON_SuccessGet(t *testing.T) {
30+
type payload struct {
31+
Name string `json:"Name"`
32+
ID int `json:"ID"`
33+
}
34+
35+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36+
if r.Method != http.MethodGet {
37+
t.Fatalf("expected GET, got %s", r.Method)
38+
}
39+
if r.Header.Get("x-pm-uid") != "test-uid-123" {
40+
t.Fatalf("missing x-pm-uid header")
41+
}
42+
if r.Header.Get("Authorization") != "Bearer test-token-abc" {
43+
t.Fatalf("missing Authorization header")
44+
}
45+
w.Header().Set("Content-Type", "application/json")
46+
json.NewEncoder(w).Encode(map[string]any{
47+
"Code": 1000,
48+
"Name": "test-share",
49+
"ID": 42,
50+
})
51+
}))
52+
defer srv.Close()
53+
54+
s := testSession(t, srv.URL)
55+
var result payload
56+
err := s.DoJSON(context.Background(), "GET", srv.URL+"/test", nil, &result)
57+
if err != nil {
58+
t.Fatalf("DoJSON GET: %v", err)
59+
}
60+
if result.Name != "test-share" || result.ID != 42 {
61+
t.Fatalf("unexpected result: %+v", result)
62+
}
63+
}
64+
65+
func TestDoJSON_SuccessPost(t *testing.T) {
66+
type reqBody struct {
67+
Email string `json:"Email"`
68+
}
69+
70+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
71+
if r.Method != http.MethodPost {
72+
t.Fatalf("expected POST, got %s", r.Method)
73+
}
74+
if r.Header.Get("Content-Type") != "application/json" {
75+
t.Fatalf("missing Content-Type header")
76+
}
77+
body, _ := io.ReadAll(r.Body)
78+
var req reqBody
79+
if err := json.Unmarshal(body, &req); err != nil {
80+
t.Fatalf("unmarshal request body: %v", err)
81+
}
82+
if req.Email != "user@example.com" {
83+
t.Fatalf("unexpected email: %s", req.Email)
84+
}
85+
w.Header().Set("Content-Type", "application/json")
86+
json.NewEncoder(w).Encode(map[string]any{"Code": 1000})
87+
}))
88+
defer srv.Close()
89+
90+
s := testSession(t, srv.URL)
91+
err := s.DoJSON(context.Background(), "POST", srv.URL+"/invite", reqBody{Email: "user@example.com"}, nil)
92+
if err != nil {
93+
t.Fatalf("DoJSON POST: %v", err)
94+
}
95+
}
96+
97+
func TestDoJSON_APIError(t *testing.T) {
98+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
99+
w.WriteHeader(http.StatusUnprocessableEntity)
100+
json.NewEncoder(w).Encode(map[string]any{
101+
"Code": 2011,
102+
"Error": "Share not found",
103+
})
104+
}))
105+
defer srv.Close()
106+
107+
s := testSession(t, srv.URL)
108+
err := s.DoJSON(context.Background(), "GET", srv.URL+"/test", nil, nil)
109+
if err == nil {
110+
t.Fatal("expected error, got nil")
111+
}
112+
113+
apiErr, ok := err.(*APIError)
114+
if !ok {
115+
t.Fatalf("expected *APIError, got %T: %v", err, err)
116+
}
117+
if apiErr.Status != http.StatusUnprocessableEntity {
118+
t.Fatalf("expected status 422, got %d", apiErr.Status)
119+
}
120+
if apiErr.Code != 2011 {
121+
t.Fatalf("expected code 2011, got %d", apiErr.Code)
122+
}
123+
if apiErr.Message != "Share not found" {
124+
t.Fatalf("expected message 'Share not found', got %q", apiErr.Message)
125+
}
126+
}
127+
128+
func TestDoJSON_AuthHeaders(t *testing.T) {
129+
var gotUID, gotAuth string
130+
131+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
132+
gotUID = r.Header.Get("x-pm-uid")
133+
gotAuth = r.Header.Get("Authorization")
134+
json.NewEncoder(w).Encode(map[string]any{"Code": 1000})
135+
}))
136+
defer srv.Close()
137+
138+
s := testSession(t, srv.URL)
139+
_ = s.DoJSON(context.Background(), "GET", srv.URL+"/test", nil, nil)
140+
141+
if gotUID != "test-uid-123" {
142+
t.Fatalf("x-pm-uid = %q, want %q", gotUID, "test-uid-123")
143+
}
144+
if gotAuth != "Bearer test-token-abc" {
145+
t.Fatalf("Authorization = %q, want %q", gotAuth, "Bearer test-token-abc")
146+
}
147+
}
148+
149+
func TestDoJSON_CookiesAttached(t *testing.T) {
150+
// First request sets a cookie, second request should send it back.
151+
var gotCookie string
152+
153+
callCount := 0
154+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
155+
callCount++
156+
if callCount == 1 {
157+
http.SetCookie(w, &http.Cookie{Name: "Session-Id", Value: "abc123", Path: "/"})
158+
} else {
159+
c, err := r.Cookie("Session-Id")
160+
if err != nil {
161+
gotCookie = ""
162+
} else {
163+
gotCookie = c.Value
164+
}
165+
}
166+
json.NewEncoder(w).Encode(map[string]any{"Code": 1000})
167+
}))
168+
defer srv.Close()
169+
170+
s := testSession(t, srv.URL)
171+
172+
// First call — server sets cookie.
173+
_ = s.DoJSON(context.Background(), "GET", srv.URL+"/first", nil, nil)
174+
// Second call — cookie should be sent.
175+
_ = s.DoJSON(context.Background(), "GET", srv.URL+"/second", nil, nil)
176+
177+
if gotCookie != "abc123" {
178+
t.Fatalf("cookie not attached on second request: got %q", gotCookie)
179+
}
180+
}
181+
182+
func TestDoJSON_NilBody(t *testing.T) {
183+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184+
if r.Body != nil {
185+
body, _ := io.ReadAll(r.Body)
186+
if len(body) > 0 {
187+
t.Fatalf("expected empty body for GET, got %d bytes", len(body))
188+
}
189+
}
190+
if r.Header.Get("Content-Type") != "" {
191+
t.Fatalf("Content-Type should not be set for nil body, got %q", r.Header.Get("Content-Type"))
192+
}
193+
json.NewEncoder(w).Encode(map[string]any{"Code": 1000})
194+
}))
195+
defer srv.Close()
196+
197+
s := testSession(t, srv.URL)
198+
err := s.DoJSON(context.Background(), "GET", srv.URL+"/test", nil, nil)
199+
if err != nil {
200+
t.Fatalf("DoJSON: %v", err)
201+
}
202+
}
203+
204+
func TestDoJSON_Delete(t *testing.T) {
205+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
206+
if r.Method != http.MethodDelete {
207+
t.Fatalf("expected DELETE, got %s", r.Method)
208+
}
209+
json.NewEncoder(w).Encode(map[string]any{"Code": 1000})
210+
}))
211+
defer srv.Close()
212+
213+
s := testSession(t, srv.URL)
214+
err := s.DoJSON(context.Background(), "DELETE", srv.URL+"/member/123", nil, nil)
215+
if err != nil {
216+
t.Fatalf("DoJSON DELETE: %v", err)
217+
}
218+
}

api/errors.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package api
22

33
import (
44
"errors"
5+
"fmt"
56
)
67

78
var (
@@ -16,3 +17,17 @@ var (
1617
// ErrNotLoggedIn indicates that no active session exists.
1718
ErrNotLoggedIn = errors.New("not logged in")
1819
)
20+
21+
// APIError represents a non-success response from the Proton API.
22+
type APIError struct {
23+
Status int // HTTP status code
24+
Code int // Proton API error code
25+
Message string // error description from the API
26+
}
27+
28+
func (e *APIError) Error() string {
29+
if e.Message != "" {
30+
return fmt.Sprintf("api: %d/%d: %s", e.Status, e.Code, e.Message)
31+
}
32+
return fmt.Sprintf("api: %d/%d", e.Status, e.Code)
33+
}

api/session.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
package api
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/json"
57
"errors"
68
"fmt"
9+
"io"
710
"log/slog"
811
"net/http"
912
"net/http/cookiejar"
1013
"net/url"
14+
"strings"
1115
"sync"
1216
"time"
1317

@@ -224,6 +228,79 @@ func (s *Session) Stop() {
224228
s.manager.Close()
225229
}
226230

231+
// apiEnvelope is the standard Proton API response wrapper.
232+
type apiEnvelope struct {
233+
Code int `json:"Code"`
234+
Error string `json:"Error,omitempty"`
235+
}
236+
237+
// DoJSON executes an authenticated JSON API request against the Proton API.
238+
// Method is "GET", "POST", "DELETE", etc. Path is relative to the API base
239+
// (e.g. "/drive/shares/{id}/members"). If body is non-nil it is JSON-encoded
240+
// as the request body. If result is non-nil the response body is JSON-decoded
241+
// into it. Returns an *APIError on non-success API responses.
242+
func (s *Session) DoJSON(ctx context.Context, method, path string, body, result any) error {
243+
reqURL := path
244+
if !strings.HasPrefix(path, "http") {
245+
reqURL = proton.DefaultHostURL + path
246+
}
247+
248+
var bodyReader io.Reader
249+
if body != nil {
250+
data, err := json.Marshal(body)
251+
if err != nil {
252+
return fmt.Errorf("doJSON: marshal body: %w", err)
253+
}
254+
bodyReader = bytes.NewReader(data)
255+
}
256+
257+
req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader)
258+
if err != nil {
259+
return fmt.Errorf("doJSON: new request: %w", err)
260+
}
261+
262+
req.Header.Set("x-pm-uid", s.Auth.UID)
263+
req.Header.Set("Authorization", "Bearer "+s.Auth.AccessToken)
264+
if body != nil {
265+
req.Header.Set("Content-Type", "application/json")
266+
}
267+
req.Header.Set("Accept", "application/json")
268+
269+
httpClient := &http.Client{Jar: s.cookieJar}
270+
resp, err := httpClient.Do(req)
271+
if err != nil {
272+
return fmt.Errorf("doJSON: %s %s: %w", method, path, err)
273+
}
274+
defer resp.Body.Close()
275+
276+
respBody, err := io.ReadAll(resp.Body)
277+
if err != nil {
278+
return fmt.Errorf("doJSON: read response: %w", err)
279+
}
280+
281+
// Parse the envelope to check the API-level error code.
282+
var envelope apiEnvelope
283+
if err := json.Unmarshal(respBody, &envelope); err != nil {
284+
return fmt.Errorf("doJSON: unmarshal envelope: %w", err)
285+
}
286+
287+
if envelope.Code != 1000 {
288+
return &APIError{
289+
Status: resp.StatusCode,
290+
Code: envelope.Code,
291+
Message: envelope.Error,
292+
}
293+
}
294+
295+
if result != nil {
296+
if err := json.Unmarshal(respBody, result); err != nil {
297+
return fmt.Errorf("doJSON: unmarshal result: %w", err)
298+
}
299+
}
300+
301+
return nil
302+
}
303+
227304
// SessionRestore loads credentials from the store and creates an unlocked
228305
// session. Returns ErrNotLoggedIn if no session is stored.
229306
func SessionRestore(ctx context.Context, options []proton.Option, store SessionStore, managerHook func(*proton.Manager)) (*Session, error) {

cmd/share/share.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ func notImplemented(name string) func(*cobra.Command, []string) error {
2424
}
2525
}
2626

27-
var shareInfoCmd = &cobra.Command{
28-
Use: "info <share-name>",
27+
var shareShowCmd = &cobra.Command{
28+
Use: "show <share-name>",
2929
Short: "Show detailed share information",
3030
Long: "Show detailed information about a share including members and invitations",
3131
Args: cobra.ExactArgs(1),
32-
RunE: notImplemented("share info"),
32+
RunE: notImplemented("share show"),
3333
}
3434

3535
var shareInviteCmd = &cobra.Command{
@@ -50,7 +50,7 @@ var shareRevokeCmd = &cobra.Command{
5050

5151
func init() {
5252
cli.AddCommand(shareCmd)
53-
shareCmd.AddCommand(shareInfoCmd)
53+
shareCmd.AddCommand(shareShowCmd)
5454
shareCmd.AddCommand(shareInviteCmd)
5555
shareCmd.AddCommand(shareRevokeCmd)
5656
}

0 commit comments

Comments
 (0)