Skip to content

Commit 14948b2

Browse files
committed
fix: prevent ParseString overflow and stabilize fuzz round-trip
1 parent 5ce9163 commit 14948b2

File tree

13 files changed

+510
-90
lines changed

13 files changed

+510
-90
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,12 @@ jobs:
1717
go-version: "1.22"
1818

1919
- run: go test ./... -count=1
20-
- run: go vet ./...
20+
- run: go vet ./...
21+
- run: go test ./... -count=1 -race
22+
- run: go test ./... -count=1 -cover
23+
- run: go test -bench=. -benchmem ./... # opsiyonel (PR'da biraz yavaşlatır)
24+
25+
- name: Fuzz (short)
26+
run: go test ./... -fuzz=FuzzParseString -fuzztime=10s
27+
28+

alloc_property_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package money_test
2+
3+
import (
4+
"math/rand"
5+
"testing"
6+
"time"
7+
8+
"github.com/dahaiyiyimcom/money"
9+
)
10+
11+
func TestAllocateProportional_Property_SumExact(t *testing.T) {
12+
r := rand.New(rand.NewSource(time.Now().UnixNano()))
13+
14+
for iter := 0; iter < 5000; iter++ {
15+
n := r.Intn(15) // 0..14
16+
bases := make([]money.Amount, n)
17+
18+
var total int64
19+
for i := 0; i < n; i++ {
20+
// Keep bases mostly positive; include some zeros
21+
v := int64(r.Intn(50000)) // up to 500.00
22+
bases[i] = money.NewMinor(v)
23+
total += v
24+
}
25+
26+
// Choose discount in [0, total] (if total==0 => 0)
27+
var d int64
28+
if total > 0 {
29+
d = int64(r.Intn(int(total + 1)))
30+
}
31+
discount := money.NewMinor(d)
32+
33+
out := money.AllocateProportional(bases, discount)
34+
35+
// Invariant 1: output length matches input length
36+
if len(out) != len(bases) {
37+
t.Fatalf("len(out)=%d len(bases)=%d", len(out), len(bases))
38+
}
39+
40+
// Invariant 2: sum(out) == discount
41+
var sum int64
42+
for _, x := range out {
43+
sum += x.Minor()
44+
}
45+
if sum != discount.Minor() {
46+
t.Fatalf("sum(out)=%d discount=%d bases=%v out=%v", sum, discount.Minor(), bases, out)
47+
}
48+
49+
// Optional invariant 3: if discount==0 -> all zeros
50+
if discount.Minor() == 0 {
51+
for i, x := range out {
52+
if x.Minor() != 0 {
53+
t.Fatalf("expected all zeros when discount=0; i=%d got=%d", i, x.Minor())
54+
}
55+
}
56+
}
57+
}
58+
}

alloc_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,21 @@ func TestAllocateProportional_EmptyOrZero(t *testing.T) {
7171
t.Fatalf("expected [0], got=%v", got)
7272
}
7373
}
74+
75+
func TestAllocateProportional_WithNegativeBase(t *testing.T) {
76+
bases := []money.Amount{
77+
money.NewMinor(-1000),
78+
money.NewMinor(2000),
79+
}
80+
discount := money.NewMinor(100)
81+
82+
out := money.AllocateProportional(bases, discount)
83+
84+
var sum int64
85+
for _, v := range out {
86+
sum += v.Minor()
87+
}
88+
if sum != 100 {
89+
t.Fatalf("sum got=%d want=100", sum)
90+
}
91+
}

amount_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,15 @@ func TestAmount_Arithmetic(t *testing.T) {
4343
t.Fatalf("MulQty got=%d want=3702", got)
4444
}
4545
}
46+
47+
func TestAmount_IsNegative(t *testing.T) {
48+
if money.NewMinor(0).IsNegative() {
49+
t.Fatalf("0 should not be negative")
50+
}
51+
if money.NewMinor(1).IsNegative() {
52+
t.Fatalf("positive should not be negative")
53+
}
54+
if !money.NewMinor(-1).IsNegative() {
55+
t.Fatalf("negative should be negative")
56+
}
57+
}

bench_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package money_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/dahaiyiyimcom/money"
8+
)
9+
10+
func BenchmarkParseString(b *testing.B) {
11+
in := "419.29"
12+
b.ReportAllocs()
13+
b.ResetTimer()
14+
15+
for i := 0; i < b.N; i++ {
16+
_, _ = money.ParseString(in)
17+
}
18+
}
19+
20+
func BenchmarkStringFixed2(b *testing.B) {
21+
a := money.NewMinor(41929)
22+
b.ReportAllocs()
23+
b.ResetTimer()
24+
25+
for i := 0; i < b.N; i++ {
26+
_ = a.StringFixed2()
27+
}
28+
}
29+
30+
func BenchmarkJSONMarshalAmount(b *testing.B) {
31+
type DTO struct {
32+
Price money.Amount `json:"price"`
33+
}
34+
dto := DTO{Price: money.NewMinor(41929)}
35+
36+
b.ReportAllocs()
37+
b.ResetTimer()
38+
39+
for i := 0; i < b.N; i++ {
40+
_, _ = json.Marshal(dto)
41+
}
42+
}
43+
44+
func BenchmarkAllocateProportional(b *testing.B) {
45+
bases := make([]money.Amount, 50)
46+
for i := 0; i < len(bases); i++ {
47+
bases[i] = money.NewMinor(int64(1000 + i*3))
48+
}
49+
discount := money.NewMinor(1234)
50+
51+
b.ReportAllocs()
52+
b.ResetTimer()
53+
54+
for i := 0; i < b.N; i++ {
55+
_ = money.AllocateProportional(bases, discount)
56+
}
57+
}
58+
59+
func BenchmarkMulRatio(b *testing.B) {
60+
a := money.NewMinor(1234567) // 12,345.67
61+
b.ReportAllocs()
62+
b.ResetTimer()
63+
64+
for i := 0; i < b.N; i++ {
65+
_ = a.MulRatio(18, 100, money.RoundHalfUp)
66+
}
67+
}

