Skip to content

Commit be0ec6b

Browse files
committed
target/smtp: Check-in accidentally reverted attempt_starttls changes
1 parent cff6cfa commit be0ec6b

File tree

5 files changed

+70
-116
lines changed

5 files changed

+70
-116
lines changed

framework/config/map.go

+15-6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package config
2020

2121
import (
2222
"errors"
23+
"fmt"
2324
"reflect"
2425
"strconv"
2526
"strings"
@@ -305,6 +306,16 @@ func (m *Map) DataSize(name string, inheritGlobal, required bool, defaultVal int
305306
}, store)
306307
}
307308

309+
func ParseBool(s string) (bool, error) {
310+
switch strings.ToLower(s) {
311+
case "1", "true", "on", "yes":
312+
return true, nil
313+
case "0", "false", "off", "no":
314+
return false, nil
315+
}
316+
return false, fmt.Errorf("bool argument should be 'yes' or 'no'")
317+
}
318+
308319
// Bool maps presence of some configuration directive to a boolean variable.
309320
// Additionally, 'name yes' and 'name no' are mapped to true and false
310321
// correspondingly.
@@ -327,13 +338,11 @@ func (m *Map) Bool(name string, inheritGlobal, defaultVal bool, store *bool) {
327338
return nil, NodeErr(node, "expected exactly 1 argument")
328339
}
329340

330-
switch strings.ToLower(node.Args[0]) {
331-
case "1", "true", "on", "yes":
332-
return true, nil
333-
case "0", "false", "off", "no":
334-
return false, nil
341+
b, err := ParseBool(node.Args[0])
342+
if err != nil {
343+
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
335344
}
336-
return nil, NodeErr(node, "bool argument should be 'yes' or 'no'")
345+
return b, nil
337346
}, store)
338347
}
339348

internal/smtpconn/smtpconn.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint,
259259
return false, nil, nil, err
260260
}
261261

