Skip to content

Commit e3f3942

Browse files
committed
feat: add disable web root
1 parent f508958 commit e3f3942

6 files changed

Lines changed: 226 additions & 6 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ docker-compose up -d
141141
```bash
142142
LISTEN=:3000 # Server listen address
143143
ADMIN_KEY=your-admin-key # Admin API key
144+
DISABLE_WEB_ROOT=true # Redirect only `/` to GitHub, keep other web routes available
144145
```
145146

146147
#### **Database Configuration**

README.zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ docker-compose up -d
142142
```bash
143143
LISTEN=:3000 # 服务器监听地址
144144
ADMIN_KEY=your-admin-key # 管理员 API 密钥
145+
DISABLE_WEB_ROOT=true # 仅将 `/` 重定向到 GitHub,其他 Web 路径保持可访问
145146
```
146147

147148
#### **数据库配置**

core/common/config/env.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ var (
1414
AdminKey string
1515
WebPath string
1616
DisableWeb bool
17+
DisableWebRoot bool
1718
FfmpegEnabled bool
1819
InternalToken string
1920
DisableModelConfig bool
@@ -34,6 +35,7 @@ func ReloadEnv() {
3435
AdminKey = os.Getenv("ADMIN_KEY")
3536
WebPath = os.Getenv("WEB_PATH")
3637
DisableWeb = env.Bool("DISABLE_WEB", false)
38+
DisableWebRoot = env.Bool("DISABLE_WEB_ROOT", false)
3739
FfmpegEnabled = env.Bool("FFMPEG_ENABLED", false)
3840
InternalToken = os.Getenv("INTERNAL_TOKEN")
3941
DisableModelConfig = env.Bool("DISABLE_MODEL_CONFIG", false)

core/relay/utils/utils_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"io"
89
"net/http"
910
"net/http/httptest"
11+
"net/url"
1012
"testing"
1113
"time"
1214

@@ -74,6 +76,43 @@ func TestDoRequest(t *testing.T) {
7476
})
7577
}
7678

