diff --git a/pkg/server/tokenrequest/tokenrequest.go b/pkg/server/tokenrequest/tokenrequest.go index 7ed86387d..e45e92cff 100644 --- a/pkg/server/tokenrequest/tokenrequest.go +++ b/pkg/server/tokenrequest/tokenrequest.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "path" + "strings" "github.com/openshift/osincli" @@ -144,9 +145,23 @@ func (t *tokenRequest) displayTokenPost(osinOAuthClient *osincli.Client, w http. } data.AccessToken = accessData.AccessToken + data.OcLoginCommand = formatOcLoginCommand(data.AccessToken, data.PublicMasterURL) + data.CurlCommand = formatCurlCommand(data.AccessToken, data.PublicMasterURL) renderToken(w, data) } +func formatOcLoginCommand(accessToken, publicMasterURL string) string { + return fmt.Sprintf("oc login --token=%s --server=%s", accessToken, publicMasterURL) +} + +func formatCurlCommand(accessToken, publicMasterURL string) string { + return fmt.Sprintf( + `curl -H "Authorization: Bearer %s" "%s/apis/user.openshift.io/v1/users/~"`, + accessToken, + strings.TrimRight(publicMasterURL, "/"), + ) +} + func displayTokenStart(osinOAuthClient *osincli.Client, w http.ResponseWriter, req *http.Request, data *sharedData) (*osincli.AuthorizeData, bool) { w.Header().Set("Content-Type", "text/html; charset=UTF-8") @@ -179,6 +194,8 @@ type tokenData struct { sharedData AccessToken string + OcLoginCommand string + CurlCommand string PublicMasterURL string LogoutURL string } @@ -216,27 +233,86 @@ const cssStyle = ` pre { padding-left: 1em; border-radius: 5px; color: #003d6e; background-color: #EAEDF0; padding: 1.5em 0 1.5em 4.5em; white-space: normal; text-indent: -2em; } a { color: #00f; text-decoration: none; } a:hover { text-decoration: underline; } - button { background: none; border: none; color: #00f; text-decoration: none; font: inherit; padding: 0; } - button:hover { text-decoration: underline; cursor: pointer; } + button, .copy-button { background: none; border: none; color: #00f; text-decoration: none; font: inherit; padding: 0; } + button:hover, .copy-button:hover { text-decoration: underline; cursor: pointer; } + .copy-heading { display: inline; } + .copy-heading .copy-button { font-size: 0.85em; margin-left: 0.75em; } @media (min-width: 768px) { .nowrap { white-space: nowrap; } } ` +const copyToClipboardScript = ` + +` + var tokenTemplate = template.Must(template.New("tokenTemplate").Parse( cssStyle + ` {{ if .Error }} {{ .Error }} {{ else }} -

Your API token is

+

Your API token is

{{.AccessToken}} -

Log in with this token

+

Log in with this token

oc login --token={{.AccessToken}} --server={{.PublicMasterURL}}
-

Use this token directly against the API

+

Use this token directly against the API

curl -H "Authorization: Bearer {{.AccessToken}}" "{{.PublicMasterURL}}/apis/user.openshift.io/v1/users/~"
+ ` + copyToClipboardScript + ` {{ end }}

diff --git a/pkg/server/tokenrequest/tokenrequest_test.go b/pkg/server/tokenrequest/tokenrequest_test.go new file mode 100644 index 000000000..3584ab6fc --- /dev/null +++ b/pkg/server/tokenrequest/tokenrequest_test.go @@ -0,0 +1,182 @@ +package tokenrequest + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/openshift/oauth-server/pkg/server/csrf" +) + +func TestFormatOcLoginCommand(t *testing.T) { + token := "sha256~abc123" + server := "https://api.example.com:6443" + + got := formatOcLoginCommand(token, server) + want := "oc login --token=sha256~abc123 --server=https://api.example.com:6443" + if got != want { + t.Fatalf("formatOcLoginCommand() = %q, want %q", got, want) + } +} + +func TestFormatCurlCommand(t *testing.T) { + token := "sha256~abc123" + server := "https://api.example.com:6443/" + + got := formatCurlCommand(token, server) + want := `curl -H "Authorization: Bearer sha256~abc123" "https://api.example.com:6443/apis/user.openshift.io/v1/users/~"` + if got != want { + t.Fatalf("formatCurlCommand() = %q, want %q", got, want) + } +} + +func TestRenderTokenIncludesCopyToClipboardControls(t *testing.T) { + data := tokenData{ + sharedData: sharedData{ + RequestURL: "/oauth/token/request", + }, + AccessToken: "sha256~token-value", + OcLoginCommand: formatOcLoginCommand("sha256~token-value", "https://api.example.com:6443"), + CurlCommand: formatCurlCommand("sha256~token-value", "https://api.example.com:6443"), + PublicMasterURL: "https://api.example.com:6443", + } + + var buf bytes.Buffer + renderToken(&buf, data) + body := buf.String() + + expectContains(t, body, []string{ + "Your API token is", + "Log in with this token", + "Use this token directly against the API", + `class="copy-button"`, + `aria-label="Copy to clipboard"`, + `data-copy-text="sha256~token-value"`, + `data-copy-text="oc login --token=sha256~token-value --server=https://api.example.com:6443"`, + `data-copy-text="curl -H "Authorization: Bearer sha256~token-value" "https://api.example.com:6443/apis/user.openshift.io/v1/users/~""`, + "navigator.clipboard", + "document.execCommand('copy')", + `Request another token`, + }) + + if strings.Count(body, `class="copy-button"`) != 3 { + t.Fatalf("expected 3 copy buttons, got %d in:\n%s", strings.Count(body, `class="copy-button"`), body) + } +} + +func TestRenderTokenEscapesSpecialCharactersInCopyAttributes(t *testing.T) { + token := `sha256~test"onclick='alert(1)'` + server := "https://api.example.com:6443" + + data := tokenData{ + sharedData: sharedData{ + RequestURL: "/oauth/token/request", + }, + AccessToken: token, + OcLoginCommand: formatOcLoginCommand(token, server), + CurlCommand: formatCurlCommand(token, server), + PublicMasterURL: server, + } + + var buf bytes.Buffer + renderToken(&buf, data) + body := buf.String() + + if strings.Contains(body, `data-copy-text="sha256~test"onclick`) { + t.Fatalf("token copy attribute was not HTML-escaped:\n%s", body) + } + if strings.Contains(body, "") { + t.Fatalf("unexpected unescaped script content in output:\n%s", body) + } + expectContains(t, body, []string{ + `data-copy-text="sha256~test"onclick='alert(1)'"`, + }) +} + +func TestRenderTokenErrorStateOmitsCopyControls(t *testing.T) { + data := tokenData{ + sharedData: sharedData{ + Error: "Error checking token", + RequestURL: "/oauth/token/request", + }, + } + + var buf bytes.Buffer + renderToken(&buf, data) + body := buf.String() + + expectContains(t, body, []string{ + "Error checking token", + `Request another token`, + }) + if strings.Contains(body, `class="copy-button"`) { + t.Fatalf("error state should not render copy buttons:\n%s", body) + } +} + +func TestRenderFormDisplayTokenStepUnchanged(t *testing.T) { + data := formData{ + sharedData: sharedData{ + RequestURL: "/oauth/token/request", + }, + Action: "https://oauth.example.com/oauth/token/display", + Code: "auth-code", + CSRF: "csrf-token", + } + + var buf bytes.Buffer + renderForm(&buf, data) + body := buf.String() + + expectContains(t, body, []string{ + `action="https://oauth.example.com/oauth/token/display"`, + `name="code" value="auth-code"`, + `name="csrf" value="csrf-token"`, + "Display Token", + }) + if strings.Contains(body, `class="copy-button"`) { + t.Fatalf("display token form should not include copy buttons:\n%s", body) + } +} + +func TestDisplayTokenPostRejectsInvalidCSRF(t *testing.T) { + handler := &tokenRequest{ + csrf: &csrf.FakeCSRF{Token: "expected-csrf"}, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/oauth/token/display", strings.NewReader("csrf=wrong")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + handler.displayTokenPost(nil, rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d: %s", http.StatusBadRequest, rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "Could not check CSRF token") { + t.Fatalf("unexpected body: %s", rec.Body.String()) + } +} + +func TestDisplayTokenRejectsUnsupportedMethod(t *testing.T) { + handler := &tokenRequest{} + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, "/oauth/token/display", nil) + handler.displayToken(nil, rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status %d, got %d: %s", http.StatusMethodNotAllowed, rec.Code, rec.Body.String()) + } +} + +func expectContains(t *testing.T, body string, expected []string) { + t.Helper() + for _, fragment := range expected { + if !strings.Contains(body, fragment) { + t.Fatalf("expected body to contain %q, got:\n%s", fragment, body) + } + } +}