From 0e8986ca79f94792343d93027ff28352bd1c212c Mon Sep 17 00:00:00 2001 From: Jonathon Reinhart Date: Sun, 14 Feb 2021 16:16:18 -0500 Subject: [PATCH] Expand allowedUsers email field to support comma-separated and domains (#9) * Expand allowedUsers email field to support comma-separated and domains Closes #8 * Refactor AuthFetch() to return AuthUser struct Also, this breaks out a parseLine() function which can be easily tested. * Ignore empty addrs after splitting commas This ignores a trailing comma * Add tests for auth parseLine() * Update documentation in smtprelay.ini * Fix bug where addrAllowed() was incorrectly case-sensitive * Update allowedUsers allowed domain format to require leading @ This disambiguates a local user ('john.smith') from a domain ('example.com') --- auth.go | 62 ++++++++++++++++++++++----------- auth_test.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 47 ++++++++++++++++++++++++-- main_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ smtprelay.ini | 9 ++++- 5 files changed, 278 insertions(+), 23 deletions(-) create mode 100644 auth_test.go create mode 100644 main_test.go diff --git a/auth.go b/auth.go index fcde86e..2b7517f 100644 --- a/auth.go +++ b/auth.go @@ -13,6 +13,12 @@ var ( filename string ) +type AuthUser struct { + username string + passwordHash string + allowedAddresses []string +} + func AuthLoadFile(file string) error { f, err := os.Open(file) if err != nil { @@ -28,50 +34,66 @@ func AuthReady() bool { return (filename != "") } -// Returns bcrypt-hash, email -// email can be empty in which case it is not checked -func AuthFetch(username string) (string, string, error) { +// Split a string and ignore empty results +// https://stackoverflow.com/a/46798310/119527 +func splitstr(s string, sep rune) []string { + return strings.FieldsFunc(s, func(c rune) bool { return c == sep }) +} + +func parseLine(line string) *AuthUser { + parts := strings.Fields(line) + + if len(parts) < 2 || len(parts) > 3 { + return nil + } + + user := AuthUser{ + username: parts[0], + passwordHash: parts[1], + allowedAddresses: nil, + } + + if len(parts) >= 3 { + user.allowedAddresses = splitstr(parts[2], ',') + } + + return &user +} + +func AuthFetch(username string) (*AuthUser, error) { if !AuthReady() { - return "", "", errors.New("Authentication file not specified. Call LoadFile() first") + return nil, errors.New("Authentication file not specified. Call LoadFile() first") } file, err := os.Open(filename) if err != nil { - return "", "", err + return nil, err } defer file.Close() scanner := bufio.NewScanner(file) for scanner.Scan() { - parts := strings.Fields(scanner.Text()) - - if len(parts) < 2 || len(parts) > 3 { + user := parseLine(scanner.Text()) + if user == nil { continue } - if strings.ToLower(username) != strings.ToLower(parts[0]) { + if strings.ToLower(username) != strings.ToLower(user.username) { continue } - hash := parts[1] - email := "" - - if len(parts) >= 3 { - email = parts[2] - } - - return hash, email, nil + return user, nil } - return "", "", errors.New("User not found") + return nil, errors.New("User not found") } func AuthCheckPassword(username string, secret string) error { - hash, _, err := AuthFetch(username) + user, err := AuthFetch(username) if err != nil { return err } - if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) == nil { + if bcrypt.CompareHashAndPassword([]byte(user.passwordHash), []byte(secret)) == nil { return nil } return errors.New("Password invalid") diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..82dd452 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "testing" +) + +func stringsEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i, _ := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestParseLine(t *testing.T) { + var tests = []struct { + name string + expectFail bool + line string + username string + addrs []string + }{ + { + name: "Empty line", + expectFail: true, + line: "", + }, + { + name: "Too few fields", + expectFail: true, + line: "joe", + }, + { + name: "Too many fields", + expectFail: true, + line: "joe xxx joe@example.com whatsthis", + }, + { + name: "Normal case", + line: "joe xxx joe@example.com", + username: "joe", + addrs: []string{"joe@example.com"}, + }, + { + name: "No allowed addrs given", + line: "joe xxx", + username: "joe", + addrs: []string{}, + }, + { + name: "Trailing comma", + line: "joe xxx joe@example.com,", + username: "joe", + addrs: []string{"joe@example.com"}, + }, + { + name: "Multiple allowed addrs", + line: "joe xxx joe@example.com,@foo.example.com", + username: "joe", + addrs: []string{"joe@example.com", "@foo.example.com"}, + }, + } + + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + user := parseLine(test.line) + if user == nil { + if !test.expectFail { + t.Errorf("parseLine() returned nil unexpectedly") + } + return + } + + if user.username != test.username { + t.Errorf("Testcase %d: Incorrect username: expected %v, got %v", + i, test.username, user.username) + } + + if !stringsEqual(user.allowedAddresses, test.addrs) { + t.Errorf("Testcase %d: Incorrect addresses: expected %v, got %v", + i, test.addrs, user.allowedAddresses) + } + }) + } +} diff --git a/main.go b/main.go index 89b74f9..e045325 100644 --- a/main.go +++ b/main.go @@ -36,15 +36,58 @@ func connectionChecker(peer smtpd.Peer) error { return smtpd.Error{Code: 421, Message: "Denied"} } +func addrAllowed(addr string, allowedAddrs []string) bool { + if allowedAddrs == nil { + // If absent, all addresses are allowed + return true + } + + addr = strings.ToLower(addr) + + // Extract optional domain part + domain := "" + if idx := strings.LastIndex(addr, "@"); idx != -1 { + domain = strings.ToLower(addr[idx+1:]) + } + + // Test each address from allowedUsers file + for _, allowedAddr := range allowedAddrs { + allowedAddr = strings.ToLower(allowedAddr) + + // Three cases for allowedAddr format: + if idx := strings.Index(allowedAddr, "@"); idx == -1 { + // 1. local address (no @) -- must match exactly + if allowedAddr == addr { + return true + } + } else { + if idx != 0 { + // 2. email address (user@domain.com) -- must match exactly + if allowedAddr == addr { + return true + } + } else { + // 3. domain (@domain.com) -- must match addr domain + allowedDomain := allowedAddr[idx+1:] + if allowedDomain == domain { + return true + } + } + } + } + + return false +} + func senderChecker(peer smtpd.Peer, addr string) error { // check sender address from auth file if user is authenticated if *allowedUsers != "" && peer.Username != "" { - _, email, err := AuthFetch(peer.Username) + user, err := AuthFetch(peer.Username) if err != nil { return smtpd.Error{Code: 451, Message: "Bad sender address"} } - if email != "" && strings.ToLower(addr) != strings.ToLower(email) { + if !addrAllowed(addr, user.allowedAddresses) { return smtpd.Error{Code: 451, Message: "Bad sender address"} } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..6d56b55 --- /dev/null +++ b/main_test.go @@ -0,0 +1,94 @@ +package main + +import ( + "testing" +) + +func TestAddrAllowedNoDomain(t *testing.T) { + allowedAddrs := []string{"joe@abc.com"} + if addrAllowed("bob.com", allowedAddrs) { + t.FailNow() + } +} + +func TestAddrAllowedSingle(t *testing.T) { + allowedAddrs := []string{"joe@abc.com"} + + if !addrAllowed("joe@abc.com", allowedAddrs) { + t.FailNow() + } + if addrAllowed("bob@abc.com", allowedAddrs) { + t.FailNow() + } +} + +func TestAddrAllowedDifferentCase(t *testing.T) { + allowedAddrs := []string{"joe@abc.com"} + testAddrs := []string{ + "joe@ABC.com", + "Joe@abc.com", + "JOE@abc.com", + "JOE@ABC.COM", + } + for _, addr := range testAddrs { + if !addrAllowed(addr, allowedAddrs) { + t.Errorf("Address %v not allowed, but should be", addr) + } + } +} + +func TestAddrAllowedLocal(t *testing.T) { + allowedAddrs := []string{"joe"} + + if !addrAllowed("joe", allowedAddrs) { + t.FailNow() + } + if addrAllowed("bob", allowedAddrs) { + t.FailNow() + } +} + +func TestAddrAllowedMulti(t *testing.T) { + allowedAddrs := []string{"joe@abc.com", "bob@def.com"} + if !addrAllowed("joe@abc.com", allowedAddrs) { + t.FailNow() + } + if !addrAllowed("bob@def.com", allowedAddrs) { + t.FailNow() + } + if addrAllowed("bob@abc.com", allowedAddrs) { + t.FailNow() + } +} + +func TestAddrAllowedSingleDomain(t *testing.T) { + allowedAddrs := []string{"@abc.com"} + if !addrAllowed("joe@abc.com", allowedAddrs) { + t.FailNow() + } + if addrAllowed("joe@def.com", allowedAddrs) { + t.FailNow() + } +} + +func TestAddrAllowedMixed(t *testing.T) { + allowedAddrs := []string{"app", "app@example.com", "@appsrv.example.com"} + if !addrAllowed("app", allowedAddrs) { + t.FailNow() + } + if !addrAllowed("app@example.com", allowedAddrs) { + t.FailNow() + } + if addrAllowed("ceo@example.com", allowedAddrs) { + t.FailNow() + } + if !addrAllowed("root@appsrv.example.com", allowedAddrs) { + t.FailNow() + } + if !addrAllowed("dev@appsrv.example.com", allowedAddrs) { + t.FailNow() + } + if addrAllowed("appsrv@example.com", allowedAddrs) { + t.FailNow() + } +} diff --git a/smtprelay.ini b/smtprelay.ini index 8cfdcfb..5417451 100644 --- a/smtprelay.ini +++ b/smtprelay.ini @@ -37,7 +37,14 @@ ; File which contains username and password used for ; authentication before they can send mail. -; File format: username bcrypt-hash [email] +; File format: username bcrypt-hash [email[,email[,...]]] +; username: The SMTP auth username +; bcrypt-hash: The bcrypt hash of the pasword (generate with "./hasher password") +; email: Comma-separated list of allowed "from" addresses: +; - If omitted, user can send from any address +; - If @domain.com is given, user can send from any address @domain.com +; - Otherwise, email address must match exactly (case-insensitive) +; E.g. "app@example.com,@appsrv.example.com" ;allowed_users = ; Relay all mails to this SMTP server