This commit is contained in:
dela
2026-01-31 22:53:12 +08:00
commit bc639cf460
30 changed files with 6836 additions and 0 deletions

136
internal/api/middleware.go Normal file
View File

@@ -0,0 +1,136 @@
package api
import (
"log/slog"
"net/http"
"strings"
"time"
)
// LoggingMiddleware 请求日志中间件
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装 ResponseWriter 以获取状态码
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
slog.Info("request",
"method", r.Method,
"path", r.URL.Path,
"status", wrapped.statusCode,
"duration_ms", time.Since(start).Milliseconds(),
"remote_addr", r.RemoteAddr,
)
})
}
// AuthMiddleware API Key 鉴权中间件
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 如果没有配置 API Key跳过鉴权
if apiKey == "" {
next.ServeHTTP(w, r)
return
}
// 检查 Authorization 头
auth := r.Header.Get("Authorization")
if auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
if token == apiKey {
next.ServeHTTP(w, r)
return
}
}
}
// 检查 X-API-Key 头
key := r.Header.Get("X-API-Key")
if key == apiKey {
next.ServeHTTP(w, r)
return
}
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
})
}
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
type RateLimitMiddleware struct {
requests map[string][]time.Time
limit int
window time.Duration
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
return &RateLimitMiddleware{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
}
// Wrap 包装处理器
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
now := time.Now()
// 清理过期记录
windowStart := now.Add(-rl.window)
var valid []time.Time
for _, t := range rl.requests[ip] {
if t.After(windowStart) {
valid = append(valid, t)
}
}
rl.requests[ip] = valid
// 检查限制
if len(rl.requests[ip]) >= rl.limit {
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
return
}
// 记录请求
rl.requests[ip] = append(rl.requests[ip], now)
next.ServeHTTP(w, r)
})
}
// getClientIP 获取客户端 IP
func getClientIP(r *http.Request) string {
// 检查代理头
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// 从 RemoteAddr 提取 IP
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// responseWriter 包装 ResponseWriter 以获取状态码
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}