db.go

Lines changed: 39 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@ package money
22

33
import (
44
"database/sql/driver"
5-
"errors"
65
"fmt"
7-
"strings"
6+
"strconv"
87
)
98

10-
// DBAmount is a DB-facing wrapper for DECIMAL(10,2) columns.
11-
// Use it in structs scanned from sql rows.
129
type DBAmount struct {
1310
A Amount
1411
}
@@ -19,100 +16,53 @@ func (m *DBAmount) Scan(value any) error {
1916
return nil
2017
}
2118

22-
var s string
2319
switch v := value.(type) {
2420
case []byte:
25-
s = string(v)
26-
case string:
27-
s = v
28-
default:
29-
return fmt.Errorf("money: unsupported scan type %T", value)
30-
}
31-
32-
a, err := ParseString(s)
33-
if err != nil {
34-
return err
35-
}
36-
m.A = a
37-
return nil
38-
}
39-
40-
func (m DBAmount) Value() (driver.Value, error) {
41-
// write back as "12.34" for DECIMAL(10,2)
42-
return m.A.StringFixed2(), nil
43-
}
21+
a, err := ParseString(string(v))
22+
if err != nil {
23+
return err
24+
}
25+
m.A = a
26+
return nil
4427

45-
// ParseString parses decimal string with optional sign, 2 decimals max.
46-
// Accepts: "12.34", "12", "-0.50", " 12.30 "
47-
func ParseString(s string) (Amount, error) {
48-
s = strings.TrimSpace(s)
49-
if s == "" {
50-
return 0, errors.New("money: empty string")
51-
}
28+
case string:
29+
a, err := ParseString(v)
30+
if err != nil {
31+
return err
32+
}
33+
m.A = a
34+
return nil
5235

53-
sign := int64(1)
54-
if s[0] == '-' {
55-
sign = -1
56-
s = s[1:]
57-
} else if s[0] == '+' {
58-
s = s[1:]
59-
}
36+
case float64:
37+
// Force 2 decimals, then parse strictly.
38+
s := strconv.FormatFloat(v, 'f', 2, 64)
39+
a, err := ParseString(s)
40+
if err != nil {
41+
return err
42+
}
43+
m.A = a
44+
return nil
6045

61-
parts := strings.Split(s, ".")
62-
if len(parts) > 2 {
63-
return 0, fmt.Errorf("money: invalid format: %q", s)
64-
}
46+
case float32:
47+
s := strconv.FormatFloat(float64(v), 'f', 2, 64)
48+
a, err := ParseString(s)
49+
if err != nil {
50+
return err
51+
}
52+
m.A = a
53+
return nil
6554

66-
wholeStr := parts[0]
67-
if wholeStr == "" {
68-
wholeStr = "0"
69-
}
70-
var fracStr string
71-
if len(parts) == 2 {
72-
fracStr = parts[1]
73-
}
55+
case int64:
56+
// If driver returns integer, assume it's already minor? (ambiguous)
57+
// Better to treat as major units without decimals is dangerous.
58+
// If you WANT to support this, define it clearly. For safety, reject:
59+
return fmt.Errorf("money: unsupported scan int64=%d (ambiguous units)", v)
7460

75-
// normalize fraction to 2 digits
76-
switch len(fracStr) {
77-
case 0:
78-
fracStr = "00"
79-
case 1:
80-
fracStr = fracStr + "0"
81-
case 2:
82-
// ok
8361
default:
84-
// If more than 2 decimals exist, you can either reject or round.
85-
// For strict DECIMAL(10,2), reject is safest.
86-
return 0, fmt.Errorf("money: too many decimal places: %q", s)
87-
}
88-
89-
whole, err := parseUint(wholeStr)
90-
if err != nil {
91-
return 0, fmt.Errorf("money: invalid whole part: %w", err)
92-
}
93-
frac, err := parseUint(fracStr)
94-
if err != nil {
95-
return 0, fmt.Errorf("money: invalid fractional part: %w", err)
96-
}
97-
if frac > 99 {
98-
return 0, fmt.Errorf("money: invalid fractional part: %q", fracStr)
62+
return fmt.Errorf("money: unsupported scan type %T", value)
9963
}
100-
101-
minor := int64(whole)*100 + int64(frac)
102-
return Amount(sign * minor), nil
10364
}
10465

105-
func parseUint(s string) (uint64, error) {
106-
if s == "" {
107-
return 0, errors.New("empty")
108-
}
109-
var n uint64
110-
for i := 0; i < len(s); i++ {
111-
c := s[i]
112-
if c < '0' || c > '9' {
113-
return 0, fmt.Errorf("non-digit %q", c)
114-
}
115-
n = n*10 + uint64(c-'0')
116-
}
117-
return n, nil
66+
func (m DBAmount) Value() (driver.Value, error) {
67+
return m.A.StringFixed2(), nil
11868
}

0 commit comments

Comments
 (0)