diff --git a/README.md b/README.md index 867aff8d..802ff76a 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,14 @@ $ chisel server --help authfile with {"": [""]}. If unset, it will use the environment variable AUTH. + --authurl, An optional URL to an external authentication service. + On each connection attempt chisel POSTs {"username": "...", "password": "..."} + as JSON to this URL. A 200 response must return a JSON array of address + regexes (in the same format as the values in --authfile) to grant access; + any non-200 response denies access. Use [""] or ["*"] to allow all addresses. + Supports all address-matching functionality of --authfile. Cannot be + combined with --authfile or --auth. + --keepalive, An optional keepalive interval. Since the underlying transport is HTTP, in many instances we'll be traversing through proxies, often these proxies will close idle connections. You must diff --git a/main.go b/main.go index 7af1f45e..0cb6dced 100644 --- a/main.go +++ b/main.go @@ -138,6 +138,14 @@ var serverHelp = ` authfile with {"": [""]}. If unset, it will use the environment variable AUTH. + --authurl, An optional URL to an external authentication service. + On each connection attempt chisel POSTs {"username": "...", "password": "..."} + as JSON to this URL. A 200 response must return a JSON array of address + regexes (in the same format as the values in --authfile) to grant access; + any non-200 response denies access. Use [""] or ["*"] to allow all addresses. + Supports all address-matching functionality of --authfile. Cannot be + combined with --authfile or --auth. + --keepalive, An optional keepalive interval. Since the underlying transport is HTTP, in many instances we'll be traversing through proxies, often these proxies will close idle connections. You must @@ -185,6 +193,7 @@ func server(args []string) { flags.StringVar(&config.KeyFile, "keyfile", "", "") flags.StringVar(&config.AuthFile, "authfile", "", "") flags.StringVar(&config.Auth, "auth", "", "") + flags.StringVar(&config.AuthURL, "authurl", "", "") flags.DurationVar(&config.KeepAlive, "keepalive", 25*time.Second, "") flags.StringVar(&config.Proxy, "proxy", "", "") flags.StringVar(&config.Proxy, "backend", "", "") diff --git a/server/server.go b/server/server.go index 8a702fce..8de6a3ae 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,7 @@ type Config struct { KeyFile string AuthFile string Auth string + AuthURL string Proxy string Socks5 bool Reverse bool @@ -45,6 +46,7 @@ type Server struct { sessions *settings.Users sshConfig *ssh.ServerConfig users *settings.UserIndex + urlUsers *settings.URLUserIndex } var upgrader = websocket.Upgrader{ @@ -62,6 +64,9 @@ func NewServer(c *Config) (*Server, error) { sessions: settings.NewUsers(), } server.Info = true + if c.AuthURL != "" && (c.AuthFile != "" || c.Auth != "") { + return nil, errors.New("--authurl cannot be combined with --authfile or --auth") + } server.users = settings.NewUserIndex(server.Logger) if c.AuthFile != "" { if err := server.users.LoadUsers(c.AuthFile); err != nil { @@ -75,6 +80,9 @@ func NewServer(c *Config) (*Server, error) { server.users.AddUser(u) } } + if c.AuthURL != "" { + server.urlUsers = settings.NewURLUserIndex(c.AuthURL, server.Logger) + } var pemBytes []byte var err error @@ -161,7 +169,7 @@ func (s *Server) Start(host, port string) error { // and can be closed by cancelling the provided context func (s *Server) StartContext(ctx context.Context, host, port string) error { s.Infof("Fingerprint %s", s.fingerprint) - if s.users.Len() > 0 { + if s.users.Len() > 0 || s.urlUsers != nil { s.Infof("User authentication enabled") } if s.reverseProxy != nil { @@ -198,15 +206,25 @@ func (s *Server) GetFingerprint() string { // authUser is responsible for validating the ssh user / password combination func (s *Server) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { // check if user authentication is enabled and if not, allow all - if s.users.Len() == 0 { + if s.users.Len() == 0 && s.urlUsers == nil { return nil, nil } - // check the user exists and has matching password n := c.User() + // URL-based auth: delegate credential check to external service + if s.urlUsers != nil { + user, err := s.urlUsers.GetUser(n, string(password)) + if err != nil { + s.Debugf("Login failed for user: %s", n) + return nil, errors.New("Invalid authentication for username: " + n) + } + s.sessions.Set(string(c.SessionID()), user) + return nil, nil + } + // file/inline user auth user, found := s.users.Get(n) if !found || user.Pass != string(password) { s.Debugf("Login failed for user: %s", n) - return nil, errors.New("Invalid authentication for username: %s") + return nil, errors.New("Invalid authentication for username: " + n) } // insert the user session map // TODO this should probably have a lock on it given the map isn't thread-safe diff --git a/server/server_handler.go b/server/server_handler.go index 8b5a68fd..da8b86a0 100644 --- a/server/server_handler.go +++ b/server/server_handler.go @@ -66,7 +66,7 @@ func (s *Server) handleWebsocket(w http.ResponseWriter, req *http.Request) { } // pull the users from the session map var user *settings.User - if s.users.Len() > 0 { + if s.users.Len() > 0 || s.urlUsers != nil { sid := string(sshConn.SessionID()) u, ok := s.sessions.Get(sid) if !ok { diff --git a/share/settings/user_url.go b/share/settings/user_url.go new file mode 100644 index 00000000..69e9ebd6 --- /dev/null +++ b/share/settings/user_url.go @@ -0,0 +1,75 @@ +package settings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + + "github.com/jpillora/chisel/share/cio" +) + +// URLUserIndex authenticates users against an external HTTP endpoint. +// On every login attempt it POSTs {"username": "...", "password": "..."} +// to the configured URL. A 200 response must contain a JSON array of address +// regexes (matching the values format of --authfile) to grant access; any +// other status code denies access. +type URLUserIndex struct { + *cio.Logger + url string + httpClient *http.Client +} + +// NewURLUserIndex creates a URLUserIndex that will POST credentials to authURL. +func NewURLUserIndex(authURL string, logger *cio.Logger) *URLUserIndex { + return &URLUserIndex{ + Logger: logger.Fork("url-users"), + url: authURL, + httpClient: &http.Client{}, + } +} + +// GetUser authenticates a user against the external URL and returns the +// resolved User (with compiled address regexes) on success, or an error on +// failure. An empty string or "*" in the address list grants full access +// (equivalent to UserAllowAll). +func (u *URLUserIndex) GetUser(name, pass string) (*User, error) { + body, err := json.Marshal(struct { + Username string `json:"username"` + Password string `json:"password"` + }{Username: name, Password: pass}) + if err != nil { + return nil, err + } + resp, err := u.httpClient.Post(u.url, "application/json", bytes.NewReader(body)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("auth denied (status %d)", resp.StatusCode) + } + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var addrStrs []string + if err := json.Unmarshal(raw, &addrStrs); err != nil { + return nil, fmt.Errorf("invalid JSON in auth response: %w", err) + } + addrs := make([]*regexp.Regexp, 0, len(addrStrs)) + for _, s := range addrStrs { + if s == "" || s == "*" { + addrs = append(addrs, UserAllowAll) + } else { + re, err := regexp.Compile(s) + if err != nil { + return nil, fmt.Errorf("invalid address regex %q: %w", s, err) + } + addrs = append(addrs, re) + } + } + return &User{Name: name, Addrs: addrs}, nil +} diff --git a/share/settings/user_url_test.go b/share/settings/user_url_test.go new file mode 100644 index 00000000..8cb53602 --- /dev/null +++ b/share/settings/user_url_test.go @@ -0,0 +1,137 @@ +package settings + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/jpillora/chisel/share/cio" +) + +func newTestURLUserIndex(t *testing.T, handler http.HandlerFunc) *URLUserIndex { + t.Helper() + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + return NewURLUserIndex(srv.URL, cio.NewLogger("test")) +} + +// assertPostJSON verifies the request is a JSON POST. +func assertPostJSON(t *testing.T, r *http.Request) { + t.Helper() + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %q", ct) + } +} + +func TestURLUserIndex_200WithAddresses(t *testing.T) { + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + assertPostJSON(t, r) + json.NewEncoder(w).Encode([]string{`^127\.0\.0\.1:\d+$`, `^10\.`}) + }) + user, err := idx.GetUser("alice", "secret") + if err != nil { + t.Fatal(err) + } + if user.Name != "alice" { + t.Fatalf("expected name alice, got %s", user.Name) + } + if len(user.Addrs) != 2 { + t.Fatalf("expected 2 addrs, got %d", len(user.Addrs)) + } + if !user.HasAccess("127.0.0.1:8080") { + t.Fatal("expected access to 127.0.0.1:8080") + } + if user.HasAccess("1.2.3.4:8080") { + t.Fatal("expected no access to 1.2.3.4:8080") + } +} + +func TestURLUserIndex_200AllowAll(t *testing.T) { + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + assertPostJSON(t, r) + json.NewEncoder(w).Encode([]string{""}) + }) + user, err := idx.GetUser("bob", "pass") + if err != nil { + t.Fatal(err) + } + if !user.HasAccess("anything:1234") { + t.Fatal("expected allow-all access") + } +} + +func TestURLUserIndex_200EmptyAddrs(t *testing.T) { + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + assertPostJSON(t, r) + json.NewEncoder(w).Encode([]string{}) + }) + user, err := idx.GetUser("carol", "pass") + if err != nil { + t.Fatal(err) + } + if user == nil { + t.Fatal("expected non-nil user") + } + if user.HasAccess("127.0.0.1:9000") { + t.Fatal("expected no access with empty addr list") + } +} + +func TestURLUserIndex_NonOKDenied(t *testing.T) { + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "forbidden", http.StatusForbidden) + }) + user, err := idx.GetUser("eve", "wrong") + if err == nil { + t.Fatal("expected error for non-200 response") + } + if user != nil { + t.Fatal("expected nil user on denial") + } +} + +func TestURLUserIndex_InvalidJSON(t *testing.T) { + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not-json")) + }) + user, err := idx.GetUser("frank", "pass") + if err == nil { + t.Fatal("expected error for invalid JSON body") + } + if user != nil { + t.Fatal("expected nil user on parse error") + } +} + +func TestURLUserIndex_RequestFormat(t *testing.T) { + var gotBody []byte + var gotContentType string + idx := newTestURLUserIndex(t, func(w http.ResponseWriter, r *http.Request) { + gotContentType = r.Header.Get("Content-Type") + gotBody, _ = io.ReadAll(r.Body) + json.NewEncoder(w).Encode([]string{""}) + }) + _, err := idx.GetUser("grace", "hunter2") + if err != nil { + t.Fatal(err) + } + if gotContentType != "application/json" { + t.Fatalf("expected Content-Type application/json, got %q", gotContentType) + } + var payload struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.Unmarshal(gotBody, &payload); err != nil { + t.Fatalf("could not parse request body: %v", err) + } + if payload.Username != "grace" || payload.Password != "hunter2" { + t.Fatalf("unexpected payload: %+v", payload) + } +} diff --git a/test/e2e/auth_test.go b/test/e2e/auth_test.go index cd758c5d..08906770 100644 --- a/test/e2e/auth_test.go +++ b/test/e2e/auth_test.go @@ -1,10 +1,19 @@ package e2e_test import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" + "github.com/gorilla/websocket" chclient "github.com/jpillora/chisel/client" chserver "github.com/jpillora/chisel/server" + "github.com/jpillora/chisel/share/cnet" + "github.com/jpillora/chisel/share/settings" + "golang.org/x/crypto/ssh" ) //TODO tests for: @@ -46,3 +55,160 @@ func TestAuth(t *testing.T) { t.Fatalf("expected exclamation mark added again") } } + +// TestAuthURL verifies that a chisel server configured with --authurl +// delegates authentication to an external HTTP service. +func TestAuthURL(t *testing.T) { + // mock auth backend: accepts alice/secret with full access + authSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var creds struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if creds.Username == "alice" && creds.Password == "secret" { + json.NewEncoder(w).Encode([]string{""}) // allow all + return + } + http.Error(w, "unauthorized", http.StatusUnauthorized) + })) + defer authSrv.Close() + + tmpPort := availablePort() + teardown := simpleSetup(t, + &chserver.Config{ + KeySeed: "authurl-test", + AuthURL: authSrv.URL, + }, + &chclient.Config{ + Remotes: []string{"0.0.0.0:" + tmpPort + ":127.0.0.1:$FILEPORT"}, + Auth: "alice:secret", + }) + defer teardown() + + result, err := post("http://localhost:"+tmpPort, "hello") + if err != nil { + t.Fatal(err) + } + if result != "hello!" { + t.Fatalf("expected 'hello!', got %q", result) + } +} + +// TestAuthURLDenied verifies that a client with wrong credentials is rejected +// when the server uses --authurl. +func TestAuthURLDenied(t *testing.T) { + // mock auth backend that always denies + authSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + })) + defer authSrv.Close() + + s, err := chserver.NewServer(&chserver.Config{ + KeySeed: "authurl-deny-test", + AuthURL: authSrv.URL, + }) + if err != nil { + t.Fatal(err) + } + port := availablePort() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := s.StartContext(ctx, "127.0.0.1", port); err != nil { + t.Fatal(err) + } + time.Sleep(20 * time.Millisecond) + + // dial directly at the SSH level; authentication must be rejected + ws, _, err := (&websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + Subprotocols: []string{"chisel-v3"}, + }).Dial("ws://127.0.0.1:"+port, http.Header{}) + if err != nil { + t.Fatalf("websocket dial: %v", err) + } + conn := cnet.NewWebSocketConn(ws) + _, _, _, sshErr := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ + User: "baduser", + Auth: []ssh.AuthMethod{ssh.Password("wrongpass")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if sshErr == nil { + t.Fatal("expected SSH auth to fail with bad credentials, but it succeeded") + } +} + +// When the auth URL returns a restrictive address list, the server must enforce it. +func TestAuthURLAddressRestriction(t *testing.T) { + // Auth server returns a regex that will never match any real remote address. + authSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var creds struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + if creds.Username == "alice" && creds.Password == "secret" { + // Return a regex that matches nothing a real remote would look like. + json.NewEncoder(w).Encode([]string{"^NOMATCH$"}) + return + } + http.Error(w, "unauthorized", http.StatusUnauthorized) + })) + defer authSrv.Close() + + s, err := chserver.NewServer(&chserver.Config{ + KeySeed: "addr-restriction-test", + AuthURL: authSrv.URL, + }) + if err != nil { + t.Fatal(err) + } + port := availablePort() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := s.StartContext(ctx, "127.0.0.1", port); err != nil { + t.Fatal(err) + } + time.Sleep(20 * time.Millisecond) + + // Dial at the SSH level with valid credentials; the config request for a + // real remote should be rejected because "^NOMATCH$" does not match it. + ws, _, wsErr := (&websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + Subprotocols: []string{"chisel-v3"}, + }).Dial("ws://127.0.0.1:"+port, http.Header{}) + if wsErr != nil { + t.Fatalf("websocket dial: %v", wsErr) + } + conn := cnet.NewWebSocketConn(ws) + sc, _, reqs, sshErr := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ + User: "alice", + Auth: []ssh.AuthMethod{ssh.Password("secret")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if sshErr != nil { + t.Fatalf("SSH auth should succeed (valid credentials): %v", sshErr) + } + go ssh.DiscardRequests(reqs) + + // Send a config requesting a tunnel to an address that won't match "^NOMATCH$". + targetPort := availablePort() + remotes := []*settings.Remote{{ + RemoteHost: "127.0.0.1", + RemotePort: targetPort, + }} + cfg, _ := json.Marshal(settings.Config{Version: "0", Remotes: remotes}) + ok, reply, err := sc.SendRequest("config", true, cfg) + if err != nil { + t.Fatalf("config request error: %v", err) + } + if ok { + t.Fatalf("expected config to be rejected due to address restriction, but it was accepted (reply: %s)", reply) + } +}