diff --git a/packages/backend/src/auth/middleware.ts b/packages/backend/src/auth/middleware.ts index 9e8e624..9a3ddb4 100644 --- a/packages/backend/src/auth/middleware.ts +++ b/packages/backend/src/auth/middleware.ts @@ -1,7 +1,12 @@ import type { BunRequest } from "bun"; -import { verifyToken } from "./utils"; +import { getSession } from "../db/queries"; +import { parseCookies, verifyToken } from "./utils"; -export type AuthedRequest = T & { userId: number }; +export type AuthedRequest = T & { + userId: number; + sessionId: number; + csrfToken: string; +}; type RouteHandler = (req: T) => Response | Promise; @@ -9,53 +14,73 @@ type AuthedRouteHandler = ( req: AuthedRequest, ) => Response | Promise; -const extractBearerToken = (req: Request) => { - const header = req.headers.get("Authorization"); - if (!header) { - return null; - } - - const [type, token] = header.split(" "); - if (type !== "Bearer" || !token) { - return null; - } - - return token; +const extractTokenFromCookie = (req: Request) => { + const cookies = parseCookies(req.headers.get("Cookie")); + return cookies.token || null; }; export const withAuth = (handler: AuthedRouteHandler): RouteHandler => { return async (req: T) => { - const token = extractBearerToken(req); + const token = extractTokenFromCookie(req); if (!token) { return new Response("Unauthorized", { status: 401 }); } try { - const { userId } = verifyToken(token); - return handler(Object.assign(req, { userId }) as AuthedRequest); + const { sessionId, userId } = verifyToken(token); + + // validate session exists and is not expired + const session = await getSession(sessionId); + if (!session || session.expiresAt < new Date()) { + return new Response("Session expired", { status: 401 }); + } + + return handler( + Object.assign(req, { + userId, + sessionId, + csrfToken: session.csrfToken, + }) as AuthedRequest, + ); } catch { return new Response("Invalid token", { status: 401 }); } }; }; +export const withCSRF = (handler: AuthedRouteHandler): AuthedRouteHandler => { + return async (req: AuthedRequest) => { + // only validate CSRF for methods which modify state + if (["POST", "PUT", "PATCH", "DELETE"].includes(req.method)) { + const csrfHeader = req.headers.get("X-CSRF-Token"); + if (!csrfHeader || csrfHeader !== req.csrfToken) { + return new Response("Invalid CSRF token", { status: 403 }); + } + } + return handler(req); + }; +}; + const CORS_ALLOWED_ORIGINS = (process.env.CORS_ORIGIN ?? "http://localhost:1420") .split(",") .map((origin) => origin.trim()) .filter(Boolean); const CORS_ALLOW_METHODS = process.env.CORS_ALLOW_METHODS ?? "GET,POST,PUT,PATCH,DELETE,OPTIONS"; -const CORS_ALLOW_HEADERS_DEFAULT = process.env.CORS_ALLOW_HEADERS ?? "Content-Type, Authorization"; +const CORS_ALLOW_HEADERS_DEFAULT = + process.env.CORS_ALLOW_HEADERS ?? "Content-Type, Authorization, X-CSRF-Token"; const CORS_MAX_AGE = process.env.CORS_MAX_AGE ?? "86400"; const getCorsAllowOrigin = (req: Request) => { const requestOrigin = req.headers.get("Origin"); if (!requestOrigin) { - return "*"; + return null; } + // when wildcard is configured, reflect the request origin back + // this allows credentials to work with any origin if (CORS_ALLOWED_ORIGINS.includes("*")) { - return "*"; + return requestOrigin; } if (CORS_ALLOWED_ORIGINS.includes(requestOrigin)) { @@ -71,9 +96,8 @@ const buildCorsHeaders = (req: Request) => { const allowOrigin = getCorsAllowOrigin(req); if (allowOrigin) { headers.set("Access-Control-Allow-Origin", allowOrigin); - if (allowOrigin !== "*") { - headers.set("Vary", "Origin"); - } + headers.set("Access-Control-Allow-Credentials", "true"); + headers.set("Vary", "Origin"); } headers.set("Access-Control-Allow-Methods", CORS_ALLOW_METHODS);