package api import ( "encoding/json" "io" "net/http" "strconv" "strings" "time" "proxyrotator/internal/config" "proxyrotator/internal/importer" "proxyrotator/internal/model" "proxyrotator/internal/security" "proxyrotator/internal/selector" "proxyrotator/internal/store" "proxyrotator/internal/tester" ) // Handlers API 处理器集合 type Handlers struct { store store.ProxyStore importer *importer.Importer tester *tester.HTTPTester selector *selector.Selector cfg *config.Config } // NewHandlers 创建处理器 func NewHandlers( store store.ProxyStore, importer *importer.Importer, tester *tester.HTTPTester, selector *selector.Selector, cfg *config.Config, ) *Handlers { return &Handlers{ store: store, importer: importer, tester: tester, selector: selector, cfg: cfg, } } // HandleImportText 文本导入处理器 // POST /v1/proxies/import/text func (h *Handlers) HandleImportText(w http.ResponseWriter, r *http.Request) { var req model.ImportTextRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } if req.Text == "" { writeError(w, http.StatusBadRequest, "bad_request", "text is required") return } input := model.ImportInput{ Group: coalesce(req.Group, "default"), Tags: req.Tags, ProtocolHint: req.ProtocolHint, } proxies, invalid := h.importer.ParseText(r.Context(), input, req.Text) imported, duplicated := 0, 0 if len(proxies) > 0 { var err error imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } } writeJSON(w, http.StatusOK, model.ImportResult{ Imported: imported, Duplicated: duplicated, Invalid: len(invalid), InvalidItems: invalid, }) } // HandleImportFile 文件上传导入处理器 // POST /v1/proxies/import/file func (h *Handlers) HandleImportFile(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(10 << 20); err != nil { // 10MB 限制 writeError(w, http.StatusBadRequest, "bad_request", "failed to parse multipart form") return } file, header, err := r.FormFile("file") if err != nil { writeError(w, http.StatusBadRequest, "bad_request", "file is required") return } defer file.Close() group := coalesce(r.FormValue("group"), "default") tagsStr := r.FormValue("tags") var tags []string if tagsStr != "" { tags = strings.Split(tagsStr, ",") for i := range tags { tags[i] = strings.TrimSpace(tags[i]) } } input := model.ImportInput{ Group: group, Tags: tags, ProtocolHint: r.FormValue("protocol_hint"), } // 读取文件内容 content, err := io.ReadAll(file) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", "failed to read file") return } var proxies []model.Proxy var invalid []model.InvalidLine // 根据文件类型解析 fileType := r.FormValue("type") if fileType == "" { // 根据文件名推断 if strings.HasSuffix(strings.ToLower(header.Filename), ".csv") { fileType = "csv" } else { fileType = "txt" } } if fileType == "csv" { proxies, invalid = h.importer.ParseCSV(r.Context(), input, strings.NewReader(string(content))) } else { proxies, invalid = h.importer.ParseText(r.Context(), input, string(content)) } imported, duplicated := 0, 0 if len(proxies) > 0 { imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } } writeJSON(w, http.StatusOK, model.ImportResult{ Imported: imported, Duplicated: duplicated, Invalid: len(invalid), InvalidItems: invalid, }) } // HandleTest 测试代理处理器 // POST /v1/proxies/test func (h *Handlers) HandleTest(w http.ResponseWriter, r *http.Request) { var req model.TestRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } // SSRF 防护 if req.TestSpec.URL == "" { writeError(w, http.StatusBadRequest, "bad_request", "test_spec.url is required") return } if err := security.ValidateTestURL(req.TestSpec.URL); err != nil { writeError(w, http.StatusBadRequest, "bad_request", err.Error()) return } // 限制并发数 concurrency := req.Concurrency if concurrency <= 0 { concurrency = 50 } if concurrency > h.cfg.MaxConcurrency { concurrency = h.cfg.MaxConcurrency } // 限制测试数量 limit := req.Filter.Limit if limit <= 0 || limit > h.cfg.MaxTestLimit { limit = h.cfg.MaxTestLimit } // 构建查询条件 var statusIn []model.ProxyStatus for _, s := range req.Filter.Status { statusIn = append(statusIn, model.ProxyStatus(s)) } if len(statusIn) == 0 { statusIn = []model.ProxyStatus{model.StatusUnknown, model.StatusAlive} } proxies, err := h.store.List(r.Context(), model.ProxyQuery{ Group: coalesce(req.Group, "default"), TagsAny: req.Filter.TagsAny, StatusIn: statusIn, OnlyEnabled: true, Limit: limit, }) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } if len(proxies) == 0 { writeJSON(w, http.StatusOK, model.TestBatchResult{ Summary: model.TestSummary{}, Results: []model.TestResult{}, }) return } // 构建测试规格 timeout := time.Duration(req.TestSpec.TimeoutMs) * time.Millisecond if timeout <= 0 { timeout = 5 * time.Second } spec := model.TestSpec{ URL: req.TestSpec.URL, Method: coalesce(req.TestSpec.Method, "GET"), Timeout: timeout, ExpectStatus: req.TestSpec.ExpectStatus, ExpectContains: req.TestSpec.ExpectContains, } // 执行测试 results := h.tester.TestBatch(r.Context(), proxies, spec, concurrency) // 统计结果 summary := model.TestSummary{Tested: len(results)} for _, result := range results { if result.OK { summary.Alive++ } else { summary.Dead++ } // 更新数据库 if req.UpdateStore { now := result.CheckedAt if result.OK { status := model.StatusAlive _ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{ Status: &status, ScoreDelta: 1, SuccessInc: 1, LatencyMs: &result.LatencyMs, CheckedAt: &now, }) } else { status := model.StatusDead _ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{ Status: &status, ScoreDelta: -3, FailInc: 1, CheckedAt: &now, }) } } // 写入测试日志 if req.WriteLog { _ = h.store.InsertTestLog(r.Context(), result, req.TestSpec.URL) } } writeJSON(w, http.StatusOK, model.TestBatchResult{ Summary: summary, Results: results, }) } // HandleNext 获取下一个可用代理 // GET /v1/proxies/next func (h *Handlers) HandleNext(w http.ResponseWriter, r *http.Request) { req := model.SelectRequest{ Group: coalesce(r.URL.Query().Get("group"), "default"), Site: r.URL.Query().Get("site"), Policy: r.URL.Query().Get("policy"), TagsAny: splitCSV(r.URL.Query().Get("tags_any")), } lease, err := h.selector.Next(r.Context(), req) if err != nil { if err == model.ErrNoProxy { writeError(w, http.StatusNotFound, "not_found", "no available proxy") return } if err == model.ErrBadPolicy { writeError(w, http.StatusBadRequest, "bad_request", "invalid policy") return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } resp := model.NextProxyResponse{ Proxy: model.ProxyInfo{ ID: lease.Proxy.ID, Protocol: lease.Proxy.Protocol, Host: lease.Proxy.Host, Port: lease.Proxy.Port, }, LeaseID: lease.LeaseID, TTLMs: time.Until(lease.ExpireAt).Milliseconds(), } // 根据配置决定是否返回凭证 if h.cfg.ReturnSecret { resp.Proxy.Username = lease.Proxy.Username resp.Proxy.Password = lease.Proxy.Password } writeJSON(w, http.StatusOK, resp) } // HandleReport 上报使用结果 // POST /v1/proxies/report func (h *Handlers) HandleReport(w http.ResponseWriter, r *http.Request) { var req model.ReportRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } if req.ProxyID == "" { writeError(w, http.StatusBadRequest, "bad_request", "proxy_id is required") return } if err := h.selector.Report(r.Context(), req.LeaseID, req.ProxyID, req.Success, req.LatencyMs, req.Error); err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, map[string]bool{"ok": true}) } // HandleListProxies 列出代理 // GET /v1/proxies func (h *Handlers) HandleListProxies(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() // 解析分页参数 offset, _ := strconv.Atoi(query.Get("offset")) limit, _ := strconv.Atoi(query.Get("limit")) if limit <= 0 { limit = 20 } if limit > 100 { limit = 100 } // 解析过滤参数 var statusIn []model.ProxyStatus if statusStr := query.Get("status"); statusStr != "" { for _, s := range strings.Split(statusStr, ",") { statusIn = append(statusIn, model.ProxyStatus(strings.TrimSpace(s))) } } q := model.ProxyListQuery{ Group: query.Get("group"), TagsAny: splitCSV(query.Get("tags")), StatusIn: statusIn, OnlyEnabled: query.Get("only_enabled") == "true", Offset: offset, Limit: limit, } proxies, total, err := h.store.ListPaginated(r.Context(), q) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, model.ProxyListResponse{ Data: proxies, Total: total, Offset: offset, Limit: limit, }) } // HandleGetProxy 获取单个代理 // GET /v1/proxies/{id} func (h *Handlers) HandleGetProxy(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required") return } proxy, err := h.store.GetByID(r.Context(), id) if err != nil { if err == model.ErrProxyNotFound { writeError(w, http.StatusNotFound, "not_found", "proxy not found") return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, proxy) } // HandleDeleteProxy 删除单个代理 // DELETE /v1/proxies/{id} func (h *Handlers) HandleDeleteProxy(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required") return } if err := h.store.Delete(r.Context(), id); err != nil { if err == model.ErrProxyNotFound { writeError(w, http.StatusNotFound, "not_found", "proxy not found") return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, map[string]bool{"ok": true}) } // HandleBulkDeleteProxies 批量删除代理 // DELETE /v1/proxies func (h *Handlers) HandleBulkDeleteProxies(w http.ResponseWriter, r *http.Request) { var req model.BulkDeleteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } deleted, err := h.store.DeleteMany(r.Context(), req) if err != nil { if err == model.ErrBulkDeleteEmpty { writeError(w, http.StatusBadRequest, "bad_request", err.Error()) return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, model.BulkDeleteResponse{Deleted: int(deleted)}) } // HandleUpdateProxy 更新代理 // PATCH /v1/proxies/{id} func (h *Handlers) HandleUpdateProxy(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required") return } var patch model.ProxyPatch if err := json.NewDecoder(r.Body).Decode(&patch); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } if err := h.store.Update(r.Context(), id, patch); err != nil { if err == model.ErrProxyNotFound { writeError(w, http.StatusNotFound, "not_found", "proxy not found") return } if err == model.ErrInvalidPatch { writeError(w, http.StatusBadRequest, "bad_request", err.Error()) return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } // 返回更新后的代理 proxy, err := h.store.GetByID(r.Context(), id) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, proxy) } // HandleTestSingleProxy 测试单个代理 // POST /v1/proxies/{id}/test func (h *Handlers) HandleTestSingleProxy(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") if id == "" { writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required") return } var req model.SingleTestRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body") return } // SSRF 防护 if req.URL == "" { writeError(w, http.StatusBadRequest, "bad_request", "url is required") return } if err := security.ValidateTestURL(req.URL); err != nil { writeError(w, http.StatusBadRequest, "bad_request", err.Error()) return } // 获取代理 proxy, err := h.store.GetByID(r.Context(), id) if err != nil { if err == model.ErrProxyNotFound { writeError(w, http.StatusNotFound, "not_found", "proxy not found") return } writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } // 构建测试规格 timeout := time.Duration(req.TimeoutMs) * time.Millisecond if timeout <= 0 { timeout = 5 * time.Second } spec := model.TestSpec{ URL: req.URL, Method: coalesce(req.Method, "GET"), Timeout: timeout, ExpectStatus: req.ExpectStatus, ExpectContains: req.ExpectContains, } // 执行测试 results := h.tester.TestBatch(r.Context(), []model.Proxy{*proxy}, spec, 1) if len(results) == 0 { writeError(w, http.StatusInternalServerError, "internal_error", "test failed") return } result := results[0] // 更新数据库 if req.UpdateStore { now := result.CheckedAt if result.OK { status := model.StatusAlive _ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{ Status: &status, ScoreDelta: 1, SuccessInc: 1, LatencyMs: &result.LatencyMs, CheckedAt: &now, }) } else { status := model.StatusDead _ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{ Status: &status, ScoreDelta: -3, FailInc: 1, CheckedAt: &now, }) } } // 写入测试日志 if req.WriteLog { _ = h.store.InsertTestLog(r.Context(), result, req.URL) } writeJSON(w, http.StatusOK, result) } // HandleGetStats 获取代理统计信息 // GET /v1/proxies/stats func (h *Handlers) HandleGetStats(w http.ResponseWriter, r *http.Request) { stats, err := h.store.GetStats(r.Context()) if err != nil { writeError(w, http.StatusInternalServerError, "internal_error", err.Error()) return } writeJSON(w, http.StatusOK, stats) } // writeJSON 写入 JSON 响应 func writeJSON(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } // writeError 写入错误响应 func writeError(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": code, "message": message, }) } // coalesce 返回第一个非空字符串 func coalesce(values ...string) string { for _, v := range values { if v != "" { return v } } return "" } // splitCSV 分割逗号分隔的字符串 func splitCSV(s string) []string { if s == "" { return nil } parts := strings.Split(s, ",") result := make([]string, 0, len(parts)) for _, p := range parts { p = strings.TrimSpace(p) if p != "" { result = append(result, p) } } return result }