package api import ( "log/slog" "net/http" "strings" "time" ) // LoggingMiddleware 请求日志中间件 func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // 包装 ResponseWriter 以获取状态码 wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} next.ServeHTTP(wrapped, r) slog.Info("request", "method", r.Method, "path", r.URL.Path, "status", wrapped.statusCode, "duration_ms", time.Since(start).Milliseconds(), "remote_addr", r.RemoteAddr, ) }) } // AuthMiddleware API Key 鉴权中间件 func AuthMiddleware(next http.Handler, apiKey string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 如果没有配置 API Key,跳过鉴权 if apiKey == "" { next.ServeHTTP(w, r) return } // 检查 Authorization 头 auth := r.Header.Get("Authorization") if auth != "" { if strings.HasPrefix(auth, "Bearer ") { token := strings.TrimPrefix(auth, "Bearer ") if token == apiKey { next.ServeHTTP(w, r) return } } } // 检查 X-API-Key 头 key := r.Header.Get("X-API-Key") if key == apiKey { next.ServeHTTP(w, r) return } http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized) }) } // RateLimitMiddleware 简单限流中间件(基于滑动窗口) type RateLimitMiddleware struct { requests map[string][]time.Time limit int window time.Duration } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware { return &RateLimitMiddleware{ requests: make(map[string][]time.Time), limit: limit, window: window, } } // Wrap 包装处理器 func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := getClientIP(r) now := time.Now() // 清理过期记录 windowStart := now.Add(-rl.window) var valid []time.Time for _, t := range rl.requests[ip] { if t.After(windowStart) { valid = append(valid, t) } } rl.requests[ip] = valid // 检查限制 if len(rl.requests[ip]) >= rl.limit { http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests) return } // 记录请求 rl.requests[ip] = append(rl.requests[ip], now) next.ServeHTTP(w, r) }) } // getClientIP 获取客户端 IP func getClientIP(r *http.Request) string { // 检查代理头 if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.Split(xff, ",") return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // 从 RemoteAddr 提取 IP ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip } // responseWriter 包装 ResponseWriter 以获取状态码 type responseWriter struct { http.ResponseWriter statusCode int } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) }