Skip to content

Commit 1eebed5

Browse files
committed
add: 新增 InjectValue
1 parent d33e064 commit 1eebed5

4 files changed

Lines changed: 359 additions & 7 deletions

File tree

bootstrap/providers/app_init.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Bean interface {
3838
GetBean(alias string) interface{}
3939
}
4040

41+
// InjectValue dest 为字段地址(toolset 生成 &field);config 实现在 ConfigProvider.InjectValue。
4142
type InjectValue interface {
42-
InjectValue(alias string, value interface{})
43+
InjectValue(alias string, dest interface{})
4344
}

bootstrap/providers/config_provider.go

Lines changed: 232 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@ package providers
22

33
import (
44
"embed"
5+
"encoding/json"
56
"flag"
6-
"github.com/go-home-admin/home/bootstrap/services"
7-
"github.com/go-home-admin/home/bootstrap/utils"
8-
"github.com/joho/godotenv"
9-
log "github.com/sirupsen/logrus"
10-
"gopkg.in/yaml.v2"
117
"io/fs"
128
"os"
139
"path"
1410
"path/filepath"
11+
"reflect"
12+
"strconv"
1513
"strings"
14+
15+
"github.com/go-home-admin/home/bootstrap/services"
16+
"github.com/go-home-admin/home/bootstrap/utils"
17+
"github.com/joho/godotenv"
18+
log "github.com/sirupsen/logrus"
19+
"gopkg.in/yaml.v2"
1620
)
1721

