frist
This commit is contained in:
136
internal/api/middleware.go
Normal file
136
internal/api/middleware.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LoggingMiddleware 请求日志中间件
|
||||
func LoggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装 ResponseWriter 以获取状态码
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", wrapped.statusCode,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"remote_addr", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthMiddleware API Key 鉴权中间件
|
||||
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 如果没有配置 API Key,跳过鉴权
|
||||
if apiKey == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization 头
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
if token == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-API-Key 头
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
|
||||
type RateLimitMiddleware struct {
|
||||
requests map[string][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap 包装处理器
|
||||
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getClientIP(r)
|
||||
now := time.Now()
|
||||
|
||||
// 清理过期记录
|
||||
windowStart := now.Add(-rl.window)
|
||||
var valid []time.Time
|
||||
for _, t := range rl.requests[ip] {
|
||||
if t.After(windowStart) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
rl.requests[ip] = valid
|
||||
|
||||
// 检查限制
|
||||
if len(rl.requests[ip]) >= rl.limit {
|
||||
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录请求
|
||||
rl.requests[ip] = append(rl.requests[ip], now)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端 IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 从 RemoteAddr 提取 IP
|
||||
ip := r.RemoteAddr
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// responseWriter 包装 ResponseWriter 以获取状态码
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
Reference in New Issue
Block a user