267 lines
7.3 KiB
Go
267 lines
7.3 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"flag"
|
|
"git.thequux.com/thequux/ipasso/sso-proxy/backend"
|
|
"git.thequux.com/thequux/ipasso/util"
|
|
"git.thequux.com/thequux/ipasso/util/genpool"
|
|
"git.thequux.com/thequux/ipasso/util/startup"
|
|
"github.com/go-ldap/ldap/gssapi"
|
|
"github.com/go-ldap/ldap/v3"
|
|
"github.com/jcmturner/gokrb5/v8/config"
|
|
"go.uber.org/zap"
|
|
"math/rand"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
ldapServerUrl = flag.String("ldap-url", "", "URL at which LDAP server can be reached")
|
|
ldapRootDN = flag.String("rootDN", "", "LDAP Root DN. Defaults to dc=host,dc=tld based on -domain")
|
|
keytab = flag.String("keytab", "ipasso.keytab", "Keytab file used to authenticate server")
|
|
krb5Principal = flag.String("krb5-principal", "", "Default kerberos principal; default HTTP/sso.<domain>")
|
|
krb5realm = flag.String("krb5-realm", "", "Kerberos realm. Default based on krb5 config")
|
|
krb5conf = flag.String("krb5-conf", util.GetEnvDefault("KRB5_CONFIG", "/etc/krb5.conf"), "Config file for kerberos")
|
|
|
|
gssapiClient *gssapi.Client
|
|
|
|
ldapServerPool []ldapServerHost
|
|
ldapPool genpool.Pool[ldap.Conn]
|
|
|
|
ldapUserBase string
|
|
ldapServiceDn string
|
|
ldapHostDn string
|
|
)
|
|
|
|
var (
|
|
ErrNoValidServer = errors.New("no valid server")
|
|
ldapRootLogger, ldapPoolLogger *zap.Logger
|
|
)
|
|
|
|
func init() {
|
|
startup.Logger.Add(func() {
|
|
ldapRootLogger = zap.L().Named("ldap")
|
|
ldapPoolLogger = ldapRootLogger.Named("pool")
|
|
})
|
|
|
|
startup.PostFlags.Add(func() {
|
|
serverUrl, err := url.Parse(*ldapServerUrl)
|
|
if err != nil {
|
|
ldapRootLogger.Fatal("Invalid LDAP server url", zap.String("url", *ldapServerUrl), zap.Error(err))
|
|
}
|
|
ldapServerPool = []ldapServerHost{
|
|
{
|
|
SPN: "ldap/" + serverUrl.Hostname(),
|
|
Url: *ldapServerUrl,
|
|
Weight: 1,
|
|
},
|
|
}
|
|
|
|
if *ldapRootDN == "" {
|
|
rootDnElements := make([]string, 0, 4)
|
|
for _, v := range strings.Split(*domain, ".") {
|
|
rootDnElements = append(rootDnElements, "dc="+v)
|
|
}
|
|
*ldapRootDN = strings.Join(rootDnElements, ",")
|
|
ldapRootLogger.Debug("Configured LDAP rootDN", zap.String("rootDN", *ldapRootDN))
|
|
}
|
|
|
|
krb5Config, err := config.Load(*krb5conf)
|
|
var realmSource string = ""
|
|
if err != nil {
|
|
ldapRootLogger.Warn("Failed to load config", zap.Error(err))
|
|
} else {
|
|
if *krb5realm == "" {
|
|
*krb5realm = krb5Config.LibDefaults.DefaultRealm
|
|
realmSource = *krb5conf
|
|
}
|
|
}
|
|
if *krb5realm == "" {
|
|
// default from domain
|
|
*krb5realm = strings.ToUpper(*domain)
|
|
realmSource = "domain"
|
|
}
|
|
|
|
if realmSource != "" {
|
|
ldapRootLogger.Debug("Configured KRB5 realm", zap.String("source", realmSource), zap.String("realm", *krb5realm))
|
|
}
|
|
|
|
if *krb5Principal == "" {
|
|
*krb5Principal = "HTTP/sso." + *domain
|
|
ldapRootLogger.Debug("Configured local kerberos principal", zap.String("principal", *krb5Principal))
|
|
}
|
|
|
|
gssapiClient, err := gssapi.NewClientWithKeytab(*krb5Principal, *krb5realm, *keytab, *krb5conf)
|
|
if err != nil {
|
|
ldapRootLogger.Fatal("Failed to initialize kerberos", zap.Error(err))
|
|
}
|
|
|
|
// Create the LDAP pool
|
|
ldapPool = genpool.NewPool[ldap.Conn](&ldapPoolManager{gssapiClient: gssapiClient}, 5)
|
|
})
|
|
startup.Startup.Add(func() {
|
|
// Test the pool...
|
|
conn, err := ldapPool.Get()
|
|
if err != nil {
|
|
ldapPoolLogger.Warn("Failed to connect to LDAP server at startup", zap.Error(err))
|
|
} else {
|
|
defer ldapPool.Put(conn)
|
|
whoami, err := conn.WhoAmI(nil)
|
|
if err != nil {
|
|
ldapPoolLogger.Warn("Failed to call whoami at startup", zap.Error(err))
|
|
}
|
|
ldapPoolLogger.Info("Successfully connected to LDAP", zap.String("authzId", whoami.AuthzID))
|
|
}
|
|
})
|
|
}
|
|
|
|
type (
|
|
ldapServerHost struct {
|
|
SPN string
|
|
Url string
|
|
Weight int
|
|
}
|
|
|
|
ldapPoolManager struct {
|
|
// TODO: Fill this with results from a SRV request...
|
|
gssapiClient *gssapi.Client
|
|
}
|
|
)
|
|
|
|
func selectServer() *ldapServerHost {
|
|
serverSet := ldapServerPool
|
|
var selectedServer *ldapServerHost
|
|
var weight = 0
|
|
for _, server := range serverSet {
|
|
if server.Weight <= 0 {
|
|
continue
|
|
}
|
|
weight += server.Weight
|
|
if rand.Intn(weight) < server.Weight {
|
|
server := server // copy the server object
|
|
selectedServer = &server
|
|
}
|
|
}
|
|
if weight > 0 && selectedServer == nil {
|
|
ldapPoolLogger.DPanic("Failed to select a server when one was on offer")
|
|
}
|
|
return selectedServer
|
|
}
|
|
|
|
func (l *ldapPoolManager) Destroy(conn *ldap.Conn) {
|
|
err := conn.Close()
|
|
if err != nil {
|
|
ldapPoolLogger.Warn("Failed to close LDAP connection",
|
|
zap.Error(err),
|
|
)
|
|
} else {
|
|
ldapPoolLogger.Debug("Closed ldap connection")
|
|
}
|
|
}
|
|
|
|
func (l *ldapPoolManager) Create() (*ldap.Conn, error) {
|
|
|
|
server := selectServer()
|
|
if server == nil {
|
|
return nil, ErrNoValidServer
|
|
}
|
|
conn, err := ldap.DialURL(server.Url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := conn.GSSAPIBind(l.gssapiClient, server.SPN, ""); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (l *ldapPoolManager) Validate(conn *ldap.Conn) bool {
|
|
_, err := conn.WhoAmI([]ldap.Control{})
|
|
return err == nil
|
|
}
|
|
|
|
func buildSessionCache(b *backend.Session, entry *ldap.Entry) (cache backend.SessionCache, err error) {
|
|
if entry == nil {
|
|
entry, err = getUserByDn(b.LdapDN)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
//ldapRootLogger.Info("Building session cache", zap.Any("entry", entry))
|
|
|
|
cache.Valid = true
|
|
for _, attr := range entry.Attributes {
|
|
if attr.Name == "displayName" && len(attr.Values) > 0 {
|
|
cache.DisplayName = attr.Values[0]
|
|
} else if attr.Name == "givenName" && len(attr.Values) > 0 {
|
|
cache.GivenName = strings.Join(attr.Values, " ")
|
|
} else if attr.Name == "sn" {
|
|
cache.SurName = strings.Join(attr.Values, " ")
|
|
} else if attr.Name == "mail" && len(attr.Values) > 0 {
|
|
cache.Email = attr.Values[0]
|
|
} else if attr.Name == "memberOf" {
|
|
grpSuffix := "cn=groups,cn=accounts," + *ldapRootDN
|
|
for _, grp := range attr.Values {
|
|
gnames := strings.SplitN(grp, ",", 2)
|
|
if gnames[1] != grpSuffix {
|
|
continue
|
|
}
|
|
cn, found := strings.CutPrefix(gnames[0], "cn=")
|
|
if found {
|
|
cache.Groups = append(cache.Groups, cn)
|
|
} else {
|
|
ldapRootLogger.Warn("Unexpected group name", zap.String("gdn", grp), zap.String("dn", entry.DN))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
var interestingUserAttr = []string{"displayName", "givenName", "sn", "uid", "mail", "memberOf", "serverHostName"}
|
|
|
|
func ldapSearchSingle(req *ldap.SearchRequest) (*ldap.Entry, error) {
|
|
server, err := ldapPool.Get()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer ldapPool.Put(server)
|
|
|
|
searchRes, err := server.Search(req)
|
|
if err != nil {
|
|
ldapRootLogger.Warn("Failed LDAP search", zap.Any("req", req), zap.Error(err))
|
|
return nil, err
|
|
}
|
|
nEntries := len(searchRes.Entries)
|
|
if nEntries == 0 {
|
|
ldapRootLogger.Info("No entries found for search", zap.String("filter", req.Filter))
|
|
return nil, ErrInvalidResult
|
|
} else if nEntries > 1 {
|
|
ldapRootLogger.Info("Multiple entries found for search", zap.String("filter", req.Filter))
|
|
return nil, ErrInvalidResult
|
|
}
|
|
|
|
return searchRes.Entries[0], nil
|
|
|
|
}
|
|
|
|
func getUserByPrincipal(principal string) (*ldap.Entry, error) {
|
|
return ldapSearchSingle(&ldap.SearchRequest{
|
|
Filter: "(&(krbprincipalname=" + principal + ")(objectClass=inetorgperson))",
|
|
BaseDN: *ldapRootDN,
|
|
Scope: ldap.ScopeWholeSubtree,
|
|
Attributes: interestingUserAttr,
|
|
})
|
|
}
|
|
|
|
func getUserByDn(dn string) (*ldap.Entry, error) {
|
|
return ldapSearchSingle(&ldap.SearchRequest{
|
|
BaseDN: dn,
|
|
Scope: ldap.ScopeBaseObject,
|
|
Attributes: interestingUserAttr,
|
|
})
|
|
}
|