diff --git a/config.go b/config.go index bdcb54e..155069a 100644 --- a/config.go +++ b/config.go @@ -3,8 +3,9 @@ package main import ( "flag" "net" - "regexp" "net/smtp" + "regexp" + "strings" "github.com/vharitonsky/iniflags" "github.com/sirupsen/logrus" @@ -21,7 +22,8 @@ var ( logLevel = flag.String("log_level", "info", "Minimum log level to output") hostName = flag.String("hostname", "localhost.localdomain", "Server hostname") welcomeMsg = flag.String("welcome_msg", "", "Welcome message for SMTP session") - listen = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP") + listenStr = flag.String("listen", "127.0.0.1:25 [::1]:25", "Address and port to listen for incoming SMTP") + listenAddrs = []protoAddr{} localCert = flag.String("local_cert", "", "SSL certificate for STARTTLS/TLS") localKey = flag.String("local_key", "", "SSL private key for STARTTLS/TLS") localForceTLS = flag.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)") @@ -41,6 +43,10 @@ var ( versionInfo = flag.Bool("version", false, "Show version information") ) +func localAuthRequired() bool { + return *allowedUsers != "" +} + func setupAllowedNetworks() { for _, netstr := range splitstr(*allowedNetsStr, ' ') { @@ -130,6 +136,39 @@ func setupRemoteAuth() { } } +type protoAddr struct { + protocol string + address string +} + +func splitProto(s string) protoAddr { + idx := strings.Index(s, "://") + if idx == -1 { + return protoAddr { + address: s, + } + } + return protoAddr { + protocol: s[0 : idx], + address: s[idx+3 : len(s)], + } +} + +func setupListeners() { + for _, listenAddr := range strings.Split(*listenStr, " ") { + pa := splitProto(listenAddr) + + if localAuthRequired() && pa.protocol == "" { + log.WithField("address", pa.address). + Fatal("Local authentication (via allowed_users file) " + + "not allowed with non-TLS listener") + } + + + listenAddrs = append(listenAddrs, pa) + } +} + func ConfigLoad() { iniflags.Parse() @@ -143,4 +182,5 @@ func ConfigLoad() { setupAllowedNetworks() setupAllowedPatterns() setupRemoteAuth() + setupListeners() } diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..602192b --- /dev/null +++ b/config_test.go @@ -0,0 +1,44 @@ +package main + +import ( + "testing" +) + +func TestSplitProto(t *testing.T) { + var tests = []struct { + input string + proto string + addr string + }{ + { + input: "localhost", + proto: "", + addr: "localhost", + }, + { + input: "tls://my.local.domain", + proto: "tls", + addr: "my.local.domain", + }, + { + input: "starttls://my.local.domain", + proto: "starttls", + addr: "my.local.domain", + }, + } + + for i, test := range tests { + testName := test.input + t.Run(testName, func(t *testing.T) { + pa := splitProto(test.input) + if pa.protocol != test.proto { + t.Errorf("Testcase %d: Incorrect proto: expected %v, got %v", + i, test.proto, pa.protocol) + } + if pa.address != test.addr { + t.Errorf("Testcase %d: Incorrect addr: expected %v, got %v", + i, test.addr, pa.address) + } + }) + } +} diff --git a/main.go b/main.go index 522a5fc..ba77aa5 100644 --- a/main.go +++ b/main.go @@ -81,7 +81,7 @@ func addrAllowed(addr string, allowedAddrs []string) bool { func senderChecker(peer smtpd.Peer, addr string) error { // check sender address from auth file if user is authenticated - if *allowedUsers != "" && peer.Username != "" { + if localAuthRequired() && peer.Username != "" { user, err := AuthFetch(peer.Username) if err != nil { // Shouldn't happen: authChecker already validated username+password @@ -276,7 +276,7 @@ func main() { Debug("starting smtprelay") // Load allowed users file - if *allowedUsers != "" { + if localAuthRequired() { err := AuthLoadFile(*allowedUsers) if err != nil { log.WithField("file", *allowedUsers). @@ -288,7 +288,9 @@ func main() { var servers []*smtpd.Server // Create a server for each desired listen address - for _, listenAddr := range strings.Split(*listen, " ") { + for _, listen := range listenAddrs { + logger := log.WithField("address", listen.address) + server := &smtpd.Server{ Hostname: *hostName, WelcomeMessage: *welcomeMsg, @@ -298,44 +300,38 @@ func main() { Handler: mailHandler, } - if *allowedUsers != "" { + if localAuthRequired() { server.Authenticator = authChecker } var lsnr net.Listener var err error - if strings.Index(listenAddr, "://") == -1 { - log.WithField("address", listenAddr). - Info("listening on address") - - lsnr, err = net.Listen("tcp", listenAddr) - } else if strings.HasPrefix(listenAddr, "starttls://") { - listenAddr = strings.TrimPrefix(listenAddr, "starttls://") + switch listen.protocol { + case "": + logger.Info("listening on address") + lsnr, err = net.Listen("tcp", listen.address) + case "starttls": server.TLSConfig = getTLSConfig() server.ForceTLS = *localForceTLS - log.WithField("address", listenAddr). - Info("listening on address (STARTTLS)") - lsnr, err = net.Listen("tcp", listenAddr) - } else if strings.HasPrefix(listenAddr, "tls://") { - listenAddr = strings.TrimPrefix(listenAddr, "tls://") + logger.Info("listening on address (STARTTLS)") + lsnr, err = net.Listen("tcp", listen.address) + case "tls": server.TLSConfig = getTLSConfig() - log.WithField("address", listenAddr). - Info("listening on address (TLS)") - lsnr, err = tls.Listen("tcp", listenAddr, server.TLSConfig) - } else { - log.WithField("address", listenAddr). + logger.Info("listening on address (TLS)") + lsnr, err = tls.Listen("tcp", listen.address, server.TLSConfig) + + default: + logger.WithField("protocol", listen.protocol). Fatal("unknown protocol in listen address") } if err != nil { - log.WithFields(logrus.Fields{ - "address": listenAddr, - }).WithError(err).Fatal("error starting listener") + logger.WithError(err).Fatal("error starting listener") } servers = append(servers, server)