From 19037b0919fbd7939eb71b04ae88f9319b1384fe Mon Sep 17 00:00:00 2001 From: BYND Date: Thu, 21 May 2026 18:28:11 +0300 Subject: [PATCH] Refresh expired OAuth access tokens --- src/server/api/routers/accounts.ts | 9 +- src/server/auth/refresh-token.ts | 185 +++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 src/server/auth/refresh-token.ts diff --git a/src/server/api/routers/accounts.ts b/src/server/api/routers/accounts.ts index bb6043ca..062f1b7e 100644 --- a/src/server/api/routers/accounts.ts +++ b/src/server/api/routers/accounts.ts @@ -1,4 +1,5 @@ import { createTRPCRouter, protectedProcedure } from "@/server/api/trpc"; +import { refreshAccountAccessTokenIfNeeded } from "@/server/auth/refresh-token"; import { z } from "zod"; export const accountsRouter = createTRPCRouter({ @@ -34,23 +35,27 @@ export const accountsRouter = createTRPCRouter({ getAccount: protectedProcedure .input(z.string()) .query(async ({ ctx, input }) => { - return ctx.db.account.findFirst({ + const account = await ctx.db.account.findFirst({ where: { id: input, userId: ctx.session.user.id, }, }); + + return refreshAccountAccessTokenIfNeeded(account); }), getAccountByProvider: protectedProcedure .input(z.string()) .query(async ({ ctx, input }) => { - return ctx.db.account.findFirst({ + const account = await ctx.db.account.findFirst({ where: { userId: ctx.session.user.id, provider: input, }, }); + + return refreshAccountAccessTokenIfNeeded(account); }), hasProviderAccount: protectedProcedure diff --git a/src/server/auth/refresh-token.ts b/src/server/auth/refresh-token.ts new file mode 100644 index 00000000..8bf8cd5c --- /dev/null +++ b/src/server/auth/refresh-token.ts @@ -0,0 +1,185 @@ +import { z } from "zod"; + +import { env } from "@/env"; +import { db } from "@/server/db"; + +import type { Account } from "@prisma/client"; + +const REFRESH_BUFFER_SECONDS = 60; + +const tokenResponseSchema = z + .object({ + access_token: z.string(), + expires_in: z.coerce.number().optional(), + expires_at: z.coerce.number().optional(), + refresh_token: z.string().optional(), + token_type: z.string().optional(), + scope: z.string().optional(), + }) + .passthrough(); + +type RefreshConfig = { + tokenEndpoint: string; + clientId: string; + clientSecret?: string; + useBasicAuth?: boolean; +}; + +export async function refreshAccountAccessTokenIfNeeded( + account: Account | null, +) { + if (!account || !shouldRefresh(account)) { + return account; + } + + if (!account.refresh_token) { + return account; + } + + const config = getRefreshConfig(account.provider); + + if (!config) { + return account; + } + + const response = await fetch(config.tokenEndpoint, { + method: "POST", + headers: getRefreshHeaders(config), + body: getRefreshBody(config, account.refresh_token), + }); + + if (!response.ok) { + throw new Error( + `Failed to refresh ${account.provider} access token: ${response.status}`, + ); + } + + const tokens = tokenResponseSchema.parse(await response.json()); + const expiresAt = + tokens.expires_at ?? + (tokens.expires_in + ? Math.floor(Date.now() / 1000 + tokens.expires_in) + : account.expires_at); + + return db.account.update({ + where: { + id: account.id, + }, + data: { + access_token: tokens.access_token, + expires_at: expiresAt, + refresh_token: tokens.refresh_token ?? account.refresh_token, + token_type: tokens.token_type ?? account.token_type, + scope: tokens.scope ?? account.scope, + }, + }); +} + +function shouldRefresh(account: Account) { + if (!account.expires_at) { + return false; + } + + return account.expires_at - REFRESH_BUFFER_SECONDS <= Date.now() / 1000; +} + +function getRefreshHeaders(config: RefreshConfig) { + const headers: Record = { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + }; + + if (config.useBasicAuth && config.clientSecret) { + headers.Authorization = `Basic ${Buffer.from( + `${config.clientId}:${config.clientSecret}`, + ).toString("base64")}`; + } + + return headers; +} + +function getRefreshBody(config: RefreshConfig, refreshToken: string) { + const body = new URLSearchParams({ + grant_type: "refresh_token", + refresh_token: refreshToken, + }); + + if (!config.useBasicAuth) { + body.set("client_id", config.clientId); + + if (config.clientSecret) { + body.set("client_secret", config.clientSecret); + } + } + + return body; +} + +function getRefreshConfig(provider: string): RefreshConfig | null { + switch (provider) { + case "google": + if (!("AUTH_GOOGLE_ID" in env && "AUTH_GOOGLE_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://oauth2.googleapis.com/token", + clientId: env.AUTH_GOOGLE_ID, + clientSecret: env.AUTH_GOOGLE_SECRET, + }; + case "discord": + if (!("AUTH_DISCORD_ID" in env && "AUTH_DISCORD_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://discord.com/api/oauth2/token", + clientId: env.AUTH_DISCORD_ID, + clientSecret: env.AUTH_DISCORD_SECRET, + }; + case "github": + if (!("AUTH_GITHUB_ID" in env && "AUTH_GITHUB_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://github.com/login/oauth/access_token", + clientId: env.AUTH_GITHUB_ID, + clientSecret: env.AUTH_GITHUB_SECRET, + }; + case "spotify": + if (!("AUTH_SPOTIFY_ID" in env && "AUTH_SPOTIFY_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://accounts.spotify.com/api/token", + clientId: env.AUTH_SPOTIFY_ID, + clientSecret: env.AUTH_SPOTIFY_SECRET, + useBasicAuth: true, + }; + case "strava": + if (!("AUTH_STRAVA_ID" in env && "AUTH_STRAVA_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://www.strava.com/oauth/token", + clientId: env.AUTH_STRAVA_ID, + clientSecret: env.AUTH_STRAVA_SECRET, + }; + case "twitter": + if (!("AUTH_TWITTER_ID" in env && "AUTH_TWITTER_SECRET" in env)) { + return null; + } + + return { + tokenEndpoint: "https://api.twitter.com/2/oauth2/token", + clientId: env.AUTH_TWITTER_ID, + clientSecret: env.AUTH_TWITTER_SECRET, + useBasicAuth: true, + }; + default: + return null; + } +}