137 lines
3.1 KiB
Go
137 lines
3.1 KiB
Go
package api
|
||
|
||
import (
|
||
"log/slog"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// LoggingMiddleware 请求日志中间件
|
||
func LoggingMiddleware(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
start := time.Now()
|
||
|
||
// 包装 ResponseWriter 以获取状态码
|
||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||
|
||
next.ServeHTTP(wrapped, r)
|
||
|
||
slog.Info("request",
|
||
"method", r.Method,
|
||
"path", r.URL.Path,
|
||
"status", wrapped.statusCode,
|
||
"duration_ms", time.Since(start).Milliseconds(),
|
||
"remote_addr", r.RemoteAddr,
|
||
)
|
||
})
|
||
}
|
||
|
||
// AuthMiddleware API Key 鉴权中间件
|
||
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// 如果没有配置 API Key,跳过鉴权
|
||
if apiKey == "" {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
// 检查 Authorization 头
|
||
auth := r.Header.Get("Authorization")
|
||
if auth != "" {
|
||
if strings.HasPrefix(auth, "Bearer ") {
|
||
token := strings.TrimPrefix(auth, "Bearer ")
|
||
if token == apiKey {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查 X-API-Key 头
|
||
key := r.Header.Get("X-API-Key")
|
||
if key == apiKey {
|
||
next.ServeHTTP(w, r)
|
||
return
|
||
}
|
||
|
||
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
|
||
})
|
||
}
|
||
|
||
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
|
||
type RateLimitMiddleware struct {
|
||
requests map[string][]time.Time
|
||
limit int
|
||
window time.Duration
|
||
}
|
||
|
||
// NewRateLimitMiddleware 创建限流中间件
|
||
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
|
||
return &RateLimitMiddleware{
|
||
requests: make(map[string][]time.Time),
|
||
limit: limit,
|
||
window: window,
|
||
}
|
||
}
|
||
|
||
// Wrap 包装处理器
|
||
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
|
||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
ip := getClientIP(r)
|
||
now := time.Now()
|
||
|
||
// 清理过期记录
|
||
windowStart := now.Add(-rl.window)
|
||
var valid []time.Time
|
||
for _, t := range rl.requests[ip] {
|
||
if t.After(windowStart) {
|
||
valid = append(valid, t)
|
||
}
|
||
}
|
||
rl.requests[ip] = valid
|
||
|
||
// 检查限制
|
||
if len(rl.requests[ip]) >= rl.limit {
|
||
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
|
||
return
|
||
}
|
||
|
||
// 记录请求
|
||
rl.requests[ip] = append(rl.requests[ip], now)
|
||
|
||
next.ServeHTTP(w, r)
|
||
})
|
||
}
|
||
|
||
// getClientIP 获取客户端 IP
|
||
func getClientIP(r *http.Request) string {
|
||
// 检查代理头
|
||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||
parts := strings.Split(xff, ",")
|
||
return strings.TrimSpace(parts[0])
|
||
}
|
||
|
||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||
return xri
|
||
}
|
||
|
||
// 从 RemoteAddr 提取 IP
|
||
ip := r.RemoteAddr
|
||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||
ip = ip[:idx]
|
||
}
|
||
return ip
|
||
}
|
||
|
||
// responseWriter 包装 ResponseWriter 以获取状态码
|
||
type responseWriter struct {
|
||
http.ResponseWriter
|
||
statusCode int
|
||
}
|
||
|
||
func (rw *responseWriter) WriteHeader(code int) {
|
||
rw.statusCode = code
|
||
rw.ResponseWriter.WriteHeader(code)
|
||
}
|