262-
if endp.IsTLS() || !starttls {
263-
return endp.IsTLS(), cl, conn, nil
262+
if !starttls {
263+
return false, cl, conn, nil
264264
}
265265

266266
if ok, _ := cl.Extension("STARTTLS"); !ok {
267-
return false, cl, conn, nil
267+
if err := cl.Quit(); err != nil {
268+
cl.Close()
269+
}
270+
return false, nil, nil, fmt.Errorf("TLS required but unsupported by downstream")
268271
}
269272

270273
cfg := tlsConfig.Clone()

internal/target/smtp/smtp_downstream.go

+38-21
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ package smtp_downstream
2929
import (
3030
"context"
3131
"crypto/tls"
32-
"errors"
3332
"fmt"
3433
"net"
3534
"runtime/trace"
@@ -54,12 +53,11 @@ type Downstream struct {
5453
lmtp bool
5554
targetsArg []string
5655

57-
requireTLS bool
58-
attemptStartTLS bool
59-
hostname string
60-
endpoints []config.Endpoint
61-
saslFactory saslClientFactory
62-
tlsConfig tls.Config
56+
starttls bool
57+
hostname string
58+
endpoints []config.Endpoint
59+
saslFactory saslClientFactory
60+
tlsConfig tls.Config
6361

6462
connectTimeout time.Duration
6563
commandTimeout time.Duration
@@ -89,10 +87,34 @@ func NewDownstream(modName, instName string, _, inlineArgs []string) (module.Mod
8987
}
9088

9189
func (u *Downstream) Init(cfg *config.Map) error {
90+
var attemptTLS *bool
91+
9292
var targetsArg []string
9393
cfg.Bool("debug", true, false, &u.log.Debug)
94-
cfg.Bool("require_tls", false, false, &u.requireTLS)
95-
cfg.Bool("attempt_starttls", false, !u.lmtp, &u.attemptStartTLS)
94+
cfg.Callback("require_tls", func(m *config.Map, node config.Node) error {
95+
u.log.Msg("require_tls directive is deprecated and ignored")
96+
return nil
97+
})
98+
cfg.Callback("attempt_starttls", func(m *config.Map, node config.Node) error {
99+
u.log.Msg("attempt_starttls directive is deprecated and equivalent to starttls")
100+
101+
if len(node.Args) == 0 {
102+
trueVal := true
103+
attemptTLS = &trueVal
104+
return nil
105+
}
106+
if len(node.Args) != 1 {
107+
return config.NodeErr(node, "expected exactly 1 argument")
108+
}
109+
110+
b, err := config.ParseBool(node.Args[0])
111+
if err != nil {
112+
return err
113+
}
114+
attemptTLS = &b
115+
return nil
116+
})
117+
cfg.Bool("starttls", false, !u.lmtp, &u.starttls)
96118
cfg.String("hostname", true, true, "", &u.hostname)
97119
cfg.StringList("targets", false, false, nil, &targetsArg)
98120
cfg.Custom("auth", false, false, func() (interface{}, error) {
@@ -109,6 +131,10 @@ func (u *Downstream) Init(cfg *config.Map) error {
109131
return err
110132
}
111133

134+
if attemptTLS != nil {
135+
u.starttls = *attemptTLS
136+
}
137+
112138
// INTERNATIONALIZATION: See RFC 6531 Section 3.7.1.
113139
var err error
114140
u.hostname, err = idna.ToASCII(u.hostname)
@@ -201,14 +227,11 @@ func (d *delivery) connect(ctx context.Context) error {
201227
}
202228

203229
for _, endp := range d.u.endpoints {
204-
var (
205-
didTLS bool
206-
err error
207-
)
230+
var err error
208231
if d.u.lmtp {
209-
didTLS, err = conn.ConnectLMTP(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
232+
_, err = conn.ConnectLMTP(ctx, endp, d.u.starttls, &d.u.tlsConfig)
210233
} else {
211-
didTLS, err = conn.Connect(ctx, endp, d.u.attemptStartTLS, &d.u.tlsConfig)
234+
_, err = conn.Connect(ctx, endp, d.u.starttls, &d.u.tlsConfig)
212235
}
213236
if err != nil {
214237
if len(d.u.endpoints) != 1 {
@@ -220,12 +243,6 @@ func (d *delivery) connect(ctx context.Context) error {
220243

221244
d.log.DebugMsg("connected", "downstream_server", conn.ServerName())
222245

223-
if !didTLS && d.u.requireTLS {
224-
conn.Close()
225-
lastErr = errors.New("TLS is required, but unsupported by downstream")
226-
continue
227-
}
228-
229246
lastErr = nil
230247
break
231248
}

internal/target/smtp/smtp_downstream_test.go

+7-86
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func TestDownstreamDelivery_MAILErr(t *testing.T) {
207207
testutils.CheckSMTPErr(t, err, 550, exterrors.EnhancedCode{5, 1, 2}, "Hey")
208208
}
209209

210-
func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
210+
func TestDownstreamDelivery_StartTLS(t *testing.T) {
211211
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
212212
defer srv.Close()
213213
defer testutils.CheckSMTPConnLeak(t, srv)
@@ -221,9 +221,9 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
221221
Port: testPort,
222222
},
223223
},
224-
tlsConfig: *clientCfg.Clone(),
225-
attemptStartTLS: true,
226-
log: testutils.Logger(t, "target.smtp"),
224+
tlsConfig: *clientCfg.Clone(),
225+
starttls: true,
226+
log: testutils.Logger(t, "target.smtp"),
227227
}
228228

229229
testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
@@ -235,85 +235,7 @@ func TestDownstreamDelivery_AttemptTLS(t *testing.T) {
235235
}
236236
}
237237

238-
func TestDownstreamDelivery_AttemptTLS_Fallback(t *testing.T) {
239-
be, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
240-
defer srv.Close()
241-
defer testutils.CheckSMTPConnLeak(t, srv)
242-
243-
mod := &Downstream{
244-
hostname: "mx.example.invalid",
245-
endpoints: []config.Endpoint{
246-
{
247-
Scheme: "tcp",
248-
Host: "127.0.0.1",
249-
Port: testPort,
250-
},
251-
},
252-
attemptStartTLS: true,
253-
log: testutils.Logger(t, "target.smtp"),
254-
}
255-
256-
testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
257-
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
258-
}
259-
260-
func TestDownstreamDelivery_RequireTLS(t *testing.T) {
261-
clientCfg, be, srv := testutils.SMTPServerSTARTTLS(t, "127.0.0.1:"+testPort)
262-
defer srv.Close()
263-
defer testutils.CheckSMTPConnLeak(t, srv)
264-
265-
mod := &Downstream{
266-
hostname: "mx.example.invalid",
267-
endpoints: []config.Endpoint{
268-
{
269-
Scheme: "tcp",
270-
Host: "127.0.0.1",
271-
Port: testPort,
272-
},
273-
},
274-
tlsConfig: *clientCfg.Clone(),
275-
attemptStartTLS: true,
276-
requireTLS: true,
277-
log: testutils.Logger(t, "target.smtp"),
278-
}
279-
280-
testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
281-
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
282-
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
283-
if !ok || !tlsState.HandshakeComplete {
284-
t.Fatal("Message was not delivered over TLS")
285-
}
286-
}
287-
288-
func TestDownstreamDelivery_RequireTLS_Implicit(t *testing.T) {
289-
clientCfg, be, srv := testutils.SMTPServerTLS(t, "127.0.0.1:"+testPort)
290-
defer srv.Close()
291-
defer testutils.CheckSMTPConnLeak(t, srv)
292-
293-
mod := &Downstream{
294-
hostname: "mx.example.invalid",
295-
endpoints: []config.Endpoint{
296-
{
297-
Scheme: "tls",
298-
Host: "127.0.0.1",
299-
Port: testPort,
300-
},
301-
},
302-
tlsConfig: *clientCfg.Clone(),
303-
attemptStartTLS: true,
304-
requireTLS: true,
305-
log: testutils.Logger(t, "target.smtp"),
306-
}
307-
308-
testutils.DoTestDelivery(t, mod, "[email protected]", []string{"[email protected]"})
309-
be.CheckMsg(t, 0, "[email protected]", []string{"[email protected]"})
310-
tlsState, ok := be.Messages[0].Conn.TLSConnectionState()
311-
if !ok || !tlsState.HandshakeComplete {
312-
t.Fatal("Message was not delivered over TLS")
313-
}
314-
}
315-
316-
func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
238+
func TestDownstreamDelivery_StartTLS_NoFallback(t *testing.T) {
317239
_, srv := testutils.SMTPServer(t, "127.0.0.1:"+testPort)
318240
defer srv.Close()
319241
defer testutils.CheckSMTPConnLeak(t, srv)
@@ -327,9 +249,8 @@ func TestDownstreamDelivery_RequireTLS_Fail(t *testing.T) {
327249
Port: testPort,
328250
},
329251
},
330-
attemptStartTLS: true,
331-
requireTLS: true,
332-
log: testutils.Logger(t, "target.smtp"),
252+
starttls: true,
253+
log: testutils.Logger(t, "target.smtp"),
333254
}
334255

335256
_, err := testutils.DoTestDeliveryErr(t, mod, "[email protected]", []string{"[email protected]"})

internal/target/smtp/smtputf8_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ func TestDownstreamDelivery_EHLO_ALabel(t *testing.T) {
4040
Name: "hostname",
4141
Args: []string{"тест.invalid"},
4242
},
43+
{
44+
Name: "starttls",
45+
Args: []string{"no"},
46+
},
4347
},
4448
})); err != nil {
4549
t.Fatal(err)

0 commit comments

Comments
 (0)