ratelimiting via "withRateLimit"

This commit is contained in:
Oliver Bryan
2026-01-21 23:04:38 +00:00
parent be57b4d6df
commit f780725a23
3 changed files with 199 additions and 47 deletions

View File

@@ -1,5 +1,6 @@
import type { BunRequest } from "bun"; import type { BunRequest } from "bun";
import { getSession } from "../db/queries"; import { getSession } from "../db/queries";
import { GLOBAL_RATE_LIMIT, getClientIP, rateLimitResponse, recordRateLimitAttempt } from "./rate-limit";
import { parseCookies, verifyToken } from "./utils"; import { parseCookies, verifyToken } from "./utils";
export type AuthedRequest<T extends BunRequest = BunRequest> = T & { export type AuthedRequest<T extends BunRequest = BunRequest> = T & {
@@ -19,6 +20,19 @@ const extractTokenFromCookie = (req: Request) => {
return cookies.token || null; return cookies.token || null;
}; };
export const withRateLimit = <T extends BunRequest>(handler: RouteHandler<T>): RouteHandler<T> => {
return async (req: T) => {
const ip = getClientIP(req);
const key = `global:ip:${ip}`;
const attempt = recordRateLimitAttempt(key, GLOBAL_RATE_LIMIT);
if (!attempt.allowed) {
return rateLimitResponse(attempt.retryAfterMs);
}
return handler(req);
};
};
export const withAuth = <T extends BunRequest>(handler: AuthedRouteHandler<T>): RouteHandler<T> => { export const withAuth = <T extends BunRequest>(handler: AuthedRouteHandler<T>): RouteHandler<T> => {
return async (req: T) => { return async (req: T) => {
const token = extractTokenFromCookie(req); const token = extractTokenFromCookie(req);

View File

@@ -0,0 +1,133 @@
type RateLimitConfig = {
windowMs: number;
max: number;
backoffBaseMs?: number;
backoffMaxMs?: number;
};
type RateLimitState = {
count: number;
windowStart: number;
blockedUntil?: number;
};
type RateLimitResult = {
allowed: boolean;
retryAfterMs?: number;
};
const rateLimitStore = new Map<string, RateLimitState>();
export const LOGIN_RATE_LIMIT: RateLimitConfig = {
windowMs: 15 * 60 * 1000,
max: 5,
backoffBaseMs: 60 * 1000,
backoffMaxMs: 15 * 60 * 1000,
};
export const GLOBAL_RATE_LIMIT: RateLimitConfig = {
windowMs: 60 * 1000,
max: 300,
};
export const REGISTER_RATE_LIMIT: RateLimitConfig = {
windowMs: 60 * 60 * 1000,
max: 3,
};
export const getClientIP = (req: Request) => {
const forwardedFor = req.headers.get("x-forwarded-for");
if (forwardedFor) {
return forwardedFor.split(",")[0]?.trim() || "unknown";
}
return req.headers.get("x-real-ip") ?? "unknown";
};
const getRetryAfter = (state: RateLimitState, now: number, config: RateLimitConfig) => {
if (state.blockedUntil && state.blockedUntil > now) {
return state.blockedUntil - now;
}
const windowEndsAt = state.windowStart + config.windowMs;
return windowEndsAt > now ? windowEndsAt - now : 0;
};
export const checkRateLimit = (key: string, config: RateLimitConfig): RateLimitResult => {
const now = Date.now();
const state = rateLimitStore.get(key);
if (!state) {
return { allowed: true };
}
if (now - state.windowStart > config.windowMs) {
rateLimitStore.delete(key);
return { allowed: true };
}
if (state.blockedUntil && state.blockedUntil > now) {
return { allowed: false, retryAfterMs: state.blockedUntil - now };
}
if (state.count >= config.max) {
return { allowed: false, retryAfterMs: getRetryAfter(state, now, config) };
}
return { allowed: true };
};
export const recordRateLimitAttempt = (key: string, config: RateLimitConfig): RateLimitResult => {
const now = Date.now();
const existing = rateLimitStore.get(key);
const state: RateLimitState = existing
? { ...existing }
: {
count: 0,
windowStart: now,
};
if (now - state.windowStart > config.windowMs) {
state.count = 0;
state.windowStart = now;
state.blockedUntil = undefined;
}
state.count += 1;
if (state.count >= config.max) {
if (config.backoffBaseMs) {
const overage = state.count - config.max;
const delay = Math.min(
config.backoffMaxMs ?? config.backoffBaseMs,
config.backoffBaseMs * 2 ** Math.max(0, overage),
);
state.blockedUntil = now + delay;
} else {
state.blockedUntil = state.windowStart + config.windowMs;
}
}
rateLimitStore.set(key, state);
if (state.blockedUntil && state.blockedUntil > now) {
return { allowed: false, retryAfterMs: state.blockedUntil - now };
}
return { allowed: true };
};
export const resetRateLimit = (key: string) => {
rateLimitStore.delete(key);
};
export const rateLimitResponse = (retryAfterMs?: number) => {
const headers = new Headers();
if (retryAfterMs && retryAfterMs > 0) {
headers.set("Retry-After", Math.ceil(retryAfterMs / 1000).toString());
}
return Response.json(
{ error: "too many requests", code: "RATE_LIMITED" },
{
status: 429,
headers,
},
);
};

View File

@@ -1,4 +1,5 @@
import { withAuth, withCors, withCSRF } from "./auth/middleware"; import type { BunRequest } from "bun";
import { withAuth, withCors, withCSRF, withRateLimit } from "./auth/middleware";
import { testDB } from "./db/client"; import { testDB } from "./db/client";
import { cleanupExpiredSessions } from "./db/queries"; import { cleanupExpiredSessions } from "./db/queries";
import { routes } from "./routes"; import { routes } from "./routes";
@@ -20,69 +21,73 @@ const startSessionCleanup = () => {
setInterval(cleanup, SESSION_CLEANUP_INTERVAL); setInterval(cleanup, SESSION_CLEANUP_INTERVAL);
}; };
type RouteHandler<T extends BunRequest = BunRequest> = (req: T) => Response | Promise<Response>;
const withGlobal = <T extends BunRequest>(handler: RouteHandler<T>) => withCors(withRateLimit(handler));
const main = async () => { const main = async () => {
const server = Bun.serve({ const server = Bun.serve({
port: Number(PORT), port: Number(PORT),
routes: { routes: {
"/": withCors(() => new Response(`title: tnirps\ndev-mode: ${DEV}\nport: ${PORT}`)), "/": withGlobal(() => new Response(`title: tnirps\ndev-mode: ${DEV}\nport: ${PORT}`)),
"/health": withCors(() => new Response("OK")), "/health": withGlobal(() => new Response("OK")),
// routes that modify state require withCSRF middleware // routes that modify state require withCSRF middleware
"/auth/register": withCors(routes.authRegister), "/auth/register": withGlobal(routes.authRegister),
"/auth/login": withCors(routes.authLogin), "/auth/login": withGlobal(routes.authLogin),
"/auth/logout": withCors(withAuth(withCSRF(routes.authLogout))), "/auth/logout": withGlobal(withAuth(withCSRF(routes.authLogout))),
"/auth/me": withCors(withAuth(routes.authMe)), "/auth/me": withGlobal(withAuth(routes.authMe)),
"/user/by-username": withCors(withAuth(routes.userByUsername)), "/user/by-username": withGlobal(withAuth(routes.userByUsername)),
"/user/update": withCors(withAuth(withCSRF(routes.userUpdate))), "/user/update": withGlobal(withAuth(withCSRF(routes.userUpdate))),
"/user/upload-avatar": withCors(withAuth(withCSRF(routes.userUploadAvatar))), "/user/upload-avatar": withGlobal(withAuth(withCSRF(routes.userUploadAvatar))),
"/issue/create": withCors(withAuth(withCSRF(routes.issueCreate))), "/issue/create": withGlobal(withAuth(withCSRF(routes.issueCreate))),
"/issue/update": withCors(withAuth(withCSRF(routes.issueUpdate))), "/issue/update": withGlobal(withAuth(withCSRF(routes.issueUpdate))),
"/issue/delete": withCors(withAuth(withCSRF(routes.issueDelete))), "/issue/delete": withGlobal(withAuth(withCSRF(routes.issueDelete))),
"/issue-comment/create": withCors(withAuth(withCSRF(routes.issueCommentCreate))), "/issue-comment/create": withGlobal(withAuth(withCSRF(routes.issueCommentCreate))),
"/issue-comment/delete": withCors(withAuth(withCSRF(routes.issueCommentDelete))), "/issue-comment/delete": withGlobal(withAuth(withCSRF(routes.issueCommentDelete))),
"/issues/by-project": withCors(withAuth(routes.issuesByProject)), "/issues/by-project": withGlobal(withAuth(routes.issuesByProject)),
"/issues/replace-status": withCors(withAuth(withCSRF(routes.issuesReplaceStatus))), "/issues/replace-status": withGlobal(withAuth(withCSRF(routes.issuesReplaceStatus))),
"/issues/status-count": withCors(withAuth(routes.issuesStatusCount)), "/issues/status-count": withGlobal(withAuth(routes.issuesStatusCount)),
"/issues/all": withCors(withAuth(routes.issues)), "/issues/all": withGlobal(withAuth(routes.issues)),
"/issue-comments/by-issue": withCors(withAuth(routes.issueCommentsByIssue)), "/issue-comments/by-issue": withGlobal(withAuth(routes.issueCommentsByIssue)),
"/organisation/create": withCors(withAuth(withCSRF(routes.organisationCreate))), "/organisation/create": withGlobal(withAuth(withCSRF(routes.organisationCreate))),
"/organisation/by-id": withCors(withAuth(routes.organisationById)), "/organisation/by-id": withGlobal(withAuth(routes.organisationById)),
"/organisation/update": withCors(withAuth(withCSRF(routes.organisationUpdate))), "/organisation/update": withGlobal(withAuth(withCSRF(routes.organisationUpdate))),
"/organisation/delete": withCors(withAuth(withCSRF(routes.organisationDelete))), "/organisation/delete": withGlobal(withAuth(withCSRF(routes.organisationDelete))),
"/organisation/upload-icon": withCors(withAuth(withCSRF(routes.organisationUploadIcon))), "/organisation/upload-icon": withGlobal(withAuth(withCSRF(routes.organisationUploadIcon))),
"/organisation/add-member": withCors(withAuth(withCSRF(routes.organisationAddMember))), "/organisation/add-member": withGlobal(withAuth(withCSRF(routes.organisationAddMember))),
"/organisation/members": withCors(withAuth(routes.organisationMembers)), "/organisation/members": withGlobal(withAuth(routes.organisationMembers)),
"/organisation/remove-member": withCors(withAuth(withCSRF(routes.organisationRemoveMember))), "/organisation/remove-member": withGlobal(withAuth(withCSRF(routes.organisationRemoveMember))),
"/organisation/update-member-role": withCors( "/organisation/update-member-role": withGlobal(
withAuth(withCSRF(routes.organisationUpdateMemberRole)), withAuth(withCSRF(routes.organisationUpdateMemberRole)),
), ),
"/organisations/by-user": withCors(withAuth(routes.organisationsByUser)), "/organisations/by-user": withGlobal(withAuth(routes.organisationsByUser)),
"/project/create": withCors(withAuth(withCSRF(routes.projectCreate))), "/project/create": withGlobal(withAuth(withCSRF(routes.projectCreate))),
"/project/update": withCors(withAuth(withCSRF(routes.projectUpdate))), "/project/update": withGlobal(withAuth(withCSRF(routes.projectUpdate))),
"/project/delete": withCors(withAuth(withCSRF(routes.projectDelete))), "/project/delete": withGlobal(withAuth(withCSRF(routes.projectDelete))),
"/project/with-creator": withCors(withAuth(routes.projectWithCreator)), "/project/with-creator": withGlobal(withAuth(routes.projectWithCreator)),
"/projects/by-creator": withCors(withAuth(routes.projectsByCreator)), "/projects/by-creator": withGlobal(withAuth(routes.projectsByCreator)),
"/projects/by-organisation": withCors(withAuth(routes.projectsByOrganisation)), "/projects/by-organisation": withGlobal(withAuth(routes.projectsByOrganisation)),
"/projects/all": withCors(withAuth(routes.projectsAll)), "/projects/all": withGlobal(withAuth(routes.projectsAll)),
"/projects/with-creators": withCors(withAuth(routes.projectsWithCreators)), "/projects/with-creators": withGlobal(withAuth(routes.projectsWithCreators)),
"/sprint/create": withCors(withAuth(withCSRF(routes.sprintCreate))), "/sprint/create": withGlobal(withAuth(withCSRF(routes.sprintCreate))),
"/sprint/update": withCors(withAuth(withCSRF(routes.sprintUpdate))), "/sprint/update": withGlobal(withAuth(withCSRF(routes.sprintUpdate))),
"/sprint/delete": withCors(withAuth(withCSRF(routes.sprintDelete))), "/sprint/delete": withGlobal(withAuth(withCSRF(routes.sprintDelete))),
"/sprints/by-project": withCors(withAuth(routes.sprintsByProject)), "/sprints/by-project": withGlobal(withAuth(routes.sprintsByProject)),
"/timer/toggle": withCors(withAuth(withCSRF(routes.timerToggle))), "/timer/toggle": withGlobal(withAuth(withCSRF(routes.timerToggle))),
"/timer/end": withCors(withAuth(withCSRF(routes.timerEnd))), "/timer/end": withGlobal(withAuth(withCSRF(routes.timerEnd))),
"/timer/get": withCors(withAuth(withCSRF(routes.timerGet))), "/timer/get": withGlobal(withAuth(withCSRF(routes.timerGet))),
"/timer/get-inactive": withCors(withAuth(withCSRF(routes.timerGetInactive))), "/timer/get-inactive": withGlobal(withAuth(withCSRF(routes.timerGetInactive))),
"/timers": withCors(withAuth(withCSRF(routes.timers))), "/timers": withGlobal(withAuth(withCSRF(routes.timers))),
}, },
}); });