79+
func TestDoRequestResponseHeaderTimeout(t *testing.T) {
80+
convey.Convey("DoRequest should timeout while awaiting response headers", t, func() {
81+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82+
time.Sleep(1500 * time.Millisecond)
83+
w.WriteHeader(http.StatusOK)
84+
_, _ = w.Write([]byte("ok"))
85+
}))
86+
defer ts.Close()
87+
88+
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL, nil)
89+
start := time.Now()
90+
91+
resp, err := utils.DoRequest(req, time.Second)
92+
if resp != nil {
93+
defer resp.Body.Close()
94+
}
95+
96+
elapsed := time.Since(start)
97+
98+
convey.So(resp, convey.ShouldBeNil)
99+
convey.So(err, convey.ShouldNotBeNil)
100+
101+
var urlErr *url.Error
102+
convey.So(errors.As(err, &urlErr), convey.ShouldBeTrue)
103+
convey.So(errors.Is(err, context.DeadlineExceeded), convey.ShouldBeTrue)
104+
convey.So(urlErr.Timeout(), convey.ShouldBeTrue)
105+
convey.So(
106+
urlErr.Err.Error(),
107+
convey.ShouldEqual,
108+
"net/http: timeout awaiting response headers",
109+
)
110+
111+
convey.So(elapsed >= time.Second, convey.ShouldBeTrue)
112+
convey.So(elapsed < 1400*time.Millisecond, convey.ShouldBeTrue)
113+
})
114+
}
115+
77116
func TestLoadHTTPClientReuse(t *testing.T) {
78117
convey.Convey("LoadHTTPClient reuse", t, func() {
79118
client1, err := utils.LoadHTTPClientE(time.Second, "")

core/router/static.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ import (
1616
"github.com/sirupsen/logrus"
1717
)
1818

19+
const (
20+
githubProjectURL = "https://github.com/labring/aiproxy"
21+
githubProjectInitialCountdown = 15
22+
)
23+
1924
func SetStaticFileRouter(router *gin.Engine) {
2025
router.SetHTMLTemplate(
2126
template.Must(
@@ -24,12 +29,7 @@ func SetStaticFileRouter(router *gin.Engine) {
2429
)
2530

2631
if config.DisableWeb {
27-
router.GET("/", func(ctx *gin.Context) {
28-
ctx.HTML(http.StatusOK, "index.tmpl", gin.H{
29-
"URL": "https://github.com/labring/aiproxy",
30-
"INITIAL_COUNTDOWN": 15,
31-
})
32-
})
32+
router.GET("/", renderWebRootRedirectPage)
3333

3434
return
3535
}
@@ -45,6 +45,8 @@ func SetStaticFileRouter(router *gin.Engine) {
4545
panic(err)
4646
}
4747

48+
registerWebRootRedirect(router)
49+
4850
fs := http.FS(public.Public)
4951
router.NoRoute(newIndexNoRouteHandler(fs))
5052
} else {
@@ -65,10 +67,27 @@ func SetStaticFileRouter(router *gin.Engine) {
6567
panic(err)
6668
}
6769

70+
registerWebRootRedirect(router)
6871
router.NoRoute(newDynamicNoRouteHandler(http.Dir(absPath)))
6972
}
7073
}
7174

75+
func registerWebRootRedirect(router *gin.Engine) {
76+
if !config.DisableWebRoot {
77+
return
78+
}
79+
80+
router.GET("/", renderWebRootRedirectPage)
81+
router.HEAD("/", renderWebRootRedirectPage)
82+
}
83+
84+
func renderWebRootRedirectPage(ctx *gin.Context) {
85+
ctx.HTML(http.StatusOK, "index.tmpl", gin.H{
86+
"URL": githubProjectURL,
87+
"INITIAL_COUNTDOWN": githubProjectInitialCountdown,
88+
})
89+
}
90+
7291
func checkNoRouteNotFound(req *http.Request) bool {
7392
if req.Method != http.MethodGet &&
7493
req.Method != http.MethodHead {

core/router/static_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package router_test
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"os"
8+
"path/filepath"
9+
"testing"
10+
11+
"github.com/gin-gonic/gin"
12+
"github.com/labring/aiproxy/core/common/config"
13+
corerouter "github.com/labring/aiproxy/core/router"
14+
"github.com/smartystreets/goconvey/convey"
15+
)
16+
17+
const testGitHubProjectURL = "https://github.com/labring/aiproxy"
18+
19+
func TestSetStaticFileRouter_DisableWebRoot(t *testing.T) {
20+
convey.Convey("SetStaticFileRouter with DISABLE_WEB_ROOT", t, func() {
21+
webPath := writeTestWebFiles(t)
22+
router := newTestStaticRouter(t, webPath, false, true)
23+
24+
convey.Convey("should redirect root path to github", func() {
25+
recorder := httptest.NewRecorder()
26+
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
27+
28+
router.ServeHTTP(recorder, req)
29+
30+
convey.So(recorder.Code, convey.ShouldEqual, http.StatusOK)
31+
convey.So(recorder.Body.String(), convey.ShouldContainSubstring, testGitHubProjectURL)
32+
convey.So(
33+
recorder.Body.String(),
34+
convey.ShouldContainSubstring,
35+
"id=\"countdown\">15</div>",
36+
)
37+
})
38+
39+
convey.Convey("should keep static assets accessible", func() {
40+
recorder := httptest.NewRecorder()
41+
req := httptest.NewRequestWithContext(
42+
context.Background(),
43+
http.MethodGet,
44+
"/assets/app.js",
45+
nil,
46+
)
47+
48+
router.ServeHTTP(recorder, req)
49+
50+
convey.So(recorder.Code, convey.ShouldEqual, http.StatusOK)
51+
convey.So(recorder.Body.String(), convey.ShouldContainSubstring, "console.log('ok');")
52+
})
53+
54+
convey.Convey("should keep SPA fallback accessible", func() {
55+
recorder := httptest.NewRecorder()
56+
req := httptest.NewRequestWithContext(
57+
context.Background(),
58+
http.MethodGet,
59+
"/dashboard",
60+
nil,
61+
)
62+
63+
router.ServeHTTP(recorder, req)
64+
65+
convey.So(recorder.Code, convey.ShouldEqual, http.StatusOK)
66+
convey.So(recorder.Body.String(), convey.ShouldContainSubstring, "test-spa")
67+
})
68+
})
69+
}
70+
71+
func TestSetStaticFileRouter_DisableWeb(t *testing.T) {
72+
convey.Convey("SetStaticFileRouter with DISABLE_WEB", t, func() {
73+
webPath := writeTestWebFiles(t)
74+
router := newTestStaticRouter(t, webPath, true, false)
75+
76+
convey.Convey("should serve the github redirect page on root", func() {
77+
recorder := httptest.NewRecorder()
78+
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
79+
80+
router.ServeHTTP(recorder, req)
81+
82+
convey.So(recorder.Code, convey.ShouldEqual, http.StatusOK)
83+
convey.So(recorder.Body.String(), convey.ShouldContainSubstring, testGitHubProjectURL)
84+
})
85+
86+
convey.Convey("should not expose static assets", func() {
87+
recorder := httptest.NewRecorder()
88+
req := httptest.NewRequestWithContext(
89+
context.Background(),
90+
http.MethodGet,
91+
"/assets/app.js",
92+
nil,
93+
)
94+
95+
router.ServeHTTP(recorder, req)
96+
97+
convey.So(recorder.Code, convey.ShouldEqual, http.StatusNotFound)
98+
})
99+
})
100+
}
101+
102+
func newTestStaticRouter(
103+
t *testing.T,
104+
webPath string,
105+
disableWeb, disableWebRoot bool,
106+
) *gin.Engine {
107+
t.Helper()
108+
109+
gin.SetMode(gin.TestMode)
110+
111+
oldWebPath := config.WebPath
112+
oldDisableWeb := config.DisableWeb
113+
oldDisableWebRoot := config.DisableWebRoot
114+
115+
t.Cleanup(func() {
116+
config.WebPath = oldWebPath
117+
config.DisableWeb = oldDisableWeb
118+
config.DisableWebRoot = oldDisableWebRoot
119+
})
120+
121+
config.WebPath = webPath
122+
config.DisableWeb = disableWeb
123+
config.DisableWebRoot = disableWebRoot
124+
125+
router := gin.New()
126+
corerouter.SetStaticFileRouter(router)
127+
128+
return router
129+
}
130+
131+
func writeTestWebFiles(t *testing.T) string {
132+
t.Helper()
133+
134+
webPath := t.TempDir()
135+
assetsPath := filepath.Join(webPath, "assets")
136+
137+
if err := os.MkdirAll(assetsPath, 0o755); err != nil {
138+
t.Fatalf("mkdir assets: %v", err)
139+
}
140+
141+
if err := os.WriteFile(
142+
filepath.Join(webPath, "index.html"),
143+
[]byte("<!doctype html><html><body>test-spa</body></html>"),
144+
0o600,
145+
); err != nil {
146+
t.Fatalf("write index.html: %v", err)
147+
}
148+
149+
if err := os.WriteFile(
150+
filepath.Join(assetsPath, "app.js"),
151+
[]byte("console.log('ok');"),
152+
0o600,
153+
); err != nil {
154+
t.Fatalf("write app.js: %v", err)
155+
}
156+
157+
return webPath
158+
}

0 commit comments

Comments
 (0)