forked from drew/smtprelay
Expand allowedUsers email field to support comma-separated and domains (#9)
* Expand allowedUsers email field to support comma-separated and domains Closes #8 * Refactor AuthFetch() to return AuthUser struct Also, this breaks out a parseLine() function which can be easily tested. * Ignore empty addrs after splitting commas This ignores a trailing comma * Add tests for auth parseLine() * Update documentation in smtprelay.ini * Fix bug where addrAllowed() was incorrectly case-sensitive * Update allowedUsers allowed domain format to require leading @ This disambiguates a local user ('john.smith') from a domain ('example.com')
This commit is contained in:
committed by
GitHub
parent
5c2e28ac36
commit
0e8986ca79
62
auth.go
62
auth.go
@@ -13,6 +13,12 @@ var (
|
||||
filename string
|
||||
)
|
||||
|
||||
type AuthUser struct {
|
||||
username string
|
||||
passwordHash string
|
||||
allowedAddresses []string
|
||||
}
|
||||
|
||||
func AuthLoadFile(file string) error {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
@@ -28,50 +34,66 @@ func AuthReady() bool {
|
||||
return (filename != "")
|
||||
}
|
||||
|
||||
// Returns bcrypt-hash, email
|
||||
// email can be empty in which case it is not checked
|
||||
func AuthFetch(username string) (string, string, error) {
|
||||
// Split a string and ignore empty results
|
||||
// https://stackoverflow.com/a/46798310/119527
|
||||
func splitstr(s string, sep rune) []string {
|
||||
return strings.FieldsFunc(s, func(c rune) bool { return c == sep })
|
||||
}
|
||||
|
||||
func parseLine(line string) *AuthUser {
|
||||
parts := strings.Fields(line)
|
||||
|
||||
if len(parts) < 2 || len(parts) > 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
user := AuthUser{
|
||||
username: parts[0],
|
||||
passwordHash: parts[1],
|
||||
allowedAddresses: nil,
|
||||
}
|
||||
|
||||
if len(parts) >= 3 {
|
||||
user.allowedAddresses = splitstr(parts[2], ',')
|
||||
}
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
func AuthFetch(username string) (*AuthUser, error) {
|
||||
if !AuthReady() {
|
||||
return "", "", errors.New("Authentication file not specified. Call LoadFile() first")
|
||||
return nil, errors.New("Authentication file not specified. Call LoadFile() first")
|
||||
}
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
parts := strings.Fields(scanner.Text())
|
||||
|
||||
if len(parts) < 2 || len(parts) > 3 {
|
||||
user := parseLine(scanner.Text())
|
||||
if user == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.ToLower(username) != strings.ToLower(parts[0]) {
|
||||
if strings.ToLower(username) != strings.ToLower(user.username) {
|
||||
continue
|
||||
}
|
||||
|
||||
hash := parts[1]
|
||||
email := ""
|
||||
|
||||
if len(parts) >= 3 {
|
||||
email = parts[2]
|
||||
return user, nil
|
||||
}
|
||||
|
||||
return hash, email, nil
|
||||
}
|
||||
|
||||
return "", "", errors.New("User not found")
|
||||
return nil, errors.New("User not found")
|
||||
}
|
||||
|
||||
func AuthCheckPassword(username string, secret string) error {
|
||||
hash, _, err := AuthFetch(username)
|
||||
user, err := AuthFetch(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) == nil {
|
||||
if bcrypt.CompareHashAndPassword([]byte(user.passwordHash), []byte(secret)) == nil {
|
||||
return nil
|
||||
}
|
||||
return errors.New("Password invalid")
|
||||
|
||||
89
auth_test.go
Normal file
89
auth_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func stringsEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, _ := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestParseLine(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
expectFail bool
|
||||
line string
|
||||
username string
|
||||
addrs []string
|
||||
}{
|
||||
{
|
||||
name: "Empty line",
|
||||
expectFail: true,
|
||||
line: "",
|
||||
},
|
||||
{
|
||||
name: "Too few fields",
|
||||
expectFail: true,
|
||||
line: "joe",
|
||||
},
|
||||
{
|
||||
name: "Too many fields",
|
||||
expectFail: true,
|
||||
line: "joe xxx joe@example.com whatsthis",
|
||||
},
|
||||
{
|
||||
name: "Normal case",
|
||||
line: "joe xxx joe@example.com",
|
||||
username: "joe",
|
||||
addrs: []string{"joe@example.com"},
|
||||
},
|
||||
{
|
||||
name: "No allowed addrs given",
|
||||
line: "joe xxx",
|
||||
username: "joe",
|
||||
addrs: []string{},
|
||||
},
|
||||
{
|
||||
name: "Trailing comma",
|
||||
line: "joe xxx joe@example.com,",
|
||||
username: "joe",
|
||||
addrs: []string{"joe@example.com"},
|
||||
},
|
||||
{
|
||||
name: "Multiple allowed addrs",
|
||||
line: "joe xxx joe@example.com,@foo.example.com",
|
||||
username: "joe",
|
||||
addrs: []string{"joe@example.com", "@foo.example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
user := parseLine(test.line)
|
||||
if user == nil {
|
||||
if !test.expectFail {
|
||||
t.Errorf("parseLine() returned nil unexpectedly")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if user.username != test.username {
|
||||
t.Errorf("Testcase %d: Incorrect username: expected %v, got %v",
|
||||
i, test.username, user.username)
|
||||
}
|
||||
|
||||
if !stringsEqual(user.allowedAddresses, test.addrs) {
|
||||
t.Errorf("Testcase %d: Incorrect addresses: expected %v, got %v",
|
||||
i, test.addrs, user.allowedAddresses)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
47
main.go
47
main.go
@@ -36,15 +36,58 @@ func connectionChecker(peer smtpd.Peer) error {
|
||||
return smtpd.Error{Code: 421, Message: "Denied"}
|
||||
}
|
||||
|
||||
func addrAllowed(addr string, allowedAddrs []string) bool {
|
||||
if allowedAddrs == nil {
|
||||
// If absent, all addresses are allowed
|
||||
return true
|
||||
}
|
||||
|
||||
addr = strings.ToLower(addr)
|
||||
|
||||
// Extract optional domain part
|
||||
domain := ""
|
||||
if idx := strings.LastIndex(addr, "@"); idx != -1 {
|
||||
domain = strings.ToLower(addr[idx+1:])
|
||||
}
|
||||
|
||||
// Test each address from allowedUsers file
|
||||
for _, allowedAddr := range allowedAddrs {
|
||||
allowedAddr = strings.ToLower(allowedAddr)
|
||||
|
||||
// Three cases for allowedAddr format:
|
||||
if idx := strings.Index(allowedAddr, "@"); idx == -1 {
|
||||
// 1. local address (no @) -- must match exactly
|
||||
if allowedAddr == addr {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if idx != 0 {
|
||||
// 2. email address (user@domain.com) -- must match exactly
|
||||
if allowedAddr == addr {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// 3. domain (@domain.com) -- must match addr domain
|
||||
allowedDomain := allowedAddr[idx+1:]
|
||||
if allowedDomain == domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func senderChecker(peer smtpd.Peer, addr string) error {
|
||||
// check sender address from auth file if user is authenticated
|
||||
if *allowedUsers != "" && peer.Username != "" {
|
||||
_, email, err := AuthFetch(peer.Username)
|
||||
user, err := AuthFetch(peer.Username)
|
||||
if err != nil {
|
||||
return smtpd.Error{Code: 451, Message: "Bad sender address"}
|
||||
}
|
||||
|
||||
if email != "" && strings.ToLower(addr) != strings.ToLower(email) {
|
||||
if !addrAllowed(addr, user.allowedAddresses) {
|
||||
return smtpd.Error{Code: 451, Message: "Bad sender address"}
|
||||
}
|
||||
}
|
||||
|
||||
94
main_test.go
Normal file
94
main_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAddrAllowedNoDomain(t *testing.T) {
|
||||
allowedAddrs := []string{"joe@abc.com"}
|
||||
if addrAllowed("bob.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedSingle(t *testing.T) {
|
||||
allowedAddrs := []string{"joe@abc.com"}
|
||||
|
||||
if !addrAllowed("joe@abc.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("bob@abc.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedDifferentCase(t *testing.T) {
|
||||
allowedAddrs := []string{"joe@abc.com"}
|
||||
testAddrs := []string{
|
||||
"joe@ABC.com",
|
||||
"Joe@abc.com",
|
||||
"JOE@abc.com",
|
||||
"JOE@ABC.COM",
|
||||
}
|
||||
for _, addr := range testAddrs {
|
||||
if !addrAllowed(addr, allowedAddrs) {
|
||||
t.Errorf("Address %v not allowed, but should be", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedLocal(t *testing.T) {
|
||||
allowedAddrs := []string{"joe"}
|
||||
|
||||
if !addrAllowed("joe", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("bob", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedMulti(t *testing.T) {
|
||||
allowedAddrs := []string{"joe@abc.com", "bob@def.com"}
|
||||
if !addrAllowed("joe@abc.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if !addrAllowed("bob@def.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("bob@abc.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedSingleDomain(t *testing.T) {
|
||||
allowedAddrs := []string{"@abc.com"}
|
||||
if !addrAllowed("joe@abc.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("joe@def.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddrAllowedMixed(t *testing.T) {
|
||||
allowedAddrs := []string{"app", "app@example.com", "@appsrv.example.com"}
|
||||
if !addrAllowed("app", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if !addrAllowed("app@example.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("ceo@example.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if !addrAllowed("root@appsrv.example.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if !addrAllowed("dev@appsrv.example.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
if addrAllowed("appsrv@example.com", allowedAddrs) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
@@ -37,7 +37,14 @@
|
||||
|
||||
; File which contains username and password used for
|
||||
; authentication before they can send mail.
|
||||
; File format: username bcrypt-hash [email]
|
||||
; File format: username bcrypt-hash [email[,email[,...]]]
|
||||
; username: The SMTP auth username
|
||||
; bcrypt-hash: The bcrypt hash of the pasword (generate with "./hasher password")
|
||||
; email: Comma-separated list of allowed "from" addresses:
|
||||
; - If omitted, user can send from any address
|
||||
; - If @domain.com is given, user can send from any address @domain.com
|
||||
; - Otherwise, email address must match exactly (case-insensitive)
|
||||
; E.g. "app@example.com,@appsrv.example.com"
|
||||
;allowed_users =
|
||||
|
||||
; Relay all mails to this SMTP server
|
||||
|
||||
Reference in New Issue
Block a user