1822
var envPath string
@@ -172,6 +176,9 @@ func (c *ConfigProvider) GetBean(alias string) interface{} {
172176
case int:
173177
got := v.(int)
174178
return &got
179+
case int64:
180+
got := v.(int64)
181+
return &got
175182
case uint:
176183
got := v.(uint)
177184
return &got
@@ -228,3 +235,223 @@ func (c *ConfigProvider) GetRoot() string {
228235
ROOT = parDir
229236
return parDir
230237
}
238+
239+
// InjectValue 将配置写入 dest。dest 必须为字段地址(toolset 生成 &field),与 GetBean 返回值配合用反射赋值。
240+
func (c *ConfigProvider) InjectValue(alias string, dest interface{}) {
241+
c.writeInjectDest(alias, dest, c.GetBean(alias))
242+
}
243+
244+
func (c *ConfigProvider) writeInjectDest(configKey string, dest, src interface{}) {
245+
if dest == nil {
246+
c.panicInjectAssign(configKey, "dest 为 nil")
247+
}
248+
destVal := reflect.ValueOf(dest)
249+
if destVal.Kind() != reflect.Ptr || destVal.IsNil() {
250+
c.panicInjectAssign(configKey, "dest 必须为非 nil 的字段地址(生成代码请使用 &field)")
251+
}
252+
fieldVal := destVal.Elem()
253+
if !fieldVal.CanSet() {
254+
c.panicInjectAssign(configKey, "dest 指向的字段不可赋值(生成代码请使用 &field)")
255+
}
256+
257+
if src != nil && c.tryInjectUnmarshaler(configKey, dest, fieldVal, src) {
258+
return
259+
}
260+
261+
if src == nil {
262+
if fieldVal.Kind() == reflect.Ptr {
263+
fieldVal.Set(reflect.Zero(fieldVal.Type()))
264+
return
265+
}
266+
c.panicInjectAssign(configKey, "GetBean 返回 nil 且目标字段非指针类型")
267+
}
268+
269+
srcVal := reflect.ValueOf(src)
270+
271+
// GetBean 返回 *T,写入值类型字段 T(如 bool、int)
272+
if fieldVal.Kind() != reflect.Ptr && srcVal.Kind() == reflect.Ptr {
273+
if srcVal.IsNil() {
274+
fieldVal.Set(reflect.Zero(fieldVal.Type()))
275+
return
276+
}
277+
if configSetAssignable(fieldVal, srcVal.Elem()) {
278+
return
279+
}
280+
c.panicInjectAssign(configKey, "GetBean 指针与值类型字段不匹配")
281+
}
282+
283+
// 目标为 *T:只写入元素值(复制),不把 GetBean 返回的指针地址赋给字段
284+
if fieldVal.Kind() == reflect.Ptr {
285+
if srcVal.Kind() == reflect.Ptr {
286+
if srcVal.IsNil() {
287+
fieldVal.Set(reflect.Zero(fieldVal.Type()))
288+
return
289+
}
290+
srcVal = srcVal.Elem()
291+
}
292+
elemTyp := fieldVal.Type().Elem()
293+
if fieldVal.IsNil() {
294+
newPtr := reflect.New(elemTyp)
295+
if configSetAssignable(newPtr.Elem(), srcVal) {
296+
fieldVal.Set(newPtr)
297+
return
298+
}
299+
} else if configSetAssignable(fieldVal.Elem(), srcVal) {
300+
return
301+
}
302+
c.panicInjectAssign(configKey, "GetBean 与指针字段元素类型不匹配")
303+
}
304+
305+
if configSetAssignable(fieldVal, srcVal) {
306+
return
307+
}
308+
c.panicInjectAssign(configKey, "GetBean 值与字段类型不匹配")
309+
}
310+
311+
// tryInjectUnmarshaler 当注入目标实现 json.Unmarshaler 且配置值为 JSON 文本(string / *string / []byte)时直接 UnmarshalJSON。
312+
func (c *ConfigProvider) tryInjectUnmarshaler(configKey string, dest interface{}, fieldVal reflect.Value, src interface{}) bool {
313+
u := configJSONUnmarshalerTarget(dest, fieldVal)
314+
if u == nil {
315+
return false
316+
}
317+
data, ok := configJSONPayload(src)
318+
if !ok {
319+
return false
320+
}
321+
if err := u.UnmarshalJSON(data); err != nil {
322+
c.panicInjectAssign(configKey, "UnmarshalJSON: "+err.Error())
323+
}
324+
return true
325+
}
326+
327+
func configJSONUnmarshalerTarget(dest interface{}, fieldVal reflect.Value) json.Unmarshaler {
328+
if u, ok := dest.(json.Unmarshaler); ok {
329+
return u
330+
}
331+
if fieldVal.Kind() == reflect.Ptr {
332+
if fieldVal.IsNil() {
333+
fieldVal.Set(reflect.New(fieldVal.Type().Elem()))
334+
}
335+
if u, ok := fieldVal.Interface().(json.Unmarshaler); ok {
336+
return u
337+
}
338+
}
339+
if fieldVal.CanAddr() {
340+
if u, ok := fieldVal.Addr().Interface().(json.Unmarshaler); ok {
341+
return u
342+
}
343+
}
344+
return nil
345+
}
346+
347+
func configJSONPayload(src interface{}) ([]byte, bool) {
348+
if src == nil {
349+
return nil, false
350+
}
351+
switch v := src.(type) {
352+
case string:
353+
return []byte(v), true
354+
case *string:
355+
if v == nil {
356+
return nil, false
357+
}
358+
return []byte(*v), true
359+
case []byte:
360+
return v, true
361+
case *[]byte:
362+
if v == nil {
363+
return nil, false
364+
}
365+
return *v, true
366+
}
367+
sv := reflect.ValueOf(src)
368+
if sv.Kind() == reflect.Ptr {
369+
if sv.IsNil() {
370+
return nil, false
371+
}
372+
return configJSONPayload(sv.Elem().Interface())
373+
}
374+
if sv.Kind() == reflect.String {
375+
return []byte(sv.String()), true
376+
}
377+
return nil, false
378+
}
379+
380+
func configSetAssignable(field, src reflect.Value) bool {
381+
if src.Kind() == reflect.Ptr {
382+
if src.IsNil() {
383+
return false
384+
}
385+
src = src.Elem()
386+
}
387+
if !src.IsValid() {
388+
return false
389+
}
390+
if configTryCoerce(field, src) {
391+
return true
392+
}
393+
if src.Type().AssignableTo(field.Type()) {
394+
field.Set(src)
395+
return true
396+
}
397+
// 勿对 string 使用 reflect.Convert(如 int→string 会得到乱码)
398+
if field.Kind() != reflect.String && src.Type().ConvertibleTo(field.Type()) {
399+
field.Set(src.Convert(field.Type()))
400+
return true
401+
}
402+
return false
403+
}
404+
405+
// configTryCoerce 处理 yaml/env 常见类型与注入字段不一致(如 app.port 为 int 8080,字段为 string / *string)。
406+
func configTryCoerce(field, src reflect.Value) bool {
407+
if !src.IsValid() {
408+
return false
409+
}
410+
if field.Kind() == reflect.String {
411+
return configScalarToString(field, src)
412+
}
413+
if field.Kind() == reflect.Ptr && field.Type().Elem().Kind() == reflect.String {
414+
s := configFormatScalarString(src)
415+
if s == "" && src.Kind() != reflect.String && src.Kind() != reflect.Bool {
416+
return false
417+
}
418+
str := s
419+
field.Set(reflect.ValueOf(&str))
420+
return true
421+
}
422+
return false
423+
}
424+
425+
func configScalarToString(field, src reflect.Value) bool {
426+
s := configFormatScalarString(src)
427+
if s == "" && src.Kind() != reflect.String && src.Kind() != reflect.Bool {
428+
return false
429+
}
430+
field.SetString(s)
431+
return true
432+
}
433+
434+
func configFormatScalarString(src reflect.Value) string {
435+
switch src.Kind() {
436+
case reflect.String:
437+
return src.String()
438+
case reflect.Bool:
439+
return strconv.FormatBool(src.Bool())
440+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
441+
return strconv.FormatInt(src.Int(), 10)
442+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
443+
return strconv.FormatUint(src.Uint(), 10)
444+
case reflect.Float32, reflect.Float64:
445+
return strconv.FormatFloat(src.Float(), 'f', -1, 64)
446+
default:
447+
return ""
448+
}
449+
}
450+
451+
func (c *ConfigProvider) panicInjectAssign(configKey, detail string) {
452+
keyPart := "空键"
453+
if configKey != "" {
454+
keyPart = "键 " + strconv.Quote(configKey)
455+
}
456+
panic("注入 " + keyPart + " 失败:无法写入字段(" + detail + ")")
457+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package providers
2+
3+
import (
4+
"encoding/json"
5+
"reflect"
6+
"testing"
7+
)
8+
9+
type injectJSONSettings struct {
10+
Host string
11+
}
12+
13+
func (s *injectJSONSettings) UnmarshalJSON(b []byte) error {
14+
type alias injectJSONSettings
15+
return json.Unmarshal(b, (*alias)(s))
16+
}
17+
18+
func TestConfigProvider_writeInjectDest_pointerFieldFromStarInt(t *testing.T) {
19+
c := &ConfigProvider{}
20+
var port *int
21+
v := 8080
22+
src := &v
23+
c.writeInjectDest("app.port", &port, src)
24+
if port == nil || *port != 8080 {
25+
t.Fatalf("got %v", port)
26+
}
27+
if port == src {
28+
t.Fatal("field must hold copied value, not GetBean pointer address")
29+
}
30+
}
31+
32+
func TestConfigProvider_writeInjectDest_pointerFieldFromInt64Star(t *testing.T) {
33+
c := &ConfigProvider{}
34+
var port *int
35+
var v int64 = 9090
36+
c.writeInjectDest("app.port", &port, &v)
37+
if port == nil || *port != 9090 {
38+
t.Fatalf("got %v", port)
39+
}
40+
}
41+
42+
func TestConfigProvider_writeInjectDest_pointerFieldNilSrc(t *testing.T) {
43+
c := &ConfigProvider{}
44+
v := 1
45+
port := &v
46+
c.writeInjectDest("app.port", &port, (*int)(nil))
47+
if port != nil {
48+
t.Fatal("expected nil *int")
49+
}
50+
}
51+
52+
func TestConfigProvider_writeInjectDest_structPointerField(t *testing.T) {
53+
c := &ConfigProvider{}
54+
type holder struct {
55+
P *int
56+
}
57+
h := holder{}
58+
v := 3000
59+
c.writeInjectDest("app.port", &h.P, &v)
60+
if h.P == nil || *h.P != 3000 {
61+
t.Fatalf("got %v", h.P)
62+
}
63+
}
64+
65+
func TestConfigProvider_writeInjectDest_stringFieldFromInt(t *testing.T) {
66+
c := &ConfigProvider{}
67+
var port string
68+
v := 8080
69+
c.writeInjectDest("app.port", &port, &v)
70+
if port != "8080" {
71+
t.Fatalf("got %q", port)
72+
}
73+
}
74+
75+
func TestConfigProvider_writeInjectDest_starStringFromInt(t *testing.T) {
76+
c := &ConfigProvider{}
77+
var port *string
78+
v := 8080
79+
c.writeInjectDest("app.port", &port, &v)
80+
if port == nil || *port != "8080" {
81+
t.Fatalf("got %v", port)
82+
}
83+
}
84+
85+
func TestConfigProvider_writeInjectDest_jsonUnmarshaler(t *testing.T) {
86+
c := &ConfigProvider{}
87+
var cfg injectJSONSettings
88+
raw := `{"Host":"127.0.0.1"}`
89+
c.writeInjectDest("app.settings", &cfg, &raw)
90+
if cfg.Host != "127.0.0.1" {
91+
t.Fatalf("got %+v", cfg)
92+
}
93+
}
94+
95+
func TestConfigProvider_writeInjectDest_jsonUnmarshalerPointerField(t *testing.T) {
96+
c := &ConfigProvider{}
97+
var cfg *injectJSONSettings
98+
raw := `{"Host":"localhost"}`
99+
c.writeInjectDest("app.settings", &cfg, raw)
100+
if cfg == nil || cfg.Host != "localhost" {
101+
t.Fatalf("got %+v", cfg)
102+
}
103+
}
104+
105+
func TestConfigProvider_writeInjectDest_boolValueField(t *testing.T) {
106+
c := &ConfigProvider{}
107+
var debug bool
108+
src := true
109+
c.writeInjectDest("app.debug", &debug, &src)
110+
if !debug {
111+
t.Fatal("expected true")
112+
}
113+
}
114+
115+
func TestConfigProvider_InjectValue_fieldAddressAssignable(t *testing.T) {
116+
type holder struct {
117+
P *int
118+
}
119+
h := holder{}
120+
rv := reflect.ValueOf(&h.P)
121+
if !rv.Elem().CanSet() {
122+
t.Fatal("&h.P should be assignable")
123+
}
124+
}

0 commit comments

Comments
 (0)