Files
ProxyPool/internal/api/handlers.go
2026-01-31 22:53:12 +08:00

644 lines
16 KiB
Go

package api
import (
"encoding/json"
"io"
"net/http"
"strconv"
"strings"
"time"
"proxyrotator/internal/config"
"proxyrotator/internal/importer"
"proxyrotator/internal/model"
"proxyrotator/internal/security"
"proxyrotator/internal/selector"
"proxyrotator/internal/store"
"proxyrotator/internal/tester"
)
// Handlers API 处理器集合
type Handlers struct {
store store.ProxyStore
importer *importer.Importer
tester *tester.HTTPTester
selector *selector.Selector
cfg *config.Config
}
// NewHandlers 创建处理器
func NewHandlers(
store store.ProxyStore,
importer *importer.Importer,
tester *tester.HTTPTester,
selector *selector.Selector,
cfg *config.Config,
) *Handlers {
return &Handlers{
store: store,
importer: importer,
tester: tester,
selector: selector,
cfg: cfg,
}
}
// HandleImportText 文本导入处理器
// POST /v1/proxies/import/text
func (h *Handlers) HandleImportText(w http.ResponseWriter, r *http.Request) {
var req model.ImportTextRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
if req.Text == "" {
writeError(w, http.StatusBadRequest, "bad_request", "text is required")
return
}
input := model.ImportInput{
Group: coalesce(req.Group, "default"),
Tags: req.Tags,
ProtocolHint: req.ProtocolHint,
}
proxies, invalid := h.importer.ParseText(r.Context(), input, req.Text)
imported, duplicated := 0, 0
if len(proxies) > 0 {
var err error
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
}
writeJSON(w, http.StatusOK, model.ImportResult{
Imported: imported,
Duplicated: duplicated,
Invalid: len(invalid),
InvalidItems: invalid,
})
}
// HandleImportFile 文件上传导入处理器
// POST /v1/proxies/import/file
func (h *Handlers) HandleImportFile(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil { // 10MB 限制
writeError(w, http.StatusBadRequest, "bad_request", "failed to parse multipart form")
return
}
file, header, err := r.FormFile("file")
if err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "file is required")
return
}
defer file.Close()
group := coalesce(r.FormValue("group"), "default")
tagsStr := r.FormValue("tags")
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
for i := range tags {
tags[i] = strings.TrimSpace(tags[i])
}
}
input := model.ImportInput{
Group: group,
Tags: tags,
ProtocolHint: r.FormValue("protocol_hint"),
}
// 读取文件内容
content, err := io.ReadAll(file)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read file")
return
}
var proxies []model.Proxy
var invalid []model.InvalidLine
// 根据文件类型解析
fileType := r.FormValue("type")
if fileType == "" {
// 根据文件名推断
if strings.HasSuffix(strings.ToLower(header.Filename), ".csv") {
fileType = "csv"
} else {
fileType = "txt"
}
}
if fileType == "csv" {
proxies, invalid = h.importer.ParseCSV(r.Context(), input, strings.NewReader(string(content)))
} else {
proxies, invalid = h.importer.ParseText(r.Context(), input, string(content))
}
imported, duplicated := 0, 0
if len(proxies) > 0 {
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
}
writeJSON(w, http.StatusOK, model.ImportResult{
Imported: imported,
Duplicated: duplicated,
Invalid: len(invalid),
InvalidItems: invalid,
})
}
// HandleTest 测试代理处理器
// POST /v1/proxies/test
func (h *Handlers) HandleTest(w http.ResponseWriter, r *http.Request) {
var req model.TestRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
// SSRF 防护
if req.TestSpec.URL == "" {
writeError(w, http.StatusBadRequest, "bad_request", "test_spec.url is required")
return
}
if err := security.ValidateTestURL(req.TestSpec.URL); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
return
}
// 限制并发数
concurrency := req.Concurrency
if concurrency <= 0 {
concurrency = 50
}
if concurrency > h.cfg.MaxConcurrency {
concurrency = h.cfg.MaxConcurrency
}
// 限制测试数量
limit := req.Filter.Limit
if limit <= 0 || limit > h.cfg.MaxTestLimit {
limit = h.cfg.MaxTestLimit
}
// 构建查询条件
var statusIn []model.ProxyStatus
for _, s := range req.Filter.Status {
statusIn = append(statusIn, model.ProxyStatus(s))
}
if len(statusIn) == 0 {
statusIn = []model.ProxyStatus{model.StatusUnknown, model.StatusAlive}
}
proxies, err := h.store.List(r.Context(), model.ProxyQuery{
Group: coalesce(req.Group, "default"),
TagsAny: req.Filter.TagsAny,
StatusIn: statusIn,
OnlyEnabled: true,
Limit: limit,
})
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
if len(proxies) == 0 {
writeJSON(w, http.StatusOK, model.TestBatchResult{
Summary: model.TestSummary{},
Results: []model.TestResult{},
})
return
}
// 构建测试规格
timeout := time.Duration(req.TestSpec.TimeoutMs) * time.Millisecond
if timeout <= 0 {
timeout = 5 * time.Second
}
spec := model.TestSpec{
URL: req.TestSpec.URL,
Method: coalesce(req.TestSpec.Method, "GET"),
Timeout: timeout,
ExpectStatus: req.TestSpec.ExpectStatus,
ExpectContains: req.TestSpec.ExpectContains,
}
// 执行测试
results := h.tester.TestBatch(r.Context(), proxies, spec, concurrency)
// 统计结果
summary := model.TestSummary{Tested: len(results)}
for _, result := range results {
if result.OK {
summary.Alive++
} else {
summary.Dead++
}
// 更新数据库
if req.UpdateStore {
now := result.CheckedAt
if result.OK {
status := model.StatusAlive
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
Status: &status,
ScoreDelta: 1,
SuccessInc: 1,
LatencyMs: &result.LatencyMs,
CheckedAt: &now,
})
} else {
status := model.StatusDead
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
Status: &status,
ScoreDelta: -3,
FailInc: 1,
CheckedAt: &now,
})
}
}
// 写入测试日志
if req.WriteLog {
_ = h.store.InsertTestLog(r.Context(), result, req.TestSpec.URL)
}
}
writeJSON(w, http.StatusOK, model.TestBatchResult{
Summary: summary,
Results: results,
})
}
// HandleNext 获取下一个可用代理
// GET /v1/proxies/next
func (h *Handlers) HandleNext(w http.ResponseWriter, r *http.Request) {
req := model.SelectRequest{
Group: coalesce(r.URL.Query().Get("group"), "default"),
Site: r.URL.Query().Get("site"),
Policy: r.URL.Query().Get("policy"),
TagsAny: splitCSV(r.URL.Query().Get("tags_any")),
}
lease, err := h.selector.Next(r.Context(), req)
if err != nil {
if err == model.ErrNoProxy {
writeError(w, http.StatusNotFound, "not_found", "no available proxy")
return
}
if err == model.ErrBadPolicy {
writeError(w, http.StatusBadRequest, "bad_request", "invalid policy")
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
resp := model.NextProxyResponse{
Proxy: model.ProxyInfo{
ID: lease.Proxy.ID,
Protocol: lease.Proxy.Protocol,
Host: lease.Proxy.Host,
Port: lease.Proxy.Port,
},
LeaseID: lease.LeaseID,
TTLMs: time.Until(lease.ExpireAt).Milliseconds(),
}
// 根据配置决定是否返回凭证
if h.cfg.ReturnSecret {
resp.Proxy.Username = lease.Proxy.Username
resp.Proxy.Password = lease.Proxy.Password
}
writeJSON(w, http.StatusOK, resp)
}
// HandleReport 上报使用结果
// POST /v1/proxies/report
func (h *Handlers) HandleReport(w http.ResponseWriter, r *http.Request) {
var req model.ReportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
if req.ProxyID == "" {
writeError(w, http.StatusBadRequest, "bad_request", "proxy_id is required")
return
}
if err := h.selector.Report(r.Context(), req.LeaseID, req.ProxyID, req.Success, req.LatencyMs, req.Error); err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
}
// HandleListProxies 列出代理
// GET /v1/proxies
func (h *Handlers) HandleListProxies(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
// 解析分页参数
offset, _ := strconv.Atoi(query.Get("offset"))
limit, _ := strconv.Atoi(query.Get("limit"))
if limit <= 0 {
limit = 20
}
if limit > 100 {
limit = 100
}
// 解析过滤参数
var statusIn []model.ProxyStatus
if statusStr := query.Get("status"); statusStr != "" {
for _, s := range strings.Split(statusStr, ",") {
statusIn = append(statusIn, model.ProxyStatus(strings.TrimSpace(s)))
}
}
q := model.ProxyListQuery{
Group: query.Get("group"),
TagsAny: splitCSV(query.Get("tags")),
StatusIn: statusIn,
OnlyEnabled: query.Get("only_enabled") == "true",
Offset: offset,
Limit: limit,
}
proxies, total, err := h.store.ListPaginated(r.Context(), q)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, model.ProxyListResponse{
Data: proxies,
Total: total,
Offset: offset,
Limit: limit,
})
}
// HandleGetProxy 获取单个代理
// GET /v1/proxies/{id}
func (h *Handlers) HandleGetProxy(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
return
}
proxy, err := h.store.GetByID(r.Context(), id)
if err != nil {
if err == model.ErrProxyNotFound {
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, proxy)
}
// HandleDeleteProxy 删除单个代理
// DELETE /v1/proxies/{id}
func (h *Handlers) HandleDeleteProxy(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
return
}
if err := h.store.Delete(r.Context(), id); err != nil {
if err == model.ErrProxyNotFound {
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
}
// HandleBulkDeleteProxies 批量删除代理
// DELETE /v1/proxies
func (h *Handlers) HandleBulkDeleteProxies(w http.ResponseWriter, r *http.Request) {
var req model.BulkDeleteRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
deleted, err := h.store.DeleteMany(r.Context(), req)
if err != nil {
if err == model.ErrBulkDeleteEmpty {
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, model.BulkDeleteResponse{Deleted: int(deleted)})
}
// HandleUpdateProxy 更新代理
// PATCH /v1/proxies/{id}
func (h *Handlers) HandleUpdateProxy(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
return
}
var patch model.ProxyPatch
if err := json.NewDecoder(r.Body).Decode(&patch); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
if err := h.store.Update(r.Context(), id, patch); err != nil {
if err == model.ErrProxyNotFound {
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
return
}
if err == model.ErrInvalidPatch {
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
// 返回更新后的代理
proxy, err := h.store.GetByID(r.Context(), id)
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, proxy)
}
// HandleTestSingleProxy 测试单个代理
// POST /v1/proxies/{id}/test
func (h *Handlers) HandleTestSingleProxy(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if id == "" {
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
return
}
var req model.SingleTestRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
return
}
// SSRF 防护
if req.URL == "" {
writeError(w, http.StatusBadRequest, "bad_request", "url is required")
return
}
if err := security.ValidateTestURL(req.URL); err != nil {
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
return
}
// 获取代理
proxy, err := h.store.GetByID(r.Context(), id)
if err != nil {
if err == model.ErrProxyNotFound {
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
return
}
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
// 构建测试规格
timeout := time.Duration(req.TimeoutMs) * time.Millisecond
if timeout <= 0 {
timeout = 5 * time.Second
}
spec := model.TestSpec{
URL: req.URL,
Method: coalesce(req.Method, "GET"),
Timeout: timeout,
ExpectStatus: req.ExpectStatus,
ExpectContains: req.ExpectContains,
}
// 执行测试
results := h.tester.TestBatch(r.Context(), []model.Proxy{*proxy}, spec, 1)
if len(results) == 0 {
writeError(w, http.StatusInternalServerError, "internal_error", "test failed")
return
}
result := results[0]
// 更新数据库
if req.UpdateStore {
now := result.CheckedAt
if result.OK {
status := model.StatusAlive
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
Status: &status,
ScoreDelta: 1,
SuccessInc: 1,
LatencyMs: &result.LatencyMs,
CheckedAt: &now,
})
} else {
status := model.StatusDead
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
Status: &status,
ScoreDelta: -3,
FailInc: 1,
CheckedAt: &now,
})
}
}
// 写入测试日志
if req.WriteLog {
_ = h.store.InsertTestLog(r.Context(), result, req.URL)
}
writeJSON(w, http.StatusOK, result)
}
// HandleGetStats 获取代理统计信息
// GET /v1/proxies/stats
func (h *Handlers) HandleGetStats(w http.ResponseWriter, r *http.Request) {
stats, err := h.store.GetStats(r.Context())
if err != nil {
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
return
}
writeJSON(w, http.StatusOK, stats)
}
// writeJSON 写入 JSON 响应
func writeJSON(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
// writeError 写入错误响应
func writeError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{
"error": code,
"message": message,
})
}
// coalesce 返回第一个非空字符串
func coalesce(values ...string) string {
for _, v := range values {
if v != "" {
return v
}
}
return ""
}
// splitCSV 分割逗号分隔的字符串
func splitCSV(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}