Files
ProxyPool/internal/api/middleware.go
2026-01-31 22:53:12 +08:00

137 lines
3.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"log/slog"
"net/http"
"strings"
"time"
)
// LoggingMiddleware 请求日志中间件
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装 ResponseWriter 以获取状态码
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
slog.Info("request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.statusCode,
"duration_ms", time.Since(start).Milliseconds(),
"remote_addr", r.RemoteAddr,
)
})
}
// AuthMiddleware API Key 鉴权中间件
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 如果没有配置 API Key跳过鉴权
if apiKey == "" {
next.ServeHTTP(w, r)
return
}
// 检查 Authorization 头
auth := r.Header.Get("Authorization")
if auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
if token == apiKey {
next.ServeHTTP(w, r)
return
}
}
}
// 检查 X-API-Key 头
key := r.Header.Get("X-API-Key")
if key == apiKey {
next.ServeHTTP(w, r)
return
}
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
})
}
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
type RateLimitMiddleware struct {
requests map[string][]time.Time
limit int
window time.Duration
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
return &RateLimitMiddleware{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
// Wrap 包装处理器
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
now := time.Now()
// 清理过期记录
windowStart := now.Add(-rl.window)
var valid []time.Time
for _, t := range rl.requests[ip] {
if t.After(windowStart) {
valid = append(valid, t)
}
}
rl.requests[ip] = valid
// 检查限制
if len(rl.requests[ip]) >= rl.limit {
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
return
}
// 记录请求
rl.requests[ip] = append(rl.requests[ip], now)
next.ServeHTTP(w, r)
})
}
// getClientIP 获取客户端 IP
func getClientIP(r *http.Request) string {
// 检查代理头
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// 从 RemoteAddr 提取 IP
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// responseWriter 包装 ResponseWriter 以获取状态码
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}