|
4 | 4 | "bufio" |
5 | 5 | "bytes" |
6 | 6 | "encoding/json" |
| 7 | + "fmt" |
7 | 8 | "io" |
8 | 9 | "mime/multipart" |
9 | 10 | "net/http" |
@@ -619,6 +620,111 @@ func TestGuestUpload(t *testing.T) { |
619 | 620 | } |
620 | 621 | } |
621 | 622 |
|
| 623 | +func TestGuestUploadAcceptHeader(t *testing.T) { |
| 624 | + authenticator, err := shared_secret.New("dummypass") |
| 625 | + if err != nil { |
| 626 | + t.Fatalf("failed to create shared secret: %v", err) |
| 627 | + } |
| 628 | + |
| 629 | + for _, tt := range []struct { |
| 630 | + explanation string |
| 631 | + acceptHeader string |
| 632 | + expectJSON bool |
| 633 | + expectedContentType string |
| 634 | + }{ |
| 635 | + { |
| 636 | + "no Accept header returns plain text URL", |
| 637 | + "", |
| 638 | + false, |
| 639 | + "text/plain", |
| 640 | + }, |
| 641 | + { |
| 642 | + "Accept header with wildcard returns plain text URL", |
| 643 | + "*/*", |
| 644 | + false, |
| 645 | + "text/plain", |
| 646 | + }, |
| 647 | + { |
| 648 | + "Accept header with application/json returns JSON", |
| 649 | + "application/json", |
| 650 | + true, |
| 651 | + "application/json", |
| 652 | + }, |
| 653 | + { |
| 654 | + "Accept header with text/html returns plain text URL", |
| 655 | + "text/html", |
| 656 | + false, |
| 657 | + "text/plain", |
| 658 | + }, |
| 659 | + } { |
| 660 | + t.Run(fmt.Sprintf("%s [%s]", tt.explanation, tt.acceptHeader), func(t *testing.T) { |
| 661 | + dataStore := test_sqlite.New() |
| 662 | + guestLink := picoshare.GuestLink{ |
| 663 | + ID: picoshare.GuestLinkID("abcdefgh23456789"), |
| 664 | + Created: mustParseTime("2022-05-26T00:00:00Z"), |
| 665 | + UrlExpires: mustParseExpirationTime("2030-01-02T03:04:25Z"), |
| 666 | + MaxFileLifetime: picoshare.FileLifetimeInfinite, |
| 667 | + } |
| 668 | + if err := dataStore.InsertGuestLink(guestLink); err != nil { |
| 669 | + t.Fatalf("failed to insert dummy guest link: %v", err) |
| 670 | + } |
| 671 | + |
| 672 | + c := mockClock{mustParseTime("2024-01-01T00:00:00Z")} |
| 673 | + s := handlers.New(authenticator, &dataStore, nilSpaceChecker, nilGarbageCollector, c) |
| 674 | + |
| 675 | + filename := "dummyimage.png" |
| 676 | + contents := "dummy bytes" |
| 677 | + formData, contentType := createMultipartFormBody(filename, "", strings.NewReader(contents)) |
| 678 | + |
| 679 | + req, err := http.NewRequest("POST", "/api/guest/abcdefgh23456789", formData) |
| 680 | + if err != nil { |
| 681 | + t.Fatal(err) |
| 682 | + } |
| 683 | + req.Header.Add("Content-Type", contentType) |
| 684 | + if tt.acceptHeader != "" { |
| 685 | + req.Header.Add("Accept", tt.acceptHeader) |
| 686 | + } |
| 687 | + |
| 688 | + rec := httptest.NewRecorder() |
| 689 | + s.Router().ServeHTTP(rec, req) |
| 690 | + res := rec.Result() |
| 691 | + |
| 692 | + if got, want := res.StatusCode, http.StatusOK; got != want { |
| 693 | + t.Fatalf("status=%d, want=%d", got, want) |
| 694 | + } |
| 695 | + |
| 696 | + if got, want := res.Header.Get("Content-Type"), tt.expectedContentType; got != want { |
| 697 | + t.Errorf("Content-Type=%v, want=%v", got, want) |
| 698 | + } |
| 699 | + |
| 700 | + body, err := io.ReadAll(res.Body) |
| 701 | + if err != nil { |
| 702 | + t.Fatalf("failed to read response body") |
| 703 | + } |
| 704 | + |
| 705 | + if tt.expectJSON { |
| 706 | + var response handlers.EntryPostResponse |
| 707 | + err = json.Unmarshal(body, &response) |
| 708 | + if err != nil { |
| 709 | + t.Fatalf("response is not valid JSON: %v", string(body)) |
| 710 | + } |
| 711 | + if got, want := len(response.ID), 10; got != want { |
| 712 | + t.Errorf("ID length=%d, want=%d", got, want) |
| 713 | + } |
| 714 | + } else { |
| 715 | + // Should be plain text URL. |
| 716 | + bodyStr := string(body) |
| 717 | + if !strings.Contains(bodyStr, "http") { |
| 718 | + t.Errorf("expected URL in response, got: %v", bodyStr) |
| 719 | + } |
| 720 | + if !strings.HasSuffix(bodyStr, "\r\n") { |
| 721 | + t.Errorf("expected response to end with \\r\\n, got: %v", bodyStr) |
| 722 | + } |
| 723 | + } |
| 724 | + }) |
| 725 | + } |
| 726 | +} |
| 727 | + |
622 | 728 | func createMultipartFormBody(filename, note string, r io.Reader) (io.Reader, string) { |
623 | 729 | var b bytes.Buffer |
624 | 730 | bw := bufio.NewWriter(&b) |
|
0 commit comments