frist
This commit is contained in:
643
internal/api/handlers.go
Normal file
643
internal/api/handlers.go
Normal file
@@ -0,0 +1,643 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/importer"
|
||||
"proxyrotator/internal/model"
|
||||
"proxyrotator/internal/security"
|
||||
"proxyrotator/internal/selector"
|
||||
"proxyrotator/internal/store"
|
||||
"proxyrotator/internal/tester"
|
||||
)
|
||||
|
||||
// Handlers API 处理器集合
|
||||
type Handlers struct {
|
||||
store store.ProxyStore
|
||||
importer *importer.Importer
|
||||
tester *tester.HTTPTester
|
||||
selector *selector.Selector
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewHandlers 创建处理器
|
||||
func NewHandlers(
|
||||
store store.ProxyStore,
|
||||
importer *importer.Importer,
|
||||
tester *tester.HTTPTester,
|
||||
selector *selector.Selector,
|
||||
cfg *config.Config,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
store: store,
|
||||
importer: importer,
|
||||
tester: tester,
|
||||
selector: selector,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleImportText 文本导入处理器
|
||||
// POST /v1/proxies/import/text
|
||||
func (h *Handlers) HandleImportText(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.ImportTextRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Text == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "text is required")
|
||||
return
|
||||
}
|
||||
|
||||
input := model.ImportInput{
|
||||
Group: coalesce(req.Group, "default"),
|
||||
Tags: req.Tags,
|
||||
ProtocolHint: req.ProtocolHint,
|
||||
}
|
||||
|
||||
proxies, invalid := h.importer.ParseText(r.Context(), input, req.Text)
|
||||
|
||||
imported, duplicated := 0, 0
|
||||
if len(proxies) > 0 {
|
||||
var err error
|
||||
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ImportResult{
|
||||
Imported: imported,
|
||||
Duplicated: duplicated,
|
||||
Invalid: len(invalid),
|
||||
InvalidItems: invalid,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleImportFile 文件上传导入处理器
|
||||
// POST /v1/proxies/import/file
|
||||
func (h *Handlers) HandleImportFile(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil { // 10MB 限制
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "failed to parse multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "file is required")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
group := coalesce(r.FormValue("group"), "default")
|
||||
tagsStr := r.FormValue("tags")
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
for i := range tags {
|
||||
tags[i] = strings.TrimSpace(tags[i])
|
||||
}
|
||||
}
|
||||
|
||||
input := model.ImportInput{
|
||||
Group: group,
|
||||
Tags: tags,
|
||||
ProtocolHint: r.FormValue("protocol_hint"),
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read file")
|
||||
return
|
||||
}
|
||||
|
||||
var proxies []model.Proxy
|
||||
var invalid []model.InvalidLine
|
||||
|
||||
// 根据文件类型解析
|
||||
fileType := r.FormValue("type")
|
||||
if fileType == "" {
|
||||
// 根据文件名推断
|
||||
if strings.HasSuffix(strings.ToLower(header.Filename), ".csv") {
|
||||
fileType = "csv"
|
||||
} else {
|
||||
fileType = "txt"
|
||||
}
|
||||
}
|
||||
|
||||
if fileType == "csv" {
|
||||
proxies, invalid = h.importer.ParseCSV(r.Context(), input, strings.NewReader(string(content)))
|
||||
} else {
|
||||
proxies, invalid = h.importer.ParseText(r.Context(), input, string(content))
|
||||
}
|
||||
|
||||
imported, duplicated := 0, 0
|
||||
if len(proxies) > 0 {
|
||||
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ImportResult{
|
||||
Imported: imported,
|
||||
Duplicated: duplicated,
|
||||
Invalid: len(invalid),
|
||||
InvalidItems: invalid,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleTest 测试代理处理器
|
||||
// POST /v1/proxies/test
|
||||
func (h *Handlers) HandleTest(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.TestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
// SSRF 防护
|
||||
if req.TestSpec.URL == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "test_spec.url is required")
|
||||
return
|
||||
}
|
||||
if err := security.ValidateTestURL(req.TestSpec.URL); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 限制并发数
|
||||
concurrency := req.Concurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 50
|
||||
}
|
||||
if concurrency > h.cfg.MaxConcurrency {
|
||||
concurrency = h.cfg.MaxConcurrency
|
||||
}
|
||||
|
||||
// 限制测试数量
|
||||
limit := req.Filter.Limit
|
||||
if limit <= 0 || limit > h.cfg.MaxTestLimit {
|
||||
limit = h.cfg.MaxTestLimit
|
||||
}
|
||||
|
||||
// 构建查询条件
|
||||
var statusIn []model.ProxyStatus
|
||||
for _, s := range req.Filter.Status {
|
||||
statusIn = append(statusIn, model.ProxyStatus(s))
|
||||
}
|
||||
if len(statusIn) == 0 {
|
||||
statusIn = []model.ProxyStatus{model.StatusUnknown, model.StatusAlive}
|
||||
}
|
||||
|
||||
proxies, err := h.store.List(r.Context(), model.ProxyQuery{
|
||||
Group: coalesce(req.Group, "default"),
|
||||
TagsAny: req.Filter.TagsAny,
|
||||
StatusIn: statusIn,
|
||||
OnlyEnabled: true,
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
writeJSON(w, http.StatusOK, model.TestBatchResult{
|
||||
Summary: model.TestSummary{},
|
||||
Results: []model.TestResult{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建测试规格
|
||||
timeout := time.Duration(req.TestSpec.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
spec := model.TestSpec{
|
||||
URL: req.TestSpec.URL,
|
||||
Method: coalesce(req.TestSpec.Method, "GET"),
|
||||
Timeout: timeout,
|
||||
ExpectStatus: req.TestSpec.ExpectStatus,
|
||||
ExpectContains: req.TestSpec.ExpectContains,
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
results := h.tester.TestBatch(r.Context(), proxies, spec, concurrency)
|
||||
|
||||
// 统计结果
|
||||
summary := model.TestSummary{Tested: len(results)}
|
||||
for _, result := range results {
|
||||
if result.OK {
|
||||
summary.Alive++
|
||||
} else {
|
||||
summary.Dead++
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
if req.UpdateStore {
|
||||
now := result.CheckedAt
|
||||
if result.OK {
|
||||
status := model.StatusAlive
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &result.LatencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
} else {
|
||||
status := model.StatusDead
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 写入测试日志
|
||||
if req.WriteLog {
|
||||
_ = h.store.InsertTestLog(r.Context(), result, req.TestSpec.URL)
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.TestBatchResult{
|
||||
Summary: summary,
|
||||
Results: results,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleNext 获取下一个可用代理
|
||||
// GET /v1/proxies/next
|
||||
func (h *Handlers) HandleNext(w http.ResponseWriter, r *http.Request) {
|
||||
req := model.SelectRequest{
|
||||
Group: coalesce(r.URL.Query().Get("group"), "default"),
|
||||
Site: r.URL.Query().Get("site"),
|
||||
Policy: r.URL.Query().Get("policy"),
|
||||
TagsAny: splitCSV(r.URL.Query().Get("tags_any")),
|
||||
}
|
||||
|
||||
lease, err := h.selector.Next(r.Context(), req)
|
||||
if err != nil {
|
||||
if err == model.ErrNoProxy {
|
||||
writeError(w, http.StatusNotFound, "not_found", "no available proxy")
|
||||
return
|
||||
}
|
||||
if err == model.ErrBadPolicy {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid policy")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := model.NextProxyResponse{
|
||||
Proxy: model.ProxyInfo{
|
||||
ID: lease.Proxy.ID,
|
||||
Protocol: lease.Proxy.Protocol,
|
||||
Host: lease.Proxy.Host,
|
||||
Port: lease.Proxy.Port,
|
||||
},
|
||||
LeaseID: lease.LeaseID,
|
||||
TTLMs: time.Until(lease.ExpireAt).Milliseconds(),
|
||||
}
|
||||
|
||||
// 根据配置决定是否返回凭证
|
||||
if h.cfg.ReturnSecret {
|
||||
resp.Proxy.Username = lease.Proxy.Username
|
||||
resp.Proxy.Password = lease.Proxy.Password
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// HandleReport 上报使用结果
|
||||
// POST /v1/proxies/report
|
||||
func (h *Handlers) HandleReport(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.ReportRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProxyID == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.selector.Report(r.Context(), req.LeaseID, req.ProxyID, req.Success, req.LatencyMs, req.Error); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// HandleListProxies 列出代理
|
||||
// GET /v1/proxies
|
||||
func (h *Handlers) HandleListProxies(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
|
||||
// 解析分页参数
|
||||
offset, _ := strconv.Atoi(query.Get("offset"))
|
||||
limit, _ := strconv.Atoi(query.Get("limit"))
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
// 解析过滤参数
|
||||
var statusIn []model.ProxyStatus
|
||||
if statusStr := query.Get("status"); statusStr != "" {
|
||||
for _, s := range strings.Split(statusStr, ",") {
|
||||
statusIn = append(statusIn, model.ProxyStatus(strings.TrimSpace(s)))
|
||||
}
|
||||
}
|
||||
|
||||
q := model.ProxyListQuery{
|
||||
Group: query.Get("group"),
|
||||
TagsAny: splitCSV(query.Get("tags")),
|
||||
StatusIn: statusIn,
|
||||
OnlyEnabled: query.Get("only_enabled") == "true",
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
proxies, total, err := h.store.ListPaginated(r.Context(), q)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ProxyListResponse{
|
||||
Data: proxies,
|
||||
Total: total,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleGetProxy 获取单个代理
|
||||
// GET /v1/proxies/{id}
|
||||
func (h *Handlers) HandleGetProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, proxy)
|
||||
}
|
||||
|
||||
// HandleDeleteProxy 删除单个代理
|
||||
// DELETE /v1/proxies/{id}
|
||||
func (h *Handlers) HandleDeleteProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Delete(r.Context(), id); err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// HandleBulkDeleteProxies 批量删除代理
|
||||
// DELETE /v1/proxies
|
||||
func (h *Handlers) HandleBulkDeleteProxies(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.BulkDeleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.store.DeleteMany(r.Context(), req)
|
||||
if err != nil {
|
||||
if err == model.ErrBulkDeleteEmpty {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.BulkDeleteResponse{Deleted: int(deleted)})
|
||||
}
|
||||
|
||||
// HandleUpdateProxy 更新代理
|
||||
// PATCH /v1/proxies/{id}
|
||||
func (h *Handlers) HandleUpdateProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var patch model.ProxyPatch
|
||||
if err := json.NewDecoder(r.Body).Decode(&patch); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Update(r.Context(), id, patch); err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
if err == model.ErrInvalidPatch {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的代理
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, proxy)
|
||||
}
|
||||
|
||||
// HandleTestSingleProxy 测试单个代理
|
||||
// POST /v1/proxies/{id}/test
|
||||
func (h *Handlers) HandleTestSingleProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req model.SingleTestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
// SSRF 防护
|
||||
if req.URL == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "url is required")
|
||||
return
|
||||
}
|
||||
if err := security.ValidateTestURL(req.URL); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取代理
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 构建测试规格
|
||||
timeout := time.Duration(req.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
spec := model.TestSpec{
|
||||
URL: req.URL,
|
||||
Method: coalesce(req.Method, "GET"),
|
||||
Timeout: timeout,
|
||||
ExpectStatus: req.ExpectStatus,
|
||||
ExpectContains: req.ExpectContains,
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
results := h.tester.TestBatch(r.Context(), []model.Proxy{*proxy}, spec, 1)
|
||||
if len(results) == 0 {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "test failed")
|
||||
return
|
||||
}
|
||||
|
||||
result := results[0]
|
||||
|
||||
// 更新数据库
|
||||
if req.UpdateStore {
|
||||
now := result.CheckedAt
|
||||
if result.OK {
|
||||
status := model.StatusAlive
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &result.LatencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
} else {
|
||||
status := model.StatusDead
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 写入测试日志
|
||||
if req.WriteLog {
|
||||
_ = h.store.InsertTestLog(r.Context(), result, req.URL)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
// HandleGetStats 获取代理统计信息
|
||||
// GET /v1/proxies/stats
|
||||
func (h *Handlers) HandleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := h.store.GetStats(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// writeJSON 写入 JSON 响应
|
||||
func writeJSON(w http.ResponseWriter, status int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// writeError 写入错误响应
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": code,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// coalesce 返回第一个非空字符串
|
||||
func coalesce(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// splitCSV 分割逗号分隔的字符串
|
||||
func splitCSV(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
136
internal/api/middleware.go
Normal file
136
internal/api/middleware.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LoggingMiddleware 请求日志中间件
|
||||
func LoggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装 ResponseWriter 以获取状态码
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", wrapped.statusCode,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"remote_addr", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthMiddleware API Key 鉴权中间件
|
||||
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 如果没有配置 API Key,跳过鉴权
|
||||
if apiKey == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization 头
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
if token == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-API-Key 头
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
|
||||
type RateLimitMiddleware struct {
|
||||
requests map[string][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap 包装处理器
|
||||
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getClientIP(r)
|
||||
now := time.Now()
|
||||
|
||||
// 清理过期记录
|
||||
windowStart := now.Add(-rl.window)
|
||||
var valid []time.Time
|
||||
for _, t := range rl.requests[ip] {
|
||||
if t.After(windowStart) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
rl.requests[ip] = valid
|
||||
|
||||
// 检查限制
|
||||
if len(rl.requests[ip]) >= rl.limit {
|
||||
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录请求
|
||||
rl.requests[ip] = append(rl.requests[ip], now)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端 IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 从 RemoteAddr 提取 IP
|
||||
ip := r.RemoteAddr
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// responseWriter 包装 ResponseWriter 以获取状态码
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
53
internal/api/router.go
Normal file
53
internal/api/router.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/importer"
|
||||
"proxyrotator/internal/selector"
|
||||
"proxyrotator/internal/store"
|
||||
"proxyrotator/internal/tester"
|
||||
)
|
||||
|
||||
// NewRouter 创建 HTTP 路由
|
||||
func NewRouter(
|
||||
store store.ProxyStore,
|
||||
importer *importer.Importer,
|
||||
tester *tester.HTTPTester,
|
||||
selector *selector.Selector,
|
||||
cfg *config.Config,
|
||||
) http.Handler {
|
||||
handlers := NewHandlers(store, importer, tester, selector, cfg)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 注册路由(Go 1.22+ 支持 METHOD /path 模式)
|
||||
mux.HandleFunc("POST /v1/proxies/import/text", handlers.HandleImportText)
|
||||
mux.HandleFunc("POST /v1/proxies/import/file", handlers.HandleImportFile)
|
||||
mux.HandleFunc("POST /v1/proxies/test", handlers.HandleTest)
|
||||
mux.HandleFunc("GET /v1/proxies/next", handlers.HandleNext)
|
||||
mux.HandleFunc("POST /v1/proxies/report", handlers.HandleReport)
|
||||
|
||||
// CRUD 路由(注意:/stats 需在 /{id} 之前注册)
|
||||
mux.HandleFunc("GET /v1/proxies/stats", handlers.HandleGetStats)
|
||||
mux.HandleFunc("GET /v1/proxies", handlers.HandleListProxies)
|
||||
mux.HandleFunc("GET /v1/proxies/{id}", handlers.HandleGetProxy)
|
||||
mux.HandleFunc("DELETE /v1/proxies/{id}", handlers.HandleDeleteProxy)
|
||||
mux.HandleFunc("DELETE /v1/proxies", handlers.HandleBulkDeleteProxies)
|
||||
mux.HandleFunc("PATCH /v1/proxies/{id}", handlers.HandleUpdateProxy)
|
||||
mux.HandleFunc("POST /v1/proxies/{id}/test", handlers.HandleTestSingleProxy)
|
||||
|
||||
// 健康检查
|
||||
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
|
||||
// 应用中间件
|
||||
var handler http.Handler = mux
|
||||
handler = AuthMiddleware(handler, cfg.APIKey)
|
||||
handler = LoggingMiddleware(handler)
|
||||
|
||||
return handler
|
||||
}
|
||||
Reference in New Issue
Block a user