Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/controller/user_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {

if search.Type == model.UserLDAP {
sessionCookie.Provider = "ldap"
if search.Email != "" {
sessionCookie.Email = search.Email
}
}

tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
Expand Down
10 changes: 9 additions & 1 deletion internal/middleware/context_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model

userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)

userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
if search.Email != "" {
userContext.LDAP.Email = search.Email
}
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)

Expand Down Expand Up @@ -240,11 +244,15 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{
Username: username,
Name: utils.Capitalize(username),
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
},
Groups: user.Groups,
}
userContext.Provider = model.ProviderLDAP

userContext.LDAP.Email = utils.CompileUserEmail(username, m.config.CookieDomain)
if search.Email != "" {
userContext.LDAP.Email = search.Email
}
}

userContext.Authenticated = true
Expand Down
1 change: 1 addition & 0 deletions internal/model/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ type LocalUser struct {

type UserSearch struct {
Username string
Email string // used for LDAP, we can't throw it to LDAPUser because it would need another cache or an LDAP lookup every time
Type UserSearchType
}
3 changes: 2 additions & 1 deletion internal/service/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,15 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}

if auth.ldap.IsConfigured() {
userDN, err := auth.ldap.GetUserDN(username)
userDN, email, err := auth.ldap.GetUserInfo(username)

if err != nil {
return nil, fmt.Errorf("failed to get ldap user: %w", err)
}

return &model.UserSearch{
Username: userDN,
Email: email,
Type: model.UserLDAP,
}, nil
}
Expand Down
13 changes: 6 additions & 7 deletions internal/service/ldap_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,15 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
return ldap.conn, nil
}

func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
func (ldap *LdapService) GetUserInfo(username string) (dn string, email string, err error) {
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)

searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
[]string{"dn", "mail"},
nil,
)

Expand All @@ -161,15 +160,15 @@ func (ldap *LdapService) GetUserDN(username string) (string, error) {

searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return "", err
return "", "", err
}

if len(searchResult.Entries) != 1 {
return "", fmt.Errorf("multiple or no entries found for user %s", username)
return "", "", fmt.Errorf("multiple or no entries found for user %s", username)
}

userDN := searchResult.Entries[0].DN
return userDN, nil
entry := searchResult.Entries[0]
return entry.DN, entry.GetAttributeValue("mail"), nil
}

func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
Expand Down