mirror of
https://github.com/decke/smtprelay.git
synced 2025-12-25 07:43:06 -07:00
Refactor parsing of -listen string out into separate config function
This makes the "for each listen address" loop in main() look even cleaner.
This commit is contained in:
32
config.go
32
config.go
@@ -3,8 +3,9 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"net"
|
||||
"regexp"
|
||||
"net/smtp"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/vharitonsky/iniflags"
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -21,7 +22,8 @@ var (
|
||||
logLevel = flag.String("log_level", "info", "Minimum log level to output")
|
||||
hostName = flag.String("hostname", "localhost.localdomain", "Server hostname")
|
||||
welcomeMsg = flag.String("welcome_msg", "", "Welcome message for SMTP session")
|
||||
listen = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP")
|
||||
listenStr = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP")
|
||||
listenAddrs = []protoAddr{}
|
||||
localCert = flag.String("local_cert", "", "SSL certificate for STARTTLS/TLS")
|
||||
localKey = flag.String("local_key", "", "SSL private key for STARTTLS/TLS")
|
||||
localForceTLS = flag.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)")
|
||||
@@ -130,6 +132,31 @@ func setupRemoteAuth() {
|
||||
}
|
||||
}
|
||||
|
||||
type protoAddr struct {
|
||||
protocol string
|
||||
address string
|
||||
}
|
||||
|
||||
func splitProto(s string) protoAddr {
|
||||
idx := strings.Index(s, "://")
|
||||
if idx == -1 {
|
||||
return protoAddr {
|
||||
address: s,
|
||||
}
|
||||
}
|
||||
return protoAddr {
|
||||
protocol: s[0 : idx],
|
||||
address: s[idx+3 : len(s)],
|
||||
}
|
||||
}
|
||||
|
||||
func setupListeners() {
|
||||
for _, listenAddr := range strings.Split(*listenStr, " ") {
|
||||
pa := splitProto(listenAddr)
|
||||
listenAddrs = append(listenAddrs, pa)
|
||||
}
|
||||
}
|
||||
|
||||
func ConfigLoad() {
|
||||
iniflags.Parse()
|
||||
|
||||
@@ -143,4 +170,5 @@ func ConfigLoad() {
|
||||
setupAllowedNetworks()
|
||||
setupAllowedPatterns()
|
||||
setupRemoteAuth()
|
||||
setupListeners()
|
||||
}
|
||||
|
||||
44
config_test.go
Normal file
44
config_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitProto(t *testing.T) {
|
||||
var tests = []struct {
|
||||
input string
|
||||
proto string
|
||||
addr string
|
||||
}{
|
||||
{
|
||||
input: "localhost",
|
||||
proto: "",
|
||||
addr: "localhost",
|
||||
},
|
||||
{
|
||||
input: "tls://my.local.domain",
|
||||
proto: "tls",
|
||||
addr: "my.local.domain",
|
||||
},
|
||||
{
|
||||
input: "starttls://my.local.domain",
|
||||
proto: "starttls",
|
||||
addr: "my.local.domain",
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
testName := test.input
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
pa := splitProto(test.input)
|
||||
if pa.protocol != test.proto {
|
||||
t.Errorf("Testcase %d: Incorrect proto: expected %v, got %v",
|
||||
i, test.proto, pa.protocol)
|
||||
}
|
||||
if pa.address != test.addr {
|
||||
t.Errorf("Testcase %d: Incorrect addr: expected %v, got %v",
|
||||
i, test.addr, pa.address)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
38
main.go
38
main.go
@@ -288,7 +288,9 @@ func main() {
|
||||
var servers []*smtpd.Server
|
||||
|
||||
// Create a server for each desired listen address
|
||||
for _, listenAddr := range strings.Split(*listen, " ") {
|
||||
for _, listen := range listenAddrs {
|
||||
logger := log.WithField("address", listen.address)
|
||||
|
||||
server := &smtpd.Server{
|
||||
Hostname: *hostName,
|
||||
WelcomeMessage: *welcomeMsg,
|
||||
@@ -305,37 +307,31 @@ func main() {
|
||||
var lsnr net.Listener
|
||||
var err error
|
||||
|
||||
if strings.Index(listenAddr, "://") == -1 {
|
||||
log.WithField("address", listenAddr).
|
||||
Info("listening on address")
|
||||
|
||||
lsnr, err = net.Listen("tcp", listenAddr)
|
||||
} else if strings.HasPrefix(listenAddr, "starttls://") {
|
||||
listenAddr = strings.TrimPrefix(listenAddr, "starttls://")
|
||||
switch listen.protocol {
|
||||
case "":
|
||||
logger.Info("listening on address")
|
||||
lsnr, err = net.Listen("tcp", listen.address)
|
||||
|
||||
case "starttls":
|
||||
server.TLSConfig = getTLSConfig()
|
||||
server.ForceTLS = *localForceTLS
|
||||
|
||||
log.WithField("address", listenAddr).
|
||||
Info("listening on address (STARTTLS)")
|
||||
lsnr, err = net.Listen("tcp", listenAddr)
|
||||
} else if strings.HasPrefix(listenAddr, "tls://") {
|
||||
listenAddr = strings.TrimPrefix(listenAddr, "tls://")
|
||||
logger.Info("listening on address (STARTTLS)")
|
||||
lsnr, err = net.Listen("tcp", listen.address)
|
||||
|
||||
case "tls":
|
||||
server.TLSConfig = getTLSConfig()
|
||||
|
||||
log.WithField("address", listenAddr).
|
||||
Info("listening on address (TLS)")
|
||||
lsnr, err = tls.Listen("tcp", listenAddr, server.TLSConfig)
|
||||
} else {
|
||||
log.WithField("address", listenAddr).
|
||||
logger.Info("listening on address (TLS)")
|
||||
lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig)
|
||||
|
||||
default:
|
||||
logger.WithField("protocol", listen.protocol).
|
||||
Fatal("unknown protocol in listen address")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"address": listenAddr,
|
||||
}).WithError(err).Fatal("error starting listener")
|
||||
logger.WithError(err).Fatal("error starting listener")
|
||||
}
|
||||
servers = append(servers, server)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user