2
0
forked from drew/smtprelay

Move parsing of "allowed_nets" out to ConfigLoad()

This has several benefits:
- Configuration errors are caught at startup rather than upon a connection
- connectionChecker() has less work to do for each connection
This commit is contained in:
Jonathon Reinhart
2021-03-13 02:42:30 -05:00
parent 4036213dd5
commit ef3f9c8ea0
2 changed files with 20 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ package main
import (
"flag"
"net"
"github.com/vharitonsky/iniflags"
)
@@ -21,7 +22,8 @@ var (
localCert = flag.String("local_cert", "", "SSL certificate for STARTTLS/TLS")
localKey = flag.String("local_key", "", "SSL private key for STARTTLS/TLS")
localForceTLS = flag.Bool("local_forcetls", false, "Force STARTTLS (needs local_cert and local_key)")
allowedNets = flag.String("allowed_nets", "127.0.0.1/8 ::1/128", "Networks allowed to send mails")
allowedNetsStr = flag.String("allowed_nets", "127.0.0.1/8 ::1/128", "Networks allowed to send mails")
allowedNets = []*net.IPNet{}
allowedSender = flag.String("allowed_sender", "", "Regular expression for valid FROM EMail addresses")
allowedRecipients = flag.String("allowed_recipients", "", "Regular expression for valid TO EMail addresses")
allowedUsers = flag.String("allowed_users", "", "Path to file with valid users/passwords")
@@ -33,6 +35,20 @@ var (
versionInfo = flag.Bool("version", false, "Show version information")
)
func setupAllowedNetworks() {
for _, netstr := range splitstr(*allowedNetsStr, ' ') {
_, allowedNet, err := net.ParseCIDR(netstr)
if err != nil {
log.WithField("netstr", netstr).
WithError(err).
Fatal("Invalid CIDR notation in allowed_nets")
}
allowedNets = append(allowedNets, allowedNet)
}
}
func ConfigLoad() {
iniflags.Parse()
@@ -42,4 +58,6 @@ func ConfigLoad() {
if (*remoteHost == "") {
log.Warn("remote_host not set; mail will not be forwarded!")
}
setupAllowedNetworks()
}

View File

@@ -20,11 +20,8 @@ func connectionChecker(peer smtpd.Peer) error {
// This can't panic because we only have TCP listeners
peerIP := peer.Addr.(*net.TCPAddr).IP
nets := strings.Split(*allowedNets, " ")
for i := range nets {
_, allowedNet, _ := net.ParseCIDR(nets[i])
for _, allowedNet := range allowedNets {
if allowedNet.Contains(peerIP) {
return nil
}