Files
ipasso/app/ipasso/ldap.go

267 lines
7.3 KiB
Go

package main
import (
"errors"
"flag"
"git.thequux.com/thequux/ipasso/sso-proxy/backend"
"git.thequux.com/thequux/ipasso/util"
"git.thequux.com/thequux/ipasso/util/genpool"
"git.thequux.com/thequux/ipasso/util/startup"
"github.com/go-ldap/ldap/gssapi"
"github.com/go-ldap/ldap/v3"
"github.com/jcmturner/gokrb5/v8/config"
"go.uber.org/zap"
"math/rand"
"net/url"
"strings"
)
var (
ldapServerUrl = flag.String("ldap-url", "", "URL at which LDAP server can be reached")
ldapRootDN = flag.String("rootDN", "", "LDAP Root DN. Defaults to dc=host,dc=tld based on -domain")
keytab = flag.String("keytab", "ipasso.keytab", "Keytab file used to authenticate server")
krb5Principal = flag.String("krb5-principal", "", "Default kerberos principal; default HTTP/sso.<domain>")
krb5realm = flag.String("krb5-realm", "", "Kerberos realm. Default based on krb5 config")
krb5conf = flag.String("krb5-conf", util.GetEnvDefault("KRB5_CONFIG", "/etc/krb5.conf"), "Config file for kerberos")
gssapiClient *gssapi.Client
ldapServerPool []ldapServerHost
ldapPool genpool.Pool[ldap.Conn]
ldapUserBase string
ldapServiceDn string
ldapHostDn string
)
var (
ErrNoValidServer = errors.New("no valid server")
ldapRootLogger, ldapPoolLogger *zap.Logger
)
func init() {
startup.Logger.Add(func() {
ldapRootLogger = zap.L().Named("ldap")
ldapPoolLogger = ldapRootLogger.Named("pool")
})
startup.PostFlags.Add(func() {
serverUrl, err := url.Parse(*ldapServerUrl)
if err != nil {
ldapRootLogger.Fatal("Invalid LDAP server url", zap.String("url", *ldapServerUrl), zap.Error(err))
}
ldapServerPool = []ldapServerHost{
{
SPN: "ldap/" + serverUrl.Hostname(),
Url: *ldapServerUrl,
Weight: 1,
},
}
if *ldapRootDN == "" {
rootDnElements := make([]string, 0, 4)
for _, v := range strings.Split(*domain, ".") {
rootDnElements = append(rootDnElements, "dc="+v)
}
*ldapRootDN = strings.Join(rootDnElements, ",")
ldapRootLogger.Debug("Configured LDAP rootDN", zap.String("rootDN", *ldapRootDN))
}
krb5Config, err := config.Load(*krb5conf)
var realmSource string = ""
if err != nil {
ldapRootLogger.Warn("Failed to load config", zap.Error(err))
} else {
if *krb5realm == "" {
*krb5realm = krb5Config.LibDefaults.DefaultRealm
realmSource = *krb5conf
}
}
if *krb5realm == "" {
// default from domain
*krb5realm = strings.ToUpper(*domain)
realmSource = "domain"
}
if realmSource != "" {
ldapRootLogger.Debug("Configured KRB5 realm", zap.String("source", realmSource), zap.String("realm", *krb5realm))
}
if *krb5Principal == "" {
*krb5Principal = "HTTP/sso." + *domain
ldapRootLogger.Debug("Configured local kerberos principal", zap.String("principal", *krb5Principal))
}
gssapiClient, err := gssapi.NewClientWithKeytab(*krb5Principal, *krb5realm, *keytab, *krb5conf)
if err != nil {
ldapRootLogger.Fatal("Failed to initialize kerberos", zap.Error(err))
}
// Create the LDAP pool
ldapPool = genpool.NewPool[ldap.Conn](&ldapPoolManager{gssapiClient: gssapiClient}, 5)
})
startup.Startup.Add(func() {
// Test the pool...
conn, err := ldapPool.Get()
if err != nil {
ldapPoolLogger.Warn("Failed to connect to LDAP server at startup", zap.Error(err))
} else {
defer ldapPool.Put(conn)
whoami, err := conn.WhoAmI(nil)
if err != nil {
ldapPoolLogger.Warn("Failed to call whoami at startup", zap.Error(err))
}
ldapPoolLogger.Info("Successfully connected to LDAP", zap.String("authzId", whoami.AuthzID))
}
})
}
type (
ldapServerHost struct {
SPN string
Url string
Weight int
}
ldapPoolManager struct {
// TODO: Fill this with results from a SRV request...
gssapiClient *gssapi.Client
}
)
func selectServer() *ldapServerHost {
serverSet := ldapServerPool
var selectedServer *ldapServerHost
var weight = 0
for _, server := range serverSet {
if server.Weight <= 0 {
continue
}
weight += server.Weight
if rand.Intn(weight) < server.Weight {
server := server // copy the server object
selectedServer = &server
}
}
if weight > 0 && selectedServer == nil {
ldapPoolLogger.DPanic("Failed to select a server when one was on offer")
}
return selectedServer
}
func (l *ldapPoolManager) Destroy(conn *ldap.Conn) {
err := conn.Close()
if err != nil {
ldapPoolLogger.Warn("Failed to close LDAP connection",
zap.Error(err),
)
} else {
ldapPoolLogger.Debug("Closed ldap connection")
}
}
func (l *ldapPoolManager) Create() (*ldap.Conn, error) {
server := selectServer()
if server == nil {
return nil, ErrNoValidServer
}
conn, err := ldap.DialURL(server.Url)
if err != nil {
return nil, err
}
if err := conn.GSSAPIBind(l.gssapiClient, server.SPN, ""); err != nil {
return nil, err
}
return conn, nil
}
func (l *ldapPoolManager) Validate(conn *ldap.Conn) bool {
_, err := conn.WhoAmI([]ldap.Control{})
return err == nil
}
func buildSessionCache(b *backend.Session, entry *ldap.Entry) (cache backend.SessionCache, err error) {
if entry == nil {
entry, err = getUserByDn(b.LdapDN)
if err != nil {
return
}
}
//ldapRootLogger.Info("Building session cache", zap.Any("entry", entry))
cache.Valid = true
for _, attr := range entry.Attributes {
if attr.Name == "displayName" && len(attr.Values) > 0 {
cache.DisplayName = attr.Values[0]
} else if attr.Name == "givenName" && len(attr.Values) > 0 {
cache.GivenName = strings.Join(attr.Values, " ")
} else if attr.Name == "sn" {
cache.SurName = strings.Join(attr.Values, " ")
} else if attr.Name == "mail" && len(attr.Values) > 0 {
cache.Email = attr.Values[0]
} else if attr.Name == "memberOf" {
grpSuffix := "cn=groups,cn=accounts," + *ldapRootDN
for _, grp := range attr.Values {
gnames := strings.SplitN(grp, ",", 2)
if gnames[1] != grpSuffix {
continue
}
cn, found := strings.CutPrefix(gnames[0], "cn=")
if found {
cache.Groups = append(cache.Groups, cn)
} else {
ldapRootLogger.Warn("Unexpected group name", zap.String("gdn", grp), zap.String("dn", entry.DN))
}
}
}
}
return
}
var interestingUserAttr = []string{"displayName", "givenName", "sn", "uid", "mail", "memberOf", "serverHostName"}
func ldapSearchSingle(req *ldap.SearchRequest) (*ldap.Entry, error) {
server, err := ldapPool.Get()
if err != nil {
return nil, err
}
defer ldapPool.Put(server)
searchRes, err := server.Search(req)
if err != nil {
ldapRootLogger.Warn("Failed LDAP search", zap.Any("req", req), zap.Error(err))
return nil, err
}
nEntries := len(searchRes.Entries)
if nEntries == 0 {
ldapRootLogger.Info("No entries found for search", zap.String("filter", req.Filter))
return nil, ErrInvalidResult
} else if nEntries > 1 {
ldapRootLogger.Info("Multiple entries found for search", zap.String("filter", req.Filter))
return nil, ErrInvalidResult
}
return searchRes.Entries[0], nil
}
func getUserByPrincipal(principal string) (*ldap.Entry, error) {
return ldapSearchSingle(&ldap.SearchRequest{
Filter: "(&(krbprincipalname=" + principal + ")(objectClass=inetorgperson))",
BaseDN: *ldapRootDN,
Scope: ldap.ScopeWholeSubtree,
Attributes: interestingUserAttr,
})
}
func getUserByDn(dn string) (*ldap.Entry, error) {
return ldapSearchSingle(&ldap.SearchRequest{
BaseDN: dn,
Scope: ldap.ScopeBaseObject,
Attributes: interestingUserAttr,
})
}