diff --git a/example/src/components/sign-in-button.tsx b/example/src/components/sign-in-button.tsx index ba03fa0..6dd051f 100644 --- a/example/src/components/sign-in-button.tsx +++ b/example/src/components/sign-in-button.tsx @@ -2,7 +2,7 @@ import { Button, Flex } from '@radix-ui/themes'; import { Link } from '@tanstack/react-router'; import type { User } from '@workos/authkit-tanstack-react-start'; -export default function SignInButton({ large, user, url }: { large?: boolean; user: User | null; url: string }) { +export default function SignInButton({ large, user }: { large?: boolean; user: User | null }) { if (user) { return ( @@ -17,7 +17,7 @@ export default function SignInButton({ large, user, url }: { large?: boolean; us return ( ); } diff --git a/example/src/routeTree.gen.ts b/example/src/routeTree.gen.ts index 49009fc..0f8c78d 100644 --- a/example/src/routeTree.gen.ts +++ b/example/src/routeTree.gen.ts @@ -14,6 +14,7 @@ import { Route as ClientRouteImport } from './routes/client' import { Route as AuthenticatedRouteImport } from './routes/_authenticated' import { Route as IndexRouteImport } from './routes/index' import { Route as AuthenticatedAccountRouteImport } from './routes/_authenticated/account' +import { Route as ApiAuthSignInRouteImport } from './routes/api/auth/sign-in' import { Route as ApiAuthCallbackRouteImport } from './routes/api/auth/callback' const LogoutRoute = LogoutRouteImport.update({ @@ -40,6 +41,11 @@ const AuthenticatedAccountRoute = AuthenticatedAccountRouteImport.update({ path: '/account', getParentRoute: () => AuthenticatedRoute, } as any) +const ApiAuthSignInRoute = ApiAuthSignInRouteImport.update({ + id: '/api/auth/sign-in', + path: '/api/auth/sign-in', + getParentRoute: () => rootRouteImport, +} as any) const ApiAuthCallbackRoute = ApiAuthCallbackRouteImport.update({ id: '/api/auth/callback', path: '/api/auth/callback', @@ -52,6 +58,7 @@ export interface FileRoutesByFullPath { '/logout': typeof LogoutRoute '/account': typeof AuthenticatedAccountRoute '/api/auth/callback': typeof ApiAuthCallbackRoute + '/api/auth/sign-in': typeof ApiAuthSignInRoute } export interface FileRoutesByTo { '/': typeof IndexRoute @@ -59,6 +66,7 @@ export interface FileRoutesByTo { '/logout': typeof LogoutRoute '/account': typeof AuthenticatedAccountRoute '/api/auth/callback': typeof ApiAuthCallbackRoute + '/api/auth/sign-in': typeof ApiAuthSignInRoute } export interface FileRoutesById { __root__: typeof rootRouteImport @@ -68,12 +76,25 @@ export interface FileRoutesById { '/logout': typeof LogoutRoute '/_authenticated/account': typeof AuthenticatedAccountRoute '/api/auth/callback': typeof ApiAuthCallbackRoute + '/api/auth/sign-in': typeof ApiAuthSignInRoute } export interface FileRouteTypes { fileRoutesByFullPath: FileRoutesByFullPath - fullPaths: '/' | '/client' | '/logout' | '/account' | '/api/auth/callback' + fullPaths: + | '/' + | '/client' + | '/logout' + | '/account' + | '/api/auth/callback' + | '/api/auth/sign-in' fileRoutesByTo: FileRoutesByTo - to: '/' | '/client' | '/logout' | '/account' | '/api/auth/callback' + to: + | '/' + | '/client' + | '/logout' + | '/account' + | '/api/auth/callback' + | '/api/auth/sign-in' id: | '__root__' | '/' @@ -82,6 +103,7 @@ export interface FileRouteTypes { | '/logout' | '/_authenticated/account' | '/api/auth/callback' + | '/api/auth/sign-in' fileRoutesById: FileRoutesById } export interface RootRouteChildren { @@ -90,6 +112,7 @@ export interface RootRouteChildren { ClientRoute: typeof ClientRoute LogoutRoute: typeof LogoutRoute ApiAuthCallbackRoute: typeof ApiAuthCallbackRoute + ApiAuthSignInRoute: typeof ApiAuthSignInRoute } declare module '@tanstack/react-router' { @@ -129,6 +152,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof AuthenticatedAccountRouteImport parentRoute: typeof AuthenticatedRoute } + '/api/auth/sign-in': { + id: '/api/auth/sign-in' + path: '/api/auth/sign-in' + fullPath: '/api/auth/sign-in' + preLoaderRoute: typeof ApiAuthSignInRouteImport + parentRoute: typeof rootRouteImport + } '/api/auth/callback': { id: '/api/auth/callback' path: '/api/auth/callback' @@ -157,6 +187,7 @@ const rootRouteChildren: RootRouteChildren = { ClientRoute: ClientRoute, LogoutRoute: LogoutRoute, ApiAuthCallbackRoute: ApiAuthCallbackRoute, + ApiAuthSignInRoute: ApiAuthSignInRoute, } export const routeTree = rootRouteImport ._addFileChildren(rootRouteChildren) diff --git a/example/src/routes/__root.tsx b/example/src/routes/__root.tsx index 97a960d..3dcb1fc 100644 --- a/example/src/routes/__root.tsx +++ b/example/src/routes/__root.tsx @@ -3,7 +3,6 @@ import { HeadContent, Link, Outlet, Scripts, createRootRoute } from '@tanstack/r import appCssUrl from '../app.css?url'; import { TanStackRouterDevtools } from '@tanstack/react-router-devtools'; import { Suspense } from 'react'; -import { getSignInUrl } from '@workos/authkit-tanstack-react-start'; import { AuthKitProvider, Impersonation, getAuthAction } from '@workos/authkit-tanstack-react-start/client'; import Footer from '../components/footer'; import SignInButton from '../components/sign-in-button'; @@ -29,18 +28,14 @@ export const Route = createRootRoute({ // getAuthAction() returns auth state without accessToken, safe for client // Pass to AuthKitProvider as initialAuth to avoid loading flicker const auth = await getAuthAction(); - const url = await getSignInUrl(); - return { - auth, - url, - }; + return { auth }; }, component: RootComponent, notFoundComponent: () =>
Not Found
, }); function RootComponent() { - const { auth, url } = Route.useLoaderData(); + const { auth } = Route.useLoaderData(); return ( @@ -67,7 +62,7 @@ function RootComponent() {
Loading...}> - + diff --git a/example/src/routes/_authenticated.tsx b/example/src/routes/_authenticated.tsx index 10269ba..c15cffa 100644 --- a/example/src/routes/_authenticated.tsx +++ b/example/src/routes/_authenticated.tsx @@ -1,14 +1,13 @@ import { redirect, createFileRoute } from '@tanstack/react-router'; -import { getAuth, getSignInUrl } from '@workos/authkit-tanstack-react-start'; +import { getAuth } from '@workos/authkit-tanstack-react-start'; export const Route = createFileRoute('/_authenticated')({ loader: async ({ location }) => { // Loader runs on server (even during client-side navigation via RPC) const { user } = await getAuth(); if (!user) { - const path = location.pathname; - const href = await getSignInUrl({ data: { returnPathname: path } }); - throw redirect({ href }); + const returnPathname = encodeURIComponent(location.pathname); + throw redirect({ href: `/api/auth/sign-in?returnPathname=${returnPathname}` }); } }, }); diff --git a/example/src/routes/api/auth/sign-in.tsx b/example/src/routes/api/auth/sign-in.tsx new file mode 100644 index 0000000..18c89da --- /dev/null +++ b/example/src/routes/api/auth/sign-in.tsx @@ -0,0 +1,17 @@ +import { createFileRoute } from '@tanstack/react-router'; +import { getSignInUrl } from '@workos/authkit-tanstack-react-start'; + +export const Route = createFileRoute('/api/auth/sign-in')({ + server: { + handlers: { + GET: async ({ request }: { request: Request }) => { + const returnPathname = new URL(request.url).searchParams.get('returnPathname'); + const url = await getSignInUrl(returnPathname ? { data: { returnPathname } } : undefined); + return new Response(null, { + status: 307, + headers: { Location: url }, + }); + }, + }, + }, +}); diff --git a/example/src/routes/index.tsx b/example/src/routes/index.tsx index c69922a..fd4d230 100644 --- a/example/src/routes/index.tsx +++ b/example/src/routes/index.tsx @@ -1,19 +1,18 @@ import { Button, Flex, Heading, Text } from '@radix-ui/themes'; import { Link, createFileRoute } from '@tanstack/react-router'; -import { getAuth, getSignInUrl } from '@workos/authkit-tanstack-react-start'; +import { getAuth } from '@workos/authkit-tanstack-react-start'; import SignInButton from '../components/sign-in-button'; export const Route = createFileRoute('/')({ component: Home, loader: async () => { const { user } = await getAuth(); - const url = await getSignInUrl(); - return { user, url }; + return { user }; }, }); function Home() { - const { user, url } = Route.useLoaderData(); + const { user } = Route.useLoaderData(); return ( @@ -27,7 +26,7 @@ function Home() { - + ) : ( @@ -37,7 +36,7 @@ function Home() { Sign in to view your account details - + )} diff --git a/package.json b/package.json index 0ad6558..a1b52c4 100644 --- a/package.json +++ b/package.json @@ -64,7 +64,7 @@ "url": "https://github.com/workos/authkit-tanstack-start/issues" }, "dependencies": { - "@workos/authkit-session": "0.3.4" + "@workos/authkit-session": "0.4.0" }, "peerDependencies": { "@tanstack/react-router": ">=1.0.0", @@ -97,6 +97,9 @@ "onlyBuiltDependencies": [ "@parcel/watcher", "esbuild" - ] + ], + "overrides": { + "@workos/authkit-session": "link:../authkit-session" + } } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9eaa994..79dd2df 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -4,13 +4,16 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +overrides: + '@workos/authkit-session': link:../authkit-session + importers: .: dependencies: '@workos/authkit-session': - specifier: 0.3.4 - version: 0.3.4 + specifier: link:../authkit-session + version: link:../authkit-session devDependencies: '@tanstack/react-router': specifier: ^1.154.8 @@ -2026,14 +2029,6 @@ packages: '@vitest/utils@4.0.15': resolution: {integrity: sha512-HXjPW2w5dxhTD0dLwtYHDnelK3j8sR8cWIaLxr22evTyY6q8pRCjZSmhRWVjBaOVXChQd6AwMzi9pucorXCPZA==} - '@workos-inc/node@8.0.0': - resolution: {integrity: sha512-D8VDfx0GXeiVm8vccAl0rElW7taebRnrteKPJzZwehwzI9W/Usa4qKfmwxj+7Lh1Z1deEocDRCpZpV7ml4GpWQ==} - engines: {node: '>=20.15.0'} - - '@workos/authkit-session@0.3.4': - resolution: {integrity: sha512-lbLP1y8MHWL1Op9athZ3SrzKLcL0+xBVpADCMQLI39mPgSQj+/lopVdOx0Cku96hYnJBOJTLVTK3Zox4FbZl4A==} - engines: {node: '>=20.0.0'} - acorn@8.15.0: resolution: {integrity: sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==} engines: {node: '>=0.4.0'} @@ -2365,9 +2360,6 @@ packages: iron-webcrypto@1.2.1: resolution: {integrity: sha512-feOM6FaSr6rEABp/eDfVseKyTMDt+KGpeB35SkVn9Tyn0CqvVsY3EwI0v5i8nMHyJnzCIQf7nsy3p41TPkJZhg==} - iron-webcrypto@2.0.0: - resolution: {integrity: sha512-rtffZKDUHciZElM8mjFCufBC7nVhCxHYyWHESqs89OioEDz4parOofd8/uhrejh/INhQFfYQfByS22LlezR9sQ==} - is-binary-path@2.1.0: resolution: {integrity: sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==} engines: {node: '>=8'} @@ -2842,10 +2834,6 @@ packages: ufo@1.6.1: resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} - uint8array-extras@1.5.0: - resolution: {integrity: sha512-rvKSBiC5zqCCiDZ9kAOszZcDvdAHwwIKJG33Ykj43OKcWsnmcBRL09YTU4nOeHZ8Y2a7l1MgTd08SBe9A8Qj6A==} - engines: {node: '>=18'} - uncrypto@0.1.3: resolution: {integrity: sha512-Ql87qFHB3s/De2ClA9e0gsnS6zXG27SkTiSJwjCc9MebbfapQfuPzumMIUMi38ezPZVNFcHI9sUIepeQfw8J8Q==} @@ -4759,17 +4747,6 @@ snapshots: '@vitest/pretty-format': 4.0.15 tinyrainbow: 3.0.3 - '@workos-inc/node@8.0.0': - dependencies: - iron-webcrypto: 2.0.0 - jose: 6.1.3 - - '@workos/authkit-session@0.3.4': - dependencies: - '@workos-inc/node': 8.0.0 - iron-webcrypto: 2.0.0 - jose: 6.1.3 - acorn@8.15.0: {} ansi-regex@5.0.1: {} @@ -5125,10 +5102,6 @@ snapshots: iron-webcrypto@1.2.1: {} - iron-webcrypto@2.0.0: - dependencies: - uint8array-extras: 1.5.0 - is-binary-path@2.1.0: dependencies: binary-extensions: 2.3.0 @@ -5660,8 +5633,6 @@ snapshots: ufo@1.6.1: {} - uint8array-extras@1.5.0: {} - uncrypto@0.1.3: {} undici-types@7.18.2: {} diff --git a/src/server/auth-helpers.spec.ts b/src/server/auth-helpers.spec.ts index e7842f2..7a627c3 100644 --- a/src/server/auth-helpers.spec.ts +++ b/src/server/auth-helpers.spec.ts @@ -23,13 +23,7 @@ vi.mock('./authkit-loader', () => ({ getAuthkit: vi.fn(() => Promise.resolve(mockAuthkit)), })); -import { - getRawAuthFromContext, - isAuthConfigured, - getSessionWithRefreshToken, - refreshSession, - decodeState, -} from './auth-helpers'; +import { getRawAuthFromContext, isAuthConfigured, getSessionWithRefreshToken, refreshSession } from './auth-helpers'; describe('Auth Helpers', () => { beforeEach(() => { @@ -212,55 +206,4 @@ describe('Auth Helpers', () => { expect(mockAuthkit.saveSession).not.toHaveBeenCalled(); }); }); - - describe('decodeState', () => { - it('returns default when state is null', () => { - expect(decodeState(null)).toEqual({ returnPathname: '/' }); - }); - - it('returns default when state is "null" string', () => { - expect(decodeState('null')).toEqual({ returnPathname: '/' }); - }); - - it('decodes valid base64 state', () => { - const internal = btoa(JSON.stringify({ returnPathname: '/dashboard' })); - - const result = decodeState(internal); - - expect(result).toEqual({ returnPathname: '/dashboard' }); - }); - - it('extracts custom state after dot separator', () => { - const internal = btoa(JSON.stringify({ returnPathname: '/profile' })); - const state = `${internal}.custom-user-state`; - - const result = decodeState(state); - - expect(result).toEqual({ - returnPathname: '/profile', - customState: 'custom-user-state', - }); - }); - - it('handles multiple dots in custom state', () => { - const internal = btoa(JSON.stringify({ returnPathname: '/' })); - const state = `${internal}.part1.part2.part3`; - - const result = decodeState(state); - - expect(result).toEqual({ - returnPathname: '/', - customState: 'part1.part2.part3', - }); - }); - - it('returns root with custom state when decode fails', () => { - const result = decodeState('invalid-base64'); - - expect(result).toEqual({ - returnPathname: '/', - customState: 'invalid-base64', - }); - }); - }); }); diff --git a/src/server/auth-helpers.ts b/src/server/auth-helpers.ts index 3ecd4dc..6163cf8 100644 --- a/src/server/auth-helpers.ts +++ b/src/server/auth-helpers.ts @@ -90,26 +90,3 @@ export async function refreshSession(organizationId?: string) { return result; } - -/** - * Decodes a state parameter from OAuth callback. - * Format: base64EncodedInternal.customUserState (dot-separated) - */ -export function decodeState(state: string | null): { returnPathname: string; customState?: string } { - if (!state || state === 'null') { - return { returnPathname: '/' }; - } - - const [internal, ...rest] = state.split('.'); - const customState = rest.length > 0 ? rest.join('.') : undefined; - - try { - const decoded = JSON.parse(atob(internal)); - return { - returnPathname: decoded.returnPathname || '/', - customState, - }; - } catch { - return { returnPathname: '/', customState: customState ?? state }; - } -} diff --git a/src/server/cookie-utils.spec.ts b/src/server/cookie-utils.spec.ts new file mode 100644 index 0000000..71c0a95 --- /dev/null +++ b/src/server/cookie-utils.spec.ts @@ -0,0 +1,24 @@ +import { describe, it, expect } from 'vitest'; +import { parseCookies } from './cookie-utils'; + +describe('parseCookies', () => { + it('parses a single cookie', () => { + expect(parseCookies('a=1')).toEqual({ a: '1' }); + }); + + it('parses multiple cookies', () => { + expect(parseCookies('a=1; b=2; c=3')).toEqual({ a: '1', b: '2', c: '3' }); + }); + + it('preserves = characters within cookie values', () => { + expect(parseCookies('token=base64==padding==')).toEqual({ token: 'base64==padding==' }); + }); + + it('returns an empty entry for an empty header', () => { + expect(parseCookies('')).toEqual({ '': '' }); + }); + + it('trims whitespace around each pair', () => { + expect(parseCookies('a=1 ; b=2')).toEqual({ a: '1', b: '2' }); + }); +}); diff --git a/src/server/cookie-utils.ts b/src/server/cookie-utils.ts new file mode 100644 index 0000000..1d2109d --- /dev/null +++ b/src/server/cookie-utils.ts @@ -0,0 +1,8 @@ +export function parseCookies(cookieHeader: string): Record { + return Object.fromEntries( + cookieHeader.split(';').map((cookie) => { + const [key, ...valueParts] = cookie.trim().split('='); + return [key, valueParts.join('=')]; + }), + ); +} diff --git a/src/server/index.ts b/src/server/index.ts index 8792256..e6c09bb 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -27,3 +27,5 @@ export { getOrganizationAction, type OrganizationInfo, } from './actions.js'; + +export { OAuthStateMismatchError, PKCECookieMissingError } from '@workos/authkit-session'; diff --git a/src/server/server-functions.spec.ts b/src/server/server-functions.spec.ts index f2d484e..ba3f244 100644 --- a/src/server/server-functions.spec.ts +++ b/src/server/server-functions.spec.ts @@ -5,6 +5,16 @@ vi.mock('@tanstack/react-start/server', () => ({ getRequest: vi.fn(() => new Request('http://test.local')), })); +// Upstream's new shape: `{ url, response?, headers? }`. Matches the library's +// storage-owned cookie flow — the adapter is no longer in the business of +// serializing the PKCE verifier cookie itself. +const authorizationResult = (url: string) => ({ + url, + headers: { + 'Set-Cookie': 'wos-auth-verifier=sealed-blob-abc; Path=/; HttpOnly; SameSite=Lax; Max-Age=600; Secure', + }, +}); + const mockAuthkit = { withAuth: vi.fn(), getWorkOS: vi.fn(() => ({ @@ -17,11 +27,14 @@ const mockAuthkit = { headers: { 'Set-Cookie': 'wos-session=; Path=/; Max-Age=0; HttpOnly; Secure; SameSite=Lax' }, }), handleCallback: vi.fn(), - getAuthorizationUrl: vi.fn().mockResolvedValue('https://auth.workos.com/authorize'), - getSignInUrl: vi.fn().mockResolvedValue('https://auth.workos.com/signin'), - getSignUpUrl: vi.fn().mockResolvedValue('https://auth.workos.com/signup'), + createAuthorization: vi.fn().mockResolvedValue(authorizationResult('https://auth.workos.com/authorize')), + createSignIn: vi.fn().mockResolvedValue(authorizationResult('https://auth.workos.com/signin')), + createSignUp: vi.fn().mockResolvedValue(authorizationResult('https://auth.workos.com/signup')), }; +const mockSetPendingHeader = vi.fn(); +let mockContextAvailable = true; + vi.mock('./authkit-loader', () => ({ getAuthkit: vi.fn(() => Promise.resolve(mockAuthkit)), getConfig: vi.fn((key: string) => { @@ -75,10 +88,16 @@ vi.mock('@tanstack/react-start', () => ({ return fn; }, }), - getGlobalStartContext: () => ({ - auth: mockAuthContext, - request: new Request('http://test.local'), - }), + getGlobalStartContext: () => { + if (!mockContextAvailable) { + throw new Error('TanStack context not available'); + } + return { + auth: mockAuthContext, + request: new Request('http://test.local'), + __setPendingHeader: mockSetPendingHeader, + }; + }, })); // Now import everything after mocks are set up @@ -90,6 +109,8 @@ import * as serverFunctions from './server-functions'; describe('Server Functions', () => { beforeEach(() => { vi.clearAllMocks(); + mockContextAvailable = true; + mockAuthContext = () => ({ user: null }); }); describe('getAuth', () => { @@ -199,7 +220,7 @@ describe('Server Functions', () => { describe('getAuthorizationUrl', () => { it('generates authorization URL with all options', async () => { const authUrl = 'https://auth.workos.com/authorize?client_id=test'; - mockAuthkit.getAuthorizationUrl.mockResolvedValue(authUrl); + mockAuthkit.createAuthorization.mockResolvedValue(authorizationResult(authUrl)); const result = await serverFunctions.getAuthorizationUrl({ data: { @@ -210,7 +231,7 @@ describe('Server Functions', () => { }); expect(result).toBe(authUrl); - expect(mockAuthkit.getAuthorizationUrl).toHaveBeenCalledWith({ + expect(mockAuthkit.createAuthorization).toHaveBeenCalledWith(undefined, { screenHint: 'sign-up', returnPathname: '/dashboard', redirectUri: 'http://custom.local/callback', @@ -219,7 +240,7 @@ describe('Server Functions', () => { it('works with minimal options', async () => { const authUrl = 'https://auth.workos.com/authorize'; - mockAuthkit.getAuthorizationUrl.mockResolvedValue(authUrl); + mockAuthkit.createAuthorization.mockResolvedValue(authorizationResult(authUrl)); const result = await serverFunctions.getAuthorizationUrl({ data: {} }); @@ -228,46 +249,46 @@ describe('Server Functions', () => { it('handles undefined data', async () => { const authUrl = 'https://auth.workos.com/authorize'; - mockAuthkit.getAuthorizationUrl.mockResolvedValue(authUrl); + mockAuthkit.createAuthorization.mockResolvedValue(authorizationResult(authUrl)); const result = await serverFunctions.getAuthorizationUrl({ data: undefined }); expect(result).toBe(authUrl); - expect(mockAuthkit.getAuthorizationUrl).toHaveBeenCalledWith({}); + expect(mockAuthkit.createAuthorization).toHaveBeenCalledWith(undefined, {}); }); }); describe('getSignInUrl', () => { it('generates sign-in URL with return path string', async () => { const signInUrl = 'https://auth.workos.com/sign-in'; - mockAuthkit.getSignInUrl.mockResolvedValue(signInUrl); + mockAuthkit.createSignIn.mockResolvedValue(authorizationResult(signInUrl)); const result = await serverFunctions.getSignInUrl({ data: '/profile' }); expect(result).toBe(signInUrl); - expect(mockAuthkit.getSignInUrl).toHaveBeenCalledWith({ returnPathname: '/profile' }); + expect(mockAuthkit.createSignIn).toHaveBeenCalledWith(undefined, { returnPathname: '/profile' }); }); it('works without options', async () => { const signInUrl = 'https://auth.workos.com/sign-in'; - mockAuthkit.getSignInUrl.mockResolvedValue(signInUrl); + mockAuthkit.createSignIn.mockResolvedValue(authorizationResult(signInUrl)); const result = await serverFunctions.getSignInUrl({ data: undefined }); expect(result).toBe(signInUrl); - expect(mockAuthkit.getSignInUrl).toHaveBeenCalledWith(undefined); + expect(mockAuthkit.createSignIn).toHaveBeenCalledWith(undefined, {}); }); it('passes state option through', async () => { const signInUrl = 'https://auth.workos.com/sign-in'; - mockAuthkit.getSignInUrl.mockResolvedValue(signInUrl); + mockAuthkit.createSignIn.mockResolvedValue(authorizationResult(signInUrl)); const result = await serverFunctions.getSignInUrl({ data: { returnPathname: '/dashboard', state: 'custom-state' }, }); expect(result).toBe(signInUrl); - expect(mockAuthkit.getSignInUrl).toHaveBeenCalledWith({ + expect(mockAuthkit.createSignIn).toHaveBeenCalledWith(undefined, { returnPathname: '/dashboard', state: 'custom-state', }); @@ -275,7 +296,7 @@ describe('Server Functions', () => { it('passes all options through', async () => { const signInUrl = 'https://auth.workos.com/sign-in'; - mockAuthkit.getSignInUrl.mockResolvedValue(signInUrl); + mockAuthkit.createSignIn.mockResolvedValue(authorizationResult(signInUrl)); const result = await serverFunctions.getSignInUrl({ data: { @@ -287,7 +308,7 @@ describe('Server Functions', () => { }); expect(result).toBe(signInUrl); - expect(mockAuthkit.getSignInUrl).toHaveBeenCalledWith({ + expect(mockAuthkit.createSignIn).toHaveBeenCalledWith(undefined, { returnPathname: '/dashboard', state: 'my-state', organizationId: 'org_123', @@ -299,34 +320,34 @@ describe('Server Functions', () => { describe('getSignUpUrl', () => { it('generates sign-up URL with return path string', async () => { const signUpUrl = 'https://auth.workos.com/sign-up'; - mockAuthkit.getSignUpUrl.mockResolvedValue(signUpUrl); + mockAuthkit.createSignUp.mockResolvedValue(authorizationResult(signUpUrl)); const result = await serverFunctions.getSignUpUrl({ data: '/welcome' }); expect(result).toBe(signUpUrl); - expect(mockAuthkit.getSignUpUrl).toHaveBeenCalledWith({ returnPathname: '/welcome' }); + expect(mockAuthkit.createSignUp).toHaveBeenCalledWith(undefined, { returnPathname: '/welcome' }); }); it('accepts object with returnPathname', async () => { const signUpUrl = 'https://auth.workos.com/sign-up'; - mockAuthkit.getSignUpUrl.mockResolvedValue(signUpUrl); + mockAuthkit.createSignUp.mockResolvedValue(authorizationResult(signUpUrl)); const result = await serverFunctions.getSignUpUrl({ data: { returnPathname: '/onboarding' } }); expect(result).toBe(signUpUrl); - expect(mockAuthkit.getSignUpUrl).toHaveBeenCalledWith({ returnPathname: '/onboarding' }); + expect(mockAuthkit.createSignUp).toHaveBeenCalledWith(undefined, { returnPathname: '/onboarding' }); }); it('passes state option through', async () => { const signUpUrl = 'https://auth.workos.com/sign-up'; - mockAuthkit.getSignUpUrl.mockResolvedValue(signUpUrl); + mockAuthkit.createSignUp.mockResolvedValue(authorizationResult(signUpUrl)); const result = await serverFunctions.getSignUpUrl({ data: { returnPathname: '/welcome', state: 'signup-flow' }, }); expect(result).toBe(signUpUrl); - expect(mockAuthkit.getSignUpUrl).toHaveBeenCalledWith({ + expect(mockAuthkit.createSignUp).toHaveBeenCalledWith(undefined, { returnPathname: '/welcome', state: 'signup-flow', }); @@ -334,7 +355,7 @@ describe('Server Functions', () => { it('passes all options through', async () => { const signUpUrl = 'https://auth.workos.com/sign-up'; - mockAuthkit.getSignUpUrl.mockResolvedValue(signUpUrl); + mockAuthkit.createSignUp.mockResolvedValue(authorizationResult(signUpUrl)); const result = await serverFunctions.getSignUpUrl({ data: { @@ -346,7 +367,7 @@ describe('Server Functions', () => { }); expect(result).toBe(signUpUrl); - expect(mockAuthkit.getSignUpUrl).toHaveBeenCalledWith({ + expect(mockAuthkit.createSignUp).toHaveBeenCalledWith(undefined, { returnPathname: '/onboarding', state: 'invite-123', organizationId: 'org_456', @@ -466,4 +487,56 @@ describe('Server Functions', () => { expect(typeof serverFunctions.getSignUpUrl).toBe('function'); }); }); + + describe('PKCE cookie wiring', () => { + const cases = [ + { + name: 'getAuthorizationUrl', + call: () => serverFunctions.getAuthorizationUrl({ data: {} }), + mockFn: () => mockAuthkit.createAuthorization, + url: 'https://auth.workos.com/authorize?client_id=test', + }, + { + name: 'getSignInUrl', + call: () => serverFunctions.getSignInUrl({ data: undefined }), + mockFn: () => mockAuthkit.createSignIn, + url: 'https://auth.workos.com/sign-in', + }, + { + name: 'getSignUpUrl', + call: () => serverFunctions.getSignUpUrl({ data: undefined }), + mockFn: () => mockAuthkit.createSignUp, + url: 'https://auth.workos.com/sign-up', + }, + ]; + + cases.forEach(({ name, call, mockFn, url }) => { + describe(name, () => { + it('writes Set-Cookie with wos-auth-verifier exactly once', async () => { + mockFn().mockResolvedValue(authorizationResult(url)); + + await call(); + + expect(mockSetPendingHeader).toHaveBeenCalledTimes(1); + expect(mockSetPendingHeader).toHaveBeenCalledWith('Set-Cookie', expect.stringMatching(/^wos-auth-verifier=/)); + }); + + it('returns only the URL (no sealedState leak)', async () => { + mockFn().mockResolvedValue(authorizationResult(url)); + + const result = await call(); + + expect(result).toBe(url); + expect(typeof result).toBe('string'); + }); + + it('throws actionable error when middleware context is unavailable', async () => { + mockContextAvailable = false; + mockFn().mockResolvedValue(authorizationResult(url)); + + await expect(call()).rejects.toThrow(/authkitMiddleware is registered/); + }); + }); + }); + }); }); diff --git a/src/server/server-functions.ts b/src/server/server-functions.ts index 10d63c4..b15769e 100644 --- a/src/server/server-functions.ts +++ b/src/server/server-functions.ts @@ -3,9 +3,57 @@ import { createServerFn } from '@tanstack/react-start'; import type { Impersonator, User } from '../types.js'; import { getRawAuthFromContext, refreshSession, getRedirectUriFromContext } from './auth-helpers.js'; import { getAuthkit } from './authkit-loader.js'; +import { getAuthKitContextOrNull } from './context.js'; // Type-only import - safe for bundling -import type { GetAuthorizationUrlOptions as GetAuthURLOptions } from '@workos/authkit-session'; +import type { GetAuthorizationUrlOptions as GetAuthURLOptions, HeadersBag } from '@workos/authkit-session'; + +type AuthorizationResult = { + url: string; + response?: Response; + headers?: HeadersBag; +}; + +/** + * Forward every `Set-Cookie` (and any other header) emitted by the upstream + * authorization-URL call through middleware's pending-header channel so the + * PKCE verifier cookie lands on the outgoing response. Each `Set-Cookie` entry + * is appended as its own header — never comma-joined — so multi-cookie + * emissions survive as distinct HTTP headers. + */ +function forwardAuthorizationCookies(result: AuthorizationResult): string { + const ctx = getAuthKitContextOrNull(); + if (!ctx?.__setPendingHeader) { + throw new Error( + '[authkit-tanstack-react-start] PKCE cookie could not be set: middleware context unavailable. Ensure authkitMiddleware is registered in your request middleware stack.', + ); + } + + // Prefer the `headers` bag when present — it's the library's primary channel. + if (result.headers) { + for (const [key, value] of Object.entries(result.headers)) { + if (Array.isArray(value)) { + for (const v of value) ctx.__setPendingHeader(key, v); + } else if (typeof value === 'string') { + ctx.__setPendingHeader(key, value); + } + } + } else if (result.response) { + // Fallback: storage mutated the Response directly (context-unavailable path). + for (const value of result.response.headers.getSetCookie()) { + ctx.__setPendingHeader('Set-Cookie', value); + } + } + + return result.url; +} + +/** Inject middleware-configured redirectUri only when caller did not provide one. */ +function applyContextRedirectUri(options: T): T { + const contextRedirectUri = getRedirectUriFromContext(); + if (!contextRedirectUri || options?.redirectUri) return options; + return { ...options, redirectUri: contextRedirectUri } as T; +} // Type exports - re-export shared types from authkit-session export type { GetAuthURLOptions }; @@ -159,17 +207,7 @@ export const getAuthorizationUrl = createServerFn({ method: 'GET' }) .inputValidator((options?: GetAuthURLOptions) => options) .handler(async ({ data: options = {} }) => { const authkit = await getAuthkit(); - const contextRedirectUri = getRedirectUriFromContext(); - - // Only inject context redirectUri if it exists and user didn't provide one - if (contextRedirectUri && !options.redirectUri) { - return authkit.getAuthorizationUrl({ - ...options, - redirectUri: contextRedirectUri, - }); - } - - return authkit.getAuthorizationUrl(options); + return forwardAuthorizationCookies(await authkit.createAuthorization(undefined, applyContextRedirectUri(options))); }); /** Options for getSignInUrl/getSignUpUrl - all GetAuthURLOptions except screenHint */ @@ -195,18 +233,8 @@ export const getSignInUrl = createServerFn({ method: 'GET' }) .inputValidator((data?: string | SignInUrlOptions) => data) .handler(async ({ data }) => { const options = typeof data === 'string' ? { returnPathname: data } : data; - const contextRedirectUri = getRedirectUriFromContext(); const authkit = await getAuthkit(); - - // Only inject context redirectUri if it exists and user didn't provide one - if (contextRedirectUri && !options?.redirectUri) { - return authkit.getSignInUrl({ - ...options, - redirectUri: contextRedirectUri, - }); - } - - return authkit.getSignInUrl(options); + return forwardAuthorizationCookies(await authkit.createSignIn(undefined, applyContextRedirectUri(options ?? {}))); }); /** @@ -229,18 +257,8 @@ export const getSignUpUrl = createServerFn({ method: 'GET' }) .inputValidator((data?: string | SignInUrlOptions) => data) .handler(async ({ data }) => { const options = typeof data === 'string' ? { returnPathname: data } : data; - const contextRedirectUri = getRedirectUriFromContext(); const authkit = await getAuthkit(); - - // Only inject context redirectUri if it exists and user didn't provide one - if (contextRedirectUri && !options?.redirectUri) { - return authkit.getSignUpUrl({ - ...options, - redirectUri: contextRedirectUri, - }); - } - - return authkit.getSignUpUrl(options); + return forwardAuthorizationCookies(await authkit.createSignUp(undefined, applyContextRedirectUri(options ?? {}))); }); /** diff --git a/src/server/server.spec.ts b/src/server/server.spec.ts index 40a5f0f..b28cf9f 100644 --- a/src/server/server.spec.ts +++ b/src/server/server.spec.ts @@ -1,18 +1,19 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; -// Setup mocks before imports const mockHandleCallback = vi.fn(); const mockWithAuth = vi.fn(); -const mockGetSignInUrl = vi.fn(); +const mockCreateSignIn = vi.fn(); +const mockClearPendingVerifier = vi.fn(async () => ({ + headers: { + 'Set-Cookie': + 'wos-auth-verifier=; Path=/; HttpOnly; SameSite=Lax; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT', + }, +})); + +let mockGetAuthkitImpl: () => Promise; vi.mock('./authkit-loader', () => ({ - getAuthkit: vi.fn(() => - Promise.resolve({ - withAuth: mockWithAuth, - handleCallback: mockHandleCallback, - getSignInUrl: mockGetSignInUrl, - }), - ), + getAuthkit: vi.fn(() => mockGetAuthkitImpl()), })); vi.mock('@tanstack/react-router', () => ({ @@ -23,258 +24,255 @@ vi.mock('@tanstack/react-router', () => ({ import { handleCallbackRoute } from './server'; +const baseAuthResponse = { + accessToken: 'access_token', + refreshToken: 'refresh_token', + user: { id: 'user_123', email: 'test@example.com' }, +}; + +const successResult = (overrides: Record = {}) => ({ + response: { headers: new Map() }, + returnPathname: '/', + state: undefined, + authResponse: baseAuthResponse, + ...overrides, +}); + describe('handleCallbackRoute', () => { beforeEach(() => { vi.clearAllMocks(); + mockGetAuthkitImpl = () => + Promise.resolve({ + withAuth: mockWithAuth, + handleCallback: mockHandleCallback, + createSignIn: mockCreateSignIn, + clearPendingVerifier: mockClearPendingVerifier, + }); }); - it('rejects missing code', async () => { - const request = new Request('http://example.com/callback'); - const handler = handleCallbackRoute(); - const response = await handler({ request }); - - expect(response.status).toBe(400); - const body = await response.json(); - expect(body.error.message).toBe('Missing authorization code'); - }); + describe('missing code', () => { + it('returns 400 with generic body and delete-cookie header', async () => { + const request = new Request('http://example.com/callback'); + const response = await handleCallbackRoute()({ request }); - it('processes valid callback', async () => { - const request = new Request('http://example.com/callback?code=auth_123'); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(response.status).toBe(400); + const body = await response.json(); + expect(body.error.message).toBe('Authentication failed'); + expect(body.error).not.toHaveProperty('details'); + expect(response.headers.getSetCookie()).toEqual([expect.stringContaining('wos-auth-verifier=')]); }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('calls onError hook when provided', async () => { + const request = new Request('http://example.com/callback'); + const onError = vi.fn().mockReturnValue(new Response('Custom error', { status: 403 })); - expect(response.status).toBe(307); - expect(response.headers.get('Location')).toBe('http://example.com/'); - }); + const response = await handleCallbackRoute({ onError })({ request }); - it('decodes state for return path', async () => { - const state = btoa(JSON.stringify({ returnPathname: '/dashboard' })); - const request = new Request(`http://example.com/callback?code=auth_123&state=${state}`); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(onError).toHaveBeenCalledWith({ error: expect.any(Error), request }); + expect(response.status).toBe(403); + expect(await response.text()).toBe('Custom error'); + expect(response.headers.getSetCookie().some((c) => c.startsWith('wos-auth-verifier='))).toBe(true); }); + }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + describe('success path', () => { + it('returns 307 with Location from result.returnPathname', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult({ returnPathname: '/dashboard' })); - expect(response.headers.get('Location')).toBe('http://example.com/dashboard'); - }); + const response = await handleCallbackRoute()({ request }); - it('handles state with query params in return path', async () => { - const state = btoa(JSON.stringify({ returnPathname: '/search?q=test&page=2' })); - const request = new Request(`http://example.com/callback?code=auth_123&state=${state}`); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(response.status).toBe(307); + expect(response.headers.get('Location')).toBe('http://example.com/dashboard'); }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('honors returnPathname with query params', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult({ returnPathname: '/search?q=test&page=2' })); - expect(response.headers.get('Location')).toBe('http://example.com/search?q=test&page=2'); - }); + const response = await handleCallbackRoute()({ request }); - it('handles invalid state gracefully', async () => { - const request = new Request('http://example.com/callback?code=auth_123&state=invalid_base64'); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(response.headers.get('Location')).toBe('http://example.com/search?q=test&page=2'); }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('defaults to / when result.returnPathname is empty', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult({ returnPathname: undefined })); - // Should default to root path - expect(response.headers.get('Location')).toBe('http://example.com/'); - }); + const response = await handleCallbackRoute()({ request }); - it('handles null state', async () => { - const request = new Request('http://example.com/callback?code=auth_123&state=null'); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(response.headers.get('Location')).toBe('http://example.com/'); }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('prefers options.returnPathname when provided', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult({ returnPathname: '/dashboard' })); - expect(response.headers.get('Location')).toBe('http://example.com/'); - }); + const response = await handleCallbackRoute({ returnPathname: '/custom' })({ request }); - it('extracts session headers from response', async () => { - const request = new Request('http://example.com/callback?code=auth_123'); - mockHandleCallback.mockResolvedValue({ - headers: { - 'Set-Cookie': 'session=abc123', - 'X-Custom': 'value', - }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(response.headers.get('Location')).toBe('http://example.com/custom'); }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('passes code and state to authkit.handleCallback without a cookieValue arg', async () => { + const request = new Request('http://example.com/callback?code=auth_123&state=s', { + headers: { cookie: 'wos-auth-verifier=sealed-abc-123' }, + }); + mockHandleCallback.mockResolvedValue(successResult()); - expect(response.headers.get('Set-Cookie')).toBe('session=abc123'); - expect(response.headers.get('X-Custom')).toBe('value'); - }); - - it('handles callback errors', async () => { - const request = new Request('http://example.com/callback?code=invalid'); - mockHandleCallback.mockRejectedValue(new Error('Invalid code')); + await handleCallbackRoute()({ request }); - // Suppress expected error log - const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + expect(mockHandleCallback).toHaveBeenCalledWith(request, expect.any(Response), { code: 'auth_123', state: 's' }); + const passedOptions = mockHandleCallback.mock.calls[0]![2]; + expect(passedOptions).not.toHaveProperty('cookieValue'); + }); - const handler = handleCallbackRoute(); - const response = await handler({ request }); + it('passes state as undefined when absent from the URL', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult()); - expect(response.status).toBe(500); - const body = await response.json(); - expect(body.error.message).toBe('Authentication failed'); - expect(body.error.description).toContain("Couldn't sign in"); + await handleCallbackRoute()({ request }); - consoleErrorSpy.mockRestore(); - }); + expect(mockHandleCallback).toHaveBeenCalledWith(request, expect.any(Response), { + code: 'auth_123', + state: undefined, + }); + }); - it('calls onSuccess hook with auth data', async () => { - const request = new Request('http://example.com/callback?code=auth_123'); - const mockAuthResponse = { - accessToken: 'access_token_123', - refreshToken: 'refresh_token_123', - user: { id: 'user_123', email: 'test@example.com', firstName: 'Test', lastName: 'User' }, - impersonator: { email: 'admin@example.com', reason: 'Support' }, - oauthTokens: { provider: 'google', accessToken: 'google_token' }, - authenticationMethod: 'GoogleOAuth', - organizationId: 'org_123', - }; - - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: mockAuthResponse, + it('appends both the session cookie and the PKCE delete cookie from the library', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue({ + headers: { + 'Set-Cookie': ['wos-session=abc123', 'wos-auth-verifier=; Path=/; Max-Age=0; HttpOnly; SameSite=Lax'], + }, + returnPathname: '/', + state: undefined, + authResponse: baseAuthResponse, + }); + + const response = await handleCallbackRoute()({ request }); + + const setCookies = response.headers.getSetCookie(); + expect(setCookies.some((c) => c.startsWith('wos-session=abc123'))).toBe(true); + expect(setCookies.some((c) => c.startsWith('wos-auth-verifier='))).toBe(true); + expect(setCookies).toHaveLength(2); }); - const onSuccess = vi.fn(); - const handler = handleCallbackRoute({ onSuccess }); - await handler({ request }); - - expect(onSuccess).toHaveBeenCalledOnce(); - expect(onSuccess).toHaveBeenCalledWith({ - accessToken: 'access_token_123', - refreshToken: 'refresh_token_123', - user: mockAuthResponse.user, - impersonator: mockAuthResponse.impersonator, - oauthTokens: mockAuthResponse.oauthTokens, - authenticationMethod: 'GoogleOAuth', - organizationId: 'org_123', - state: undefined, + it('extracts session headers from plain-object shape', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue({ + headers: { + 'Set-Cookie': 'session=abc123', + 'X-Custom': 'value', + }, + returnPathname: '/', + state: undefined, + authResponse: baseAuthResponse, + }); + + const response = await handleCallbackRoute()({ request }); + + expect(response.headers.get('X-Custom')).toBe('value'); + expect(response.headers.getSetCookie().some((c) => c.startsWith('session=abc123'))).toBe(true); }); - }); - it('calls onSuccess with custom state', async () => { - const customState = 'custom.user.state'; - const request = new Request(`http://example.com/callback?code=auth_123&state=${customState}`); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { + it('calls onSuccess with result.state (unsealed customState) and auth data', async () => { + const request = new Request('http://example.com/callback?code=auth_123&state=encoded'); + mockHandleCallback.mockResolvedValue( + successResult({ + state: 'user.custom.state', + authResponse: { + ...baseAuthResponse, + impersonator: { email: 'admin@example.com', reason: 'Support' }, + oauthTokens: { provider: 'google', accessToken: 'google_token' }, + authenticationMethod: 'GoogleOAuth', + organizationId: 'org_123', + }, + }), + ); + + const onSuccess = vi.fn(); + await handleCallbackRoute({ onSuccess })({ request }); + + expect(onSuccess).toHaveBeenCalledWith({ accessToken: 'access_token', refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + user: baseAuthResponse.user, + impersonator: { email: 'admin@example.com', reason: 'Support' }, + oauthTokens: { provider: 'google', accessToken: 'google_token' }, + authenticationMethod: 'GoogleOAuth', + organizationId: 'org_123', + state: 'user.custom.state', + }); }); - const onSuccess = vi.fn(); - const handler = handleCallbackRoute({ onSuccess }); - await handler({ request }); + it('passes through undefined state when core returns no customState', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockHandleCallback.mockResolvedValue(successResult()); - expect(onSuccess).toHaveBeenCalledWith( - expect.objectContaining({ - state: 'user.state', - }), - ); - }); + const onSuccess = vi.fn(); + await handleCallbackRoute({ onSuccess })({ request }); - it('uses custom returnPathname from options', async () => { - const request = new Request('http://example.com/callback?code=auth_123'); - mockHandleCallback.mockResolvedValue({ - response: { headers: new Map() }, - authResponse: { - accessToken: 'access_token', - refreshToken: 'refresh_token', - user: { id: 'user_123', email: 'test@example.com' }, - }, + expect(onSuccess).toHaveBeenCalledWith(expect.objectContaining({ state: undefined })); }); + }); - const handler = handleCallbackRoute({ returnPathname: '/custom-redirect' }); - const response = await handler({ request }); + describe('error path', () => { + it('returns 500 with generic body on handleCallback failure', async () => { + const request = new Request('http://example.com/callback?code=invalid'); + mockHandleCallback.mockRejectedValue(new Error('Invalid code')); + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - expect(response.headers.get('Location')).toBe('http://example.com/custom-redirect'); - }); + const response = await handleCallbackRoute()({ request }); - it('calls onError hook on missing code', async () => { - const request = new Request('http://example.com/callback'); - const onError = vi.fn().mockReturnValue(new Response('Custom error', { status: 403 })); + expect(response.status).toBe(500); + const body = await response.json(); + expect(body.error.message).toBe('Authentication failed'); + expect(body.error.description).toContain("Couldn't sign in"); + expect(body.error).not.toHaveProperty('details'); + expect(response.headers.getSetCookie().some((c) => c.startsWith('wos-auth-verifier='))).toBe(true); - const handler = handleCallbackRoute({ onError }); - const response = await handler({ request }); + consoleSpy.mockRestore(); + }); - expect(onError).toHaveBeenCalledOnce(); - expect(onError).toHaveBeenCalledWith({ - error: expect.any(Error), - request, + it('calls onError with the underlying error and appends delete-cookie', async () => { + const request = new Request('http://example.com/callback?code=invalid'); + const err = new Error('Auth failed'); + mockHandleCallback.mockRejectedValue(err); + const onError = vi.fn().mockReturnValue( + new Response('Custom error page', { + status: 418, + headers: { 'X-Custom': 'preserved' }, + }), + ); + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + const response = await handleCallbackRoute({ onError })({ request }); + + expect(onError).toHaveBeenCalledWith({ error: err, request }); + expect(response.status).toBe(418); + expect(response.headers.get('X-Custom')).toBe('preserved'); + expect(await response.text()).toBe('Custom error page'); + expect(response.headers.getSetCookie().some((c) => c.startsWith('wos-auth-verifier='))).toBe(true); + + consoleSpy.mockRestore(); }); - expect(response.status).toBe(403); - expect(await response.text()).toBe('Custom error'); - }); - it('calls onError hook on callback failure', async () => { - const request = new Request('http://example.com/callback?code=invalid'); - const error = new Error('Auth failed'); - mockHandleCallback.mockRejectedValue(error); + it('emits static fallback delete-cookies when getAuthkit() rejects', async () => { + const request = new Request('http://example.com/callback?code=auth_123'); + mockGetAuthkitImpl = () => Promise.reject(new Error('Config missing')); + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - const onError = vi.fn().mockReturnValue(new Response('Custom error page', { status: 500 })); - const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + const response = await handleCallbackRoute()({ request }); - const handler = handleCallbackRoute({ onError }); - const response = await handler({ request }); + expect(response.status).toBe(500); + const setCookies = response.headers.getSetCookie(); + expect(setCookies).toHaveLength(2); + expect(setCookies[0]).toContain('SameSite=Lax'); + expect(setCookies[1]).toContain('SameSite=None'); + expect(setCookies[1]).toContain('Secure'); + expect(setCookies.every((c) => c.includes('Max-Age=0'))).toBe(true); - expect(onError).toHaveBeenCalledOnce(); - expect(onError).toHaveBeenCalledWith({ - error, - request, + consoleSpy.mockRestore(); }); - expect(response.status).toBe(500); - expect(await response.text()).toBe('Custom error page'); - - consoleErrorSpy.mockRestore(); }); }); diff --git a/src/server/server.ts b/src/server/server.ts index 9e4c0e0..bcbcdb9 100644 --- a/src/server/server.ts +++ b/src/server/server.ts @@ -1,7 +1,11 @@ import { getAuthkit } from './authkit-loader.js'; -import { decodeState } from './auth-helpers.js'; import type { HandleCallbackOptions } from './types.js'; +const STATIC_FALLBACK_DELETE_HEADERS: readonly string[] = [ + 'wos-auth-verifier=; Path=/; HttpOnly; SameSite=Lax; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT', + 'wos-auth-verifier=; Path=/; HttpOnly; SameSite=None; Secure; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT', +]; + /** * Creates a callback route handler for OAuth authentication. * This should be used in your callback route to complete the authentication flow. @@ -33,9 +37,7 @@ import type { HandleCallbackOptions } from './types.js'; * handlers: { * GET: handleCallbackRoute({ * onSuccess: async ({ user, authenticationMethod }) => { - * // Create user record in your database * await db.users.upsert({ id: user.id, email: user.email }); - * // Track analytics * analytics.track('User Signed In', { method: authenticationMethod }); * }, * }), @@ -50,72 +52,112 @@ export function handleCallbackRoute(options: HandleCallbackOptions = {}) { }; } +/** + * Extract the `Set-Cookie` header(s) produced by `authkit.clearPendingVerifier()` + * so we can attach them to whatever response we emit on an error path. + * + * The library returns a `HeadersBag` whose `Set-Cookie` is either a string or a + * `string[]`. We coerce to an array so callers can append each entry in turn. + */ +async function buildVerifierDeleteHeaders(authkit: Awaited>): Promise { + try { + const { headers } = await authkit.clearPendingVerifier(new Response()); + const setCookie = headers?.['Set-Cookie']; + if (!setCookie) return STATIC_FALLBACK_DELETE_HEADERS; + return Array.isArray(setCookie) ? setCookie : [setCookie]; + } catch (error) { + console.error('[authkit-tanstack-react-start] clearPendingVerifier failed:', error); + return STATIC_FALLBACK_DELETE_HEADERS; + } +} + async function handleCallbackInternal(request: Request, options: HandleCallbackOptions): Promise { + let deleteCookieHeaders: readonly string[] = STATIC_FALLBACK_DELETE_HEADERS; + let authkit: Awaited> | undefined; + + try { + authkit = await getAuthkit(); + deleteCookieHeaders = await buildVerifierDeleteHeaders(authkit); + } catch (setupError) { + console.error('[authkit-tanstack-react-start] Callback setup failed:', setupError); + } + const url = new URL(request.url); const code = url.searchParams.get('code'); const state = url.searchParams.get('state'); if (!code) { - if (options.onError) { - return options.onError({ error: new Error('Missing authorization code'), request }); - } - - return new Response(JSON.stringify({ error: { message: 'Missing authorization code' } }), { - status: 400, - headers: { 'Content-Type': 'application/json' }, - }); + return errorResponse(new Error('Missing authorization code'), request, options, deleteCookieHeaders, 400); + } + if (!authkit) { + return errorResponse(new Error('AuthKit not initialized'), request, options, deleteCookieHeaders, 500); } try { - const { returnPathname: stateReturnPathname, customState } = decodeState(state); - const returnPathname = options.returnPathname ?? stateReturnPathname; - const response = new Response(); - const authkit = await getAuthkit(); - const result = await authkit.handleCallback(request, response, { code, state: state ?? undefined }); - const { authResponse } = result; + const result = await authkit.handleCallback(request, response, { + code, + state: state ?? undefined, + }); if (options.onSuccess) { await options.onSuccess({ - accessToken: authResponse.accessToken, - refreshToken: authResponse.refreshToken, - user: authResponse.user, - impersonator: authResponse.impersonator, - oauthTokens: authResponse.oauthTokens, - authenticationMethod: authResponse.authenticationMethod, - organizationId: authResponse.organizationId, - state: customState, + accessToken: result.authResponse.accessToken, + refreshToken: result.authResponse.refreshToken, + user: result.authResponse.user, + impersonator: result.authResponse.impersonator, + oauthTokens: result.authResponse.oauthTokens, + authenticationMethod: result.authResponse.authenticationMethod, + organizationId: result.authResponse.organizationId, + state: result.state, }); } + const returnPathname = options.returnPathname ?? result.returnPathname ?? '/'; const redirectUrl = buildRedirectUrl(url, returnPathname); - const sessionHeaders = extractSessionHeaders(result); - return new Response(null, { - status: 307, - headers: { - Location: redirectUrl.toString(), - ...sessionHeaders, - }, - }); + const headers = new Headers({ Location: redirectUrl.toString() }); + // `result` now carries BOTH the session Set-Cookie and the verifier-delete + // Set-Cookie as a `string[]`. `appendSessionHeaders` preserves each entry + // via `.append` so they survive as distinct HTTP headers. + appendSessionHeaders(headers, result); + + return new Response(null, { status: 307, headers }); } catch (error) { console.error('OAuth callback failed:', error); + return errorResponse(error, request, options, deleteCookieHeaders, 500); + } +} - if (options.onError) { - return options.onError({ error, request }); - } - - return new Response( - JSON.stringify({ - error: { - message: 'Authentication failed', - description: "Couldn't sign in. Please contact your organization admin if the issue persists.", - details: error instanceof Error ? error.message : String(error), - }, - }), - { status: 500, headers: { 'Content-Type': 'application/json' } }, - ); +async function errorResponse( + error: unknown, + request: Request, + options: HandleCallbackOptions, + deleteCookieHeaders: readonly string[], + defaultStatus: number, +): Promise { + if (options.onError) { + const userResponse = await options.onError({ error, request }); + const headers = new Headers(userResponse.headers); + for (const h of deleteCookieHeaders) headers.append('Set-Cookie', h); + return new Response(userResponse.body, { + status: userResponse.status, + statusText: userResponse.statusText, + headers, + }); } + + const headers = new Headers({ 'Content-Type': 'application/json' }); + for (const h of deleteCookieHeaders) headers.append('Set-Cookie', h); + return new Response( + JSON.stringify({ + error: { + message: 'Authentication failed', + description: "Couldn't sign in. Please contact your organization admin if the issue persists.", + }, + }), + { status: defaultStatus, headers }, + ); } function buildRedirectUrl(originalUrl: URL, returnPathname: string): URL { @@ -134,15 +176,31 @@ function buildRedirectUrl(originalUrl: URL, returnPathname: string): URL { return url; } -function extractSessionHeaders(result: any): Record { - const setCookie = result?.response?.headers?.get?.('Set-Cookie'); - if (setCookie) { - return { 'Set-Cookie': setCookie }; - } - +function appendSessionHeaders(target: Headers, result: any): void { + // Prefer the plain-object `headers` bag when present — it's the library's + // primary channel and carries a `string[]` when multiple cookies are emitted. if (result?.headers && typeof result.headers === 'object') { - return result.headers; + for (const [key, value] of Object.entries(result.headers)) { + if (typeof value === 'string') { + target.append(key, value); + } else if (Array.isArray(value)) { + for (const v of value) { + target.append(key, typeof v === 'string' ? v : String(v)); + } + } + } + return; } - return {}; + // Fallback: the library routed its output through a mutated Response + // (storage's context-unavailable path). + const responseHeaders: Headers | undefined = result?.response?.headers; + if (responseHeaders && typeof responseHeaders.getSetCookie === 'function') { + for (const value of responseHeaders.getSetCookie()) { + target.append('Set-Cookie', value); + } + } else if (responseHeaders && typeof responseHeaders.get === 'function') { + const setCookie = responseHeaders.get('Set-Cookie'); + if (setCookie) target.append('Set-Cookie', setCookie); + } } diff --git a/src/server/storage.spec.ts b/src/server/storage.spec.ts index bf4db99..0bf061c 100644 --- a/src/server/storage.spec.ts +++ b/src/server/storage.spec.ts @@ -36,7 +36,71 @@ describe('TanStackStartCookieSessionStorage', () => { mockContextAvailable = true; }); - describe('getSession', () => { + describe('getCookie', () => { + it('returns the named cookie value', async () => { + const request = new Request('http://example.com', { + headers: { cookie: 'wos-auth-verifier=sealed-abc' }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBe('sealed-abc'); + }); + + it('returns null without cookies', async () => { + const request = new Request('http://example.com'); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBeNull(); + }); + + it('returns null when the named cookie is absent', async () => { + const request = new Request('http://example.com', { + headers: { cookie: 'other=value' }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBeNull(); + }); + + it('URI-decodes the cookie value', async () => { + const encoded = encodeURIComponent('value with spaces & symbols'); + const request = new Request('http://example.com', { + headers: { cookie: `wos-auth-verifier=${encoded}` }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBe('value with spaces & symbols'); + }); + + it('returns the named cookie when mixed with others', async () => { + const request = new Request('http://example.com', { + headers: { cookie: 'other=x; wos-auth-verifier=target; another=y' }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBe('target'); + }); + + it('preserves = padding inside a sealed cookie value', async () => { + const request = new Request('http://example.com', { + headers: { cookie: 'wos-auth-verifier=abc==' }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBe('abc=='); + }); + + it('returns null on malformed percent-encoding instead of throwing', async () => { + const request = new Request('http://example.com', { + headers: { cookie: 'wos-auth-verifier=%E0%A4%A' }, + }); + + const result = await storage.getCookie(request, 'wos-auth-verifier'); + expect(result).toBeNull(); + }); + }); + + describe('getSession (inherited wrapper)', () => { it('extracts session from cookies', async () => { const request = new Request('http://example.com', { headers: { cookie: 'wos_session=test-value' }, diff --git a/src/server/storage.ts b/src/server/storage.ts index b11a94e..07cc881 100644 --- a/src/server/storage.ts +++ b/src/server/storage.ts @@ -1,14 +1,21 @@ import { CookieSessionStorage } from '@workos/authkit-session'; import { getAuthKitContextOrNull } from './context.js'; +import { parseCookies } from './cookie-utils.js'; export class TanStackStartCookieSessionStorage extends CookieSessionStorage { - async getSession(request: Request): Promise { + async getCookie(request: Request, name: string): Promise { const cookieHeader = request.headers.get('cookie'); if (!cookieHeader) return null; - const cookies = this.parseCookies(cookieHeader); - const value = cookies[this.cookieName]; - return value ? decodeURIComponent(value) : null; + const cookies = parseCookies(cookieHeader); + const raw = cookies[name]; + if (raw === undefined) return null; + try { + return decodeURIComponent(raw); + } catch { + // Malformed percent-encoding — surface as missing rather than throwing. + return null; + } } protected async applyHeaders( @@ -36,13 +43,4 @@ export class TanStackStartCookieSessionStorage extends CookieSessionStorage newResponse.headers.append(key, value)); return { response: newResponse }; } - - private parseCookies(cookieHeader: string): Record { - return Object.fromEntries( - cookieHeader.split(';').map((cookie) => { - const [key, ...valueParts] = cookie.trim().split('='); - return [key, valueParts.join('=')]; - }), - ); - } } diff --git a/tests/exports.spec.ts b/tests/exports.spec.ts index e874b92..5b0a93f 100644 --- a/tests/exports.spec.ts +++ b/tests/exports.spec.ts @@ -16,6 +16,10 @@ describe('SDK exports', () => { // Middleware expect(exports.authkitMiddleware).toBeDefined(); + + // Error classes re-exported from authkit-session for adopter error handling + expect(exports.OAuthStateMismatchError).toBeDefined(); + expect(exports.PKCECookieMissingError).toBeDefined(); }); it('exports expected types', () => {