frist
This commit is contained in:
643
internal/api/handlers.go
Normal file
643
internal/api/handlers.go
Normal file
@@ -0,0 +1,643 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/importer"
|
||||
"proxyrotator/internal/model"
|
||||
"proxyrotator/internal/security"
|
||||
"proxyrotator/internal/selector"
|
||||
"proxyrotator/internal/store"
|
||||
"proxyrotator/internal/tester"
|
||||
)
|
||||
|
||||
// Handlers API 处理器集合
|
||||
type Handlers struct {
|
||||
store store.ProxyStore
|
||||
importer *importer.Importer
|
||||
tester *tester.HTTPTester
|
||||
selector *selector.Selector
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewHandlers 创建处理器
|
||||
func NewHandlers(
|
||||
store store.ProxyStore,
|
||||
importer *importer.Importer,
|
||||
tester *tester.HTTPTester,
|
||||
selector *selector.Selector,
|
||||
cfg *config.Config,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
store: store,
|
||||
importer: importer,
|
||||
tester: tester,
|
||||
selector: selector,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleImportText 文本导入处理器
|
||||
// POST /v1/proxies/import/text
|
||||
func (h *Handlers) HandleImportText(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.ImportTextRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Text == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "text is required")
|
||||
return
|
||||
}
|
||||
|
||||
input := model.ImportInput{
|
||||
Group: coalesce(req.Group, "default"),
|
||||
Tags: req.Tags,
|
||||
ProtocolHint: req.ProtocolHint,
|
||||
}
|
||||
|
||||
proxies, invalid := h.importer.ParseText(r.Context(), input, req.Text)
|
||||
|
||||
imported, duplicated := 0, 0
|
||||
if len(proxies) > 0 {
|
||||
var err error
|
||||
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ImportResult{
|
||||
Imported: imported,
|
||||
Duplicated: duplicated,
|
||||
Invalid: len(invalid),
|
||||
InvalidItems: invalid,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleImportFile 文件上传导入处理器
|
||||
// POST /v1/proxies/import/file
|
||||
func (h *Handlers) HandleImportFile(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil { // 10MB 限制
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "failed to parse multipart form")
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "file is required")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
group := coalesce(r.FormValue("group"), "default")
|
||||
tagsStr := r.FormValue("tags")
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
for i := range tags {
|
||||
tags[i] = strings.TrimSpace(tags[i])
|
||||
}
|
||||
}
|
||||
|
||||
input := model.ImportInput{
|
||||
Group: group,
|
||||
Tags: tags,
|
||||
ProtocolHint: r.FormValue("protocol_hint"),
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "failed to read file")
|
||||
return
|
||||
}
|
||||
|
||||
var proxies []model.Proxy
|
||||
var invalid []model.InvalidLine
|
||||
|
||||
// 根据文件类型解析
|
||||
fileType := r.FormValue("type")
|
||||
if fileType == "" {
|
||||
// 根据文件名推断
|
||||
if strings.HasSuffix(strings.ToLower(header.Filename), ".csv") {
|
||||
fileType = "csv"
|
||||
} else {
|
||||
fileType = "txt"
|
||||
}
|
||||
}
|
||||
|
||||
if fileType == "csv" {
|
||||
proxies, invalid = h.importer.ParseCSV(r.Context(), input, strings.NewReader(string(content)))
|
||||
} else {
|
||||
proxies, invalid = h.importer.ParseText(r.Context(), input, string(content))
|
||||
}
|
||||
|
||||
imported, duplicated := 0, 0
|
||||
if len(proxies) > 0 {
|
||||
imported, duplicated, err = h.store.UpsertMany(r.Context(), proxies)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ImportResult{
|
||||
Imported: imported,
|
||||
Duplicated: duplicated,
|
||||
Invalid: len(invalid),
|
||||
InvalidItems: invalid,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleTest 测试代理处理器
|
||||
// POST /v1/proxies/test
|
||||
func (h *Handlers) HandleTest(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.TestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
// SSRF 防护
|
||||
if req.TestSpec.URL == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "test_spec.url is required")
|
||||
return
|
||||
}
|
||||
if err := security.ValidateTestURL(req.TestSpec.URL); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 限制并发数
|
||||
concurrency := req.Concurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 50
|
||||
}
|
||||
if concurrency > h.cfg.MaxConcurrency {
|
||||
concurrency = h.cfg.MaxConcurrency
|
||||
}
|
||||
|
||||
// 限制测试数量
|
||||
limit := req.Filter.Limit
|
||||
if limit <= 0 || limit > h.cfg.MaxTestLimit {
|
||||
limit = h.cfg.MaxTestLimit
|
||||
}
|
||||
|
||||
// 构建查询条件
|
||||
var statusIn []model.ProxyStatus
|
||||
for _, s := range req.Filter.Status {
|
||||
statusIn = append(statusIn, model.ProxyStatus(s))
|
||||
}
|
||||
if len(statusIn) == 0 {
|
||||
statusIn = []model.ProxyStatus{model.StatusUnknown, model.StatusAlive}
|
||||
}
|
||||
|
||||
proxies, err := h.store.List(r.Context(), model.ProxyQuery{
|
||||
Group: coalesce(req.Group, "default"),
|
||||
TagsAny: req.Filter.TagsAny,
|
||||
StatusIn: statusIn,
|
||||
OnlyEnabled: true,
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
writeJSON(w, http.StatusOK, model.TestBatchResult{
|
||||
Summary: model.TestSummary{},
|
||||
Results: []model.TestResult{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建测试规格
|
||||
timeout := time.Duration(req.TestSpec.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
spec := model.TestSpec{
|
||||
URL: req.TestSpec.URL,
|
||||
Method: coalesce(req.TestSpec.Method, "GET"),
|
||||
Timeout: timeout,
|
||||
ExpectStatus: req.TestSpec.ExpectStatus,
|
||||
ExpectContains: req.TestSpec.ExpectContains,
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
results := h.tester.TestBatch(r.Context(), proxies, spec, concurrency)
|
||||
|
||||
// 统计结果
|
||||
summary := model.TestSummary{Tested: len(results)}
|
||||
for _, result := range results {
|
||||
if result.OK {
|
||||
summary.Alive++
|
||||
} else {
|
||||
summary.Dead++
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
if req.UpdateStore {
|
||||
now := result.CheckedAt
|
||||
if result.OK {
|
||||
status := model.StatusAlive
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &result.LatencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
} else {
|
||||
status := model.StatusDead
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 写入测试日志
|
||||
if req.WriteLog {
|
||||
_ = h.store.InsertTestLog(r.Context(), result, req.TestSpec.URL)
|
||||
}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.TestBatchResult{
|
||||
Summary: summary,
|
||||
Results: results,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleNext 获取下一个可用代理
|
||||
// GET /v1/proxies/next
|
||||
func (h *Handlers) HandleNext(w http.ResponseWriter, r *http.Request) {
|
||||
req := model.SelectRequest{
|
||||
Group: coalesce(r.URL.Query().Get("group"), "default"),
|
||||
Site: r.URL.Query().Get("site"),
|
||||
Policy: r.URL.Query().Get("policy"),
|
||||
TagsAny: splitCSV(r.URL.Query().Get("tags_any")),
|
||||
}
|
||||
|
||||
lease, err := h.selector.Next(r.Context(), req)
|
||||
if err != nil {
|
||||
if err == model.ErrNoProxy {
|
||||
writeError(w, http.StatusNotFound, "not_found", "no available proxy")
|
||||
return
|
||||
}
|
||||
if err == model.ErrBadPolicy {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid policy")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resp := model.NextProxyResponse{
|
||||
Proxy: model.ProxyInfo{
|
||||
ID: lease.Proxy.ID,
|
||||
Protocol: lease.Proxy.Protocol,
|
||||
Host: lease.Proxy.Host,
|
||||
Port: lease.Proxy.Port,
|
||||
},
|
||||
LeaseID: lease.LeaseID,
|
||||
TTLMs: time.Until(lease.ExpireAt).Milliseconds(),
|
||||
}
|
||||
|
||||
// 根据配置决定是否返回凭证
|
||||
if h.cfg.ReturnSecret {
|
||||
resp.Proxy.Username = lease.Proxy.Username
|
||||
resp.Proxy.Password = lease.Proxy.Password
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// HandleReport 上报使用结果
|
||||
// POST /v1/proxies/report
|
||||
func (h *Handlers) HandleReport(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.ReportRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.ProxyID == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.selector.Report(r.Context(), req.LeaseID, req.ProxyID, req.Success, req.LatencyMs, req.Error); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// HandleListProxies 列出代理
|
||||
// GET /v1/proxies
|
||||
func (h *Handlers) HandleListProxies(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
|
||||
// 解析分页参数
|
||||
offset, _ := strconv.Atoi(query.Get("offset"))
|
||||
limit, _ := strconv.Atoi(query.Get("limit"))
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
if limit > 100 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
// 解析过滤参数
|
||||
var statusIn []model.ProxyStatus
|
||||
if statusStr := query.Get("status"); statusStr != "" {
|
||||
for _, s := range strings.Split(statusStr, ",") {
|
||||
statusIn = append(statusIn, model.ProxyStatus(strings.TrimSpace(s)))
|
||||
}
|
||||
}
|
||||
|
||||
q := model.ProxyListQuery{
|
||||
Group: query.Get("group"),
|
||||
TagsAny: splitCSV(query.Get("tags")),
|
||||
StatusIn: statusIn,
|
||||
OnlyEnabled: query.Get("only_enabled") == "true",
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
proxies, total, err := h.store.ListPaginated(r.Context(), q)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.ProxyListResponse{
|
||||
Data: proxies,
|
||||
Total: total,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleGetProxy 获取单个代理
|
||||
// GET /v1/proxies/{id}
|
||||
func (h *Handlers) HandleGetProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, proxy)
|
||||
}
|
||||
|
||||
// HandleDeleteProxy 删除单个代理
|
||||
// DELETE /v1/proxies/{id}
|
||||
func (h *Handlers) HandleDeleteProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Delete(r.Context(), id); err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]bool{"ok": true})
|
||||
}
|
||||
|
||||
// HandleBulkDeleteProxies 批量删除代理
|
||||
// DELETE /v1/proxies
|
||||
func (h *Handlers) HandleBulkDeleteProxies(w http.ResponseWriter, r *http.Request) {
|
||||
var req model.BulkDeleteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
deleted, err := h.store.DeleteMany(r.Context(), req)
|
||||
if err != nil {
|
||||
if err == model.ErrBulkDeleteEmpty {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, model.BulkDeleteResponse{Deleted: int(deleted)})
|
||||
}
|
||||
|
||||
// HandleUpdateProxy 更新代理
|
||||
// PATCH /v1/proxies/{id}
|
||||
func (h *Handlers) HandleUpdateProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var patch model.ProxyPatch
|
||||
if err := json.NewDecoder(r.Body).Decode(&patch); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.Update(r.Context(), id, patch); err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
if err == model.ErrInvalidPatch {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的代理
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, proxy)
|
||||
}
|
||||
|
||||
// HandleTestSingleProxy 测试单个代理
|
||||
// POST /v1/proxies/{id}/test
|
||||
func (h *Handlers) HandleTestSingleProxy(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "proxy id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req model.SingleTestRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "invalid JSON body")
|
||||
return
|
||||
}
|
||||
|
||||
// SSRF 防护
|
||||
if req.URL == "" {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", "url is required")
|
||||
return
|
||||
}
|
||||
if err := security.ValidateTestURL(req.URL); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "bad_request", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取代理
|
||||
proxy, err := h.store.GetByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if err == model.ErrProxyNotFound {
|
||||
writeError(w, http.StatusNotFound, "not_found", "proxy not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 构建测试规格
|
||||
timeout := time.Duration(req.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
spec := model.TestSpec{
|
||||
URL: req.URL,
|
||||
Method: coalesce(req.Method, "GET"),
|
||||
Timeout: timeout,
|
||||
ExpectStatus: req.ExpectStatus,
|
||||
ExpectContains: req.ExpectContains,
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
results := h.tester.TestBatch(r.Context(), []model.Proxy{*proxy}, spec, 1)
|
||||
if len(results) == 0 {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", "test failed")
|
||||
return
|
||||
}
|
||||
|
||||
result := results[0]
|
||||
|
||||
// 更新数据库
|
||||
if req.UpdateStore {
|
||||
now := result.CheckedAt
|
||||
if result.OK {
|
||||
status := model.StatusAlive
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &result.LatencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
} else {
|
||||
status := model.StatusDead
|
||||
_ = h.store.UpdateHealth(r.Context(), result.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 写入测试日志
|
||||
if req.WriteLog {
|
||||
_ = h.store.InsertTestLog(r.Context(), result, req.URL)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
// HandleGetStats 获取代理统计信息
|
||||
// GET /v1/proxies/stats
|
||||
func (h *Handlers) HandleGetStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := h.store.GetStats(r.Context())
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "internal_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// writeJSON 写入 JSON 响应
|
||||
func writeJSON(w http.ResponseWriter, status int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// writeError 写入错误响应
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": code,
|
||||
"message": message,
|
||||
})
|
||||
}
|
||||
|
||||
// coalesce 返回第一个非空字符串
|
||||
func coalesce(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// splitCSV 分割逗号分隔的字符串
|
||||
func splitCSV(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
136
internal/api/middleware.go
Normal file
136
internal/api/middleware.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LoggingMiddleware 请求日志中间件
|
||||
func LoggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装 ResponseWriter 以获取状态码
|
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(wrapped, r)
|
||||
|
||||
slog.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", wrapped.statusCode,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"remote_addr", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// AuthMiddleware API Key 鉴权中间件
|
||||
func AuthMiddleware(next http.Handler, apiKey string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 如果没有配置 API Key,跳过鉴权
|
||||
if apiKey == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization 头
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
if token == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-API-Key 头
|
||||
key := r.Header.Get("X-API-Key")
|
||||
if key == apiKey {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, `{"error":"unauthorized","message":"invalid or missing API key"}`, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 简单限流中间件(基于滑动窗口)
|
||||
type RateLimitMiddleware struct {
|
||||
requests map[string][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(limit int, window time.Duration) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap 包装处理器
|
||||
func (rl *RateLimitMiddleware) Wrap(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getClientIP(r)
|
||||
now := time.Now()
|
||||
|
||||
// 清理过期记录
|
||||
windowStart := now.Add(-rl.window)
|
||||
var valid []time.Time
|
||||
for _, t := range rl.requests[ip] {
|
||||
if t.After(windowStart) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
rl.requests[ip] = valid
|
||||
|
||||
// 检查限制
|
||||
if len(rl.requests[ip]) >= rl.limit {
|
||||
http.Error(w, `{"error":"rate_limit","message":"too many requests"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录请求
|
||||
rl.requests[ip] = append(rl.requests[ip], now)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端 IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 从 RemoteAddr 提取 IP
|
||||
ip := r.RemoteAddr
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// responseWriter 包装 ResponseWriter 以获取状态码
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
53
internal/api/router.go
Normal file
53
internal/api/router.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/importer"
|
||||
"proxyrotator/internal/selector"
|
||||
"proxyrotator/internal/store"
|
||||
"proxyrotator/internal/tester"
|
||||
)
|
||||
|
||||
// NewRouter 创建 HTTP 路由
|
||||
func NewRouter(
|
||||
store store.ProxyStore,
|
||||
importer *importer.Importer,
|
||||
tester *tester.HTTPTester,
|
||||
selector *selector.Selector,
|
||||
cfg *config.Config,
|
||||
) http.Handler {
|
||||
handlers := NewHandlers(store, importer, tester, selector, cfg)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 注册路由(Go 1.22+ 支持 METHOD /path 模式)
|
||||
mux.HandleFunc("POST /v1/proxies/import/text", handlers.HandleImportText)
|
||||
mux.HandleFunc("POST /v1/proxies/import/file", handlers.HandleImportFile)
|
||||
mux.HandleFunc("POST /v1/proxies/test", handlers.HandleTest)
|
||||
mux.HandleFunc("GET /v1/proxies/next", handlers.HandleNext)
|
||||
mux.HandleFunc("POST /v1/proxies/report", handlers.HandleReport)
|
||||
|
||||
// CRUD 路由(注意:/stats 需在 /{id} 之前注册)
|
||||
mux.HandleFunc("GET /v1/proxies/stats", handlers.HandleGetStats)
|
||||
mux.HandleFunc("GET /v1/proxies", handlers.HandleListProxies)
|
||||
mux.HandleFunc("GET /v1/proxies/{id}", handlers.HandleGetProxy)
|
||||
mux.HandleFunc("DELETE /v1/proxies/{id}", handlers.HandleDeleteProxy)
|
||||
mux.HandleFunc("DELETE /v1/proxies", handlers.HandleBulkDeleteProxies)
|
||||
mux.HandleFunc("PATCH /v1/proxies/{id}", handlers.HandleUpdateProxy)
|
||||
mux.HandleFunc("POST /v1/proxies/{id}/test", handlers.HandleTestSingleProxy)
|
||||
|
||||
// 健康检查
|
||||
mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
|
||||
// 应用中间件
|
||||
var handler http.Handler = mux
|
||||
handler = AuthMiddleware(handler, cfg.APIKey)
|
||||
handler = LoggingMiddleware(handler)
|
||||
|
||||
return handler
|
||||
}
|
||||
145
internal/config/config.go
Normal file
145
internal/config/config.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
DatabaseURL string
|
||||
ListenAddr string
|
||||
APIKey string
|
||||
ReturnSecret bool
|
||||
MaxConcurrency int
|
||||
MaxTestLimit int
|
||||
LeaseTTL time.Duration
|
||||
|
||||
// Telegram Bot 配置
|
||||
TelegramBotToken string
|
||||
TelegramAdminIDs []int64
|
||||
TelegramNotifyChatID string
|
||||
TelegramTestIntervalMin int
|
||||
TelegramAlertThreshold int
|
||||
TelegramTestURL string
|
||||
TelegramTestTimeoutMs int
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
func Load() *Config {
|
||||
cfg := &Config{
|
||||
DatabaseURL: getEnv("DATABASE_URL", "postgres://postgres:postgres@localhost:5432/proxyrotator?sslmode=disable"),
|
||||
ListenAddr: getEnv("LISTEN_ADDR", ":8080"),
|
||||
APIKey: getEnv("API_KEY", ""),
|
||||
ReturnSecret: getEnvBool("RETURN_SECRET", true),
|
||||
MaxConcurrency: getEnvInt("MAX_CONCURRENCY", 200),
|
||||
MaxTestLimit: getEnvInt("MAX_TEST_LIMIT", 2000),
|
||||
LeaseTTL: getEnvDuration("LEASE_TTL", 60*time.Second),
|
||||
|
||||
// Telegram
|
||||
TelegramBotToken: getEnv("TELEGRAM_BOT_TOKEN", ""),
|
||||
TelegramAdminIDs: getEnvInt64Slice("TELEGRAM_ADMIN_IDS", nil),
|
||||
TelegramNotifyChatID: getEnv("TELEGRAM_NOTIFY_CHAT_ID", ""),
|
||||
TelegramTestIntervalMin: getEnvInt("TELEGRAM_TEST_INTERVAL_MIN", 60),
|
||||
TelegramAlertThreshold: getEnvInt("TELEGRAM_ALERT_THRESHOLD", 50),
|
||||
TelegramTestURL: getEnv("TELEGRAM_TEST_URL", "https://httpbin.org/ip"),
|
||||
TelegramTestTimeoutMs: getEnvInt("TELEGRAM_TEST_TIMEOUT_MS", 5000),
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
b, err := strconv.ParseBool(v)
|
||||
if err == nil {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
i, err := strconv.Atoi(v)
|
||||
if err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
d, err := time.ParseDuration(v)
|
||||
if err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvInt64Slice(key string, defaultValue []int64) []int64 {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return defaultValue
|
||||
}
|
||||
parts := splitAndTrim(v, ",")
|
||||
result := make([]int64, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
i, err := strconv.ParseInt(p, 10, 64)
|
||||
if err == nil {
|
||||
result = append(result, i)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func splitAndTrim(s, sep string) []string {
|
||||
parts := make([]string, 0)
|
||||
for _, p := range stringsSplit(s, sep) {
|
||||
p = stringsTrim(p)
|
||||
if p != "" {
|
||||
parts = append(parts, p)
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func stringsSplit(s, sep string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0)
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if i+len(sep) <= len(s) && s[i:i+len(sep)] == sep {
|
||||
result = append(result, s[start:i])
|
||||
start = i + len(sep)
|
||||
i += len(sep) - 1
|
||||
}
|
||||
}
|
||||
result = append(result, s[start:])
|
||||
return result
|
||||
}
|
||||
|
||||
func stringsTrim(s string) string {
|
||||
start, end := 0, len(s)
|
||||
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
||||
start++
|
||||
}
|
||||
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
||||
end--
|
||||
}
|
||||
return s[start:end]
|
||||
}
|
||||
292
internal/importer/importer.go
Normal file
292
internal/importer/importer.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package importer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
// Importer 代理导入器
|
||||
type Importer struct{}
|
||||
|
||||
// NewImporter 创建导入器
|
||||
func NewImporter() *Importer {
|
||||
return &Importer{}
|
||||
}
|
||||
|
||||
// ParseText 解析文本格式的代理列表
|
||||
func (im *Importer) ParseText(ctx context.Context, in model.ImportInput, text string) ([]model.Proxy, []model.InvalidLine) {
|
||||
lines := strings.Split(text, "\n")
|
||||
var proxies []model.Proxy
|
||||
var invalid []model.InvalidLine
|
||||
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, raw := range lines {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || strings.HasPrefix(raw, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
p, err := ParseProxyLine(raw, in.ProtocolHint)
|
||||
if err != nil {
|
||||
invalid = append(invalid, model.InvalidLine{
|
||||
Raw: raw,
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
p.ID = uuid.New().String()
|
||||
p.Group = coalesce(p.Group, in.Group, "default")
|
||||
p.Tags = mergeTags(p.Tags, in.Tags)
|
||||
p.Status = model.StatusUnknown
|
||||
|
||||
// 内存去重
|
||||
key := dedupKey(p)
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
|
||||
proxies = append(proxies, *p)
|
||||
}
|
||||
|
||||
return proxies, invalid
|
||||
}
|
||||
|
||||
// ParseCSV 解析 CSV 格式的代理列表
|
||||
// 期望列: protocol,host,port,username,password,group,tags
|
||||
func (im *Importer) ParseCSV(ctx context.Context, in model.ImportInput, r io.Reader) ([]model.Proxy, []model.InvalidLine) {
|
||||
reader := csv.NewReader(r)
|
||||
reader.FieldsPerRecord = -1 // 允许不定列数
|
||||
reader.TrimLeadingSpace = true
|
||||
|
||||
var proxies []model.Proxy
|
||||
var invalid []model.InvalidLine
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 读取表头
|
||||
header, err := reader.Read()
|
||||
if err != nil {
|
||||
return nil, []model.InvalidLine{{Raw: "", Reason: "failed to read CSV header"}}
|
||||
}
|
||||
|
||||
// 解析列索引
|
||||
colIdx := parseHeader(header)
|
||||
|
||||
lineNum := 1
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
lineNum++
|
||||
|
||||
if err != nil {
|
||||
invalid = append(invalid, model.InvalidLine{
|
||||
Raw: strings.Join(record, ","),
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
p, err := parseCSVRecord(record, colIdx, in)
|
||||
if err != nil {
|
||||
invalid = append(invalid, model.InvalidLine{
|
||||
Raw: strings.Join(record, ","),
|
||||
Reason: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// 内存去重
|
||||
key := dedupKey(p)
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
|
||||
proxies = append(proxies, *p)
|
||||
}
|
||||
|
||||
return proxies, invalid
|
||||
}
|
||||
|
||||
// columnIndex CSV 列索引
|
||||
type columnIndex struct {
|
||||
protocol int
|
||||
host int
|
||||
port int
|
||||
username int
|
||||
password int
|
||||
group int
|
||||
tags int
|
||||
}
|
||||
|
||||
// parseHeader 解析 CSV 表头
|
||||
func parseHeader(header []string) columnIndex {
|
||||
idx := columnIndex{
|
||||
protocol: -1,
|
||||
host: -1,
|
||||
port: -1,
|
||||
username: -1,
|
||||
password: -1,
|
||||
group: -1,
|
||||
tags: -1,
|
||||
}
|
||||
|
||||
for i, col := range header {
|
||||
switch strings.ToLower(strings.TrimSpace(col)) {
|
||||
case "protocol":
|
||||
idx.protocol = i
|
||||
case "host":
|
||||
idx.host = i
|
||||
case "port":
|
||||
idx.port = i
|
||||
case "username", "user":
|
||||
idx.username = i
|
||||
case "password", "pass":
|
||||
idx.password = i
|
||||
case "group":
|
||||
idx.group = i
|
||||
case "tags":
|
||||
idx.tags = i
|
||||
}
|
||||
}
|
||||
|
||||
return idx
|
||||
}
|
||||
|
||||
// parseCSVRecord 解析 CSV 记录
|
||||
func parseCSVRecord(record []string, idx columnIndex, in model.ImportInput) (*model.Proxy, error) {
|
||||
get := func(i int) string {
|
||||
if i >= 0 && i < len(record) {
|
||||
return strings.TrimSpace(record[i])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 如果没有表头,尝试按位置解析
|
||||
if idx.host == -1 && len(record) >= 2 {
|
||||
// 假设格式: host,port 或 host,port,username,password
|
||||
line := strings.Join(record, ":")
|
||||
if len(record) >= 4 {
|
||||
line = record[2] + ":" + record[3] + "@" + record[0] + ":" + record[1]
|
||||
} else {
|
||||
line = record[0] + ":" + record[1]
|
||||
}
|
||||
p, err := ParseProxyLine(line, in.ProtocolHint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.ID = uuid.New().String()
|
||||
p.Group = coalesce(in.Group, "default")
|
||||
p.Tags = in.Tags
|
||||
p.Status = model.StatusUnknown
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// 根据表头解析
|
||||
protocol := get(idx.protocol)
|
||||
host := get(idx.host)
|
||||
portStr := get(idx.port)
|
||||
|
||||
if host == "" {
|
||||
return nil, errors.New("missing host")
|
||||
}
|
||||
|
||||
var p model.Proxy
|
||||
p.ID = uuid.New().String()
|
||||
p.Host = host
|
||||
p.Status = model.StatusUnknown
|
||||
|
||||
// 解析协议
|
||||
switch strings.ToLower(protocol) {
|
||||
case "http":
|
||||
p.Protocol = model.ProtoHTTP
|
||||
case "https":
|
||||
p.Protocol = model.ProtoHTTPS
|
||||
case "socks5":
|
||||
p.Protocol = model.ProtoSOCKS5
|
||||
default:
|
||||
p.Protocol = model.ProtoHTTP
|
||||
}
|
||||
|
||||
// 解析端口
|
||||
if portStr != "" {
|
||||
port := 0
|
||||
for _, c := range portStr {
|
||||
if c >= '0' && c <= '9' {
|
||||
port = port*10 + int(c-'0')
|
||||
}
|
||||
}
|
||||
if port > 0 && port < 65536 {
|
||||
p.Port = port
|
||||
} else {
|
||||
p.Port = 80
|
||||
}
|
||||
} else {
|
||||
p.Port = 80
|
||||
}
|
||||
|
||||
p.Username = get(idx.username)
|
||||
p.Password = get(idx.password)
|
||||
p.Group = coalesce(get(idx.group), in.Group, "default")
|
||||
|
||||
// 解析 tags
|
||||
tagsStr := get(idx.tags)
|
||||
if tagsStr != "" {
|
||||
p.Tags = strings.Split(tagsStr, ";")
|
||||
}
|
||||
p.Tags = mergeTags(p.Tags, in.Tags)
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// dedupKey 生成去重键
|
||||
func dedupKey(p *model.Proxy) string {
|
||||
return string(p.Protocol) + ":" + p.Host + ":" + strconv.Itoa(p.Port) + ":" + p.Username
|
||||
}
|
||||
|
||||
// coalesce 返回第一个非空字符串
|
||||
func coalesce(values ...string) string {
|
||||
for _, v := range values {
|
||||
if v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// mergeTags 合并去重 tags
|
||||
func mergeTags(a, b []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var result []string
|
||||
|
||||
for _, t := range a {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" && !seen[t] {
|
||||
seen[t] = true
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range b {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" && !seen[t] {
|
||||
seen[t] = true
|
||||
result = append(result, t)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
171
internal/importer/parser.go
Normal file
171
internal/importer/parser.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package importer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
// 匹配 host:port 格式
|
||||
hostPortRegex = regexp.MustCompile(`^([a-zA-Z0-9.-]+):(\d+)$`)
|
||||
// 匹配 user:pass@host:port 格式
|
||||
userPassHostPortRegex = regexp.MustCompile(`^([^:@]+):([^@]+)@([a-zA-Z0-9.-]+):(\d+)$`)
|
||||
// 匹配 host:port:user:pass 格式
|
||||
hostPortUserPassRegex = regexp.MustCompile(`^([a-zA-Z0-9.-]+):(\d+):([^:]+):(.+)$`)
|
||||
)
|
||||
|
||||
// ParseProxyLine 解析单行代理格式
|
||||
// 支持格式:
|
||||
// - host:port
|
||||
// - user:pass@host:port
|
||||
// - host:port:user:pass
|
||||
// - http://host:port
|
||||
// - http://user:pass@host:port
|
||||
// - socks5://host:port
|
||||
// - socks5://user:pass@host:port
|
||||
func ParseProxyLine(raw string, protocolHint string) (*model.Proxy, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, fmt.Errorf("empty line")
|
||||
}
|
||||
|
||||
// 尝试解析为 URL
|
||||
if strings.Contains(raw, "://") {
|
||||
return parseAsURL(raw)
|
||||
}
|
||||
|
||||
// 尝试解析 host:port:user:pass 格式
|
||||
if matches := hostPortUserPassRegex.FindStringSubmatch(raw); matches != nil {
|
||||
port, err := strconv.Atoi(matches[2])
|
||||
if err != nil || port <= 0 || port >= 65536 {
|
||||
return nil, fmt.Errorf("invalid port: %s", matches[2])
|
||||
}
|
||||
|
||||
protocol := inferProtocol(protocolHint, port)
|
||||
return &model.Proxy{
|
||||
Protocol: protocol,
|
||||
Host: matches[1],
|
||||
Port: port,
|
||||
Username: matches[3],
|
||||
Password: matches[4],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 尝试解析 user:pass@host:port 格式
|
||||
if matches := userPassHostPortRegex.FindStringSubmatch(raw); matches != nil {
|
||||
port, err := strconv.Atoi(matches[4])
|
||||
if err != nil || port <= 0 || port >= 65536 {
|
||||
return nil, fmt.Errorf("invalid port: %s", matches[4])
|
||||
}
|
||||
|
||||
protocol := inferProtocol(protocolHint, port)
|
||||
return &model.Proxy{
|
||||
Protocol: protocol,
|
||||
Host: matches[3],
|
||||
Port: port,
|
||||
Username: matches[1],
|
||||
Password: matches[2],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 尝试解析 host:port 格式
|
||||
if matches := hostPortRegex.FindStringSubmatch(raw); matches != nil {
|
||||
port, err := strconv.Atoi(matches[2])
|
||||
if err != nil || port <= 0 || port >= 65536 {
|
||||
return nil, fmt.Errorf("invalid port: %s", matches[2])
|
||||
}
|
||||
|
||||
protocol := inferProtocol(protocolHint, port)
|
||||
return &model.Proxy{
|
||||
Protocol: protocol,
|
||||
Host: matches[1],
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unrecognized format")
|
||||
}
|
||||
|
||||
// parseAsURL 解析 URL 格式的代理
|
||||
func parseAsURL(raw string) (*model.Proxy, error) {
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
var protocol model.ProxyProtocol
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "http":
|
||||
protocol = model.ProtoHTTP
|
||||
case "https":
|
||||
protocol = model.ProtoHTTPS
|
||||
case "socks5":
|
||||
protocol = model.ProtoSOCKS5
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("missing host")
|
||||
}
|
||||
|
||||
portStr := u.Port()
|
||||
if portStr == "" {
|
||||
// 默认端口
|
||||
switch protocol {
|
||||
case model.ProtoHTTP:
|
||||
portStr = "80"
|
||||
case model.ProtoHTTPS:
|
||||
portStr = "443"
|
||||
case model.ProtoSOCKS5:
|
||||
portStr = "1080"
|
||||
}
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port <= 0 || port >= 65536 {
|
||||
return nil, fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
var username, password string
|
||||
if u.User != nil {
|
||||
username = u.User.Username()
|
||||
password, _ = u.User.Password()
|
||||
}
|
||||
|
||||
return &model.Proxy{
|
||||
Protocol: protocol,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// inferProtocol 根据提示和端口推断协议
|
||||
func inferProtocol(hint string, port int) model.ProxyProtocol {
|
||||
switch strings.ToLower(hint) {
|
||||
case "http":
|
||||
return model.ProtoHTTP
|
||||
case "https":
|
||||
return model.ProtoHTTPS
|
||||
case "socks5":
|
||||
return model.ProtoSOCKS5
|
||||
}
|
||||
|
||||
// 根据端口推断
|
||||
switch port {
|
||||
case 443:
|
||||
return model.ProtoHTTPS
|
||||
case 1080:
|
||||
return model.ProtoSOCKS5
|
||||
default:
|
||||
return model.ProtoHTTP
|
||||
}
|
||||
}
|
||||
16
internal/model/errors.go
Normal file
16
internal/model/errors.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoProxy = errors.New("no available proxy")
|
||||
ErrBadModulo = errors.New("modulo must be positive")
|
||||
ErrBadPolicy = errors.New("unknown selection policy")
|
||||
ErrLeaseExpired = errors.New("lease expired or not found")
|
||||
ErrProxyNotFound = errors.New("proxy not found")
|
||||
ErrInvalidURL = errors.New("invalid URL")
|
||||
ErrPrivateIP = errors.New("private IP address not allowed")
|
||||
ErrUnsafeScheme = errors.New("only http and https schemes are allowed")
|
||||
ErrInvalidPatch = errors.New("invalid patch: no fields to update")
|
||||
ErrBulkDeleteEmpty = errors.New("bulk delete requires at least one condition")
|
||||
)
|
||||
258
internal/model/types.go
Normal file
258
internal/model/types.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ProxyProtocol 代理协议类型
|
||||
type ProxyProtocol string
|
||||
|
||||
const (
|
||||
ProtoHTTP ProxyProtocol = "http"
|
||||
ProtoHTTPS ProxyProtocol = "https"
|
||||
ProtoSOCKS5 ProxyProtocol = "socks5"
|
||||
)
|
||||
|
||||
// ProxyStatus 代理状态
|
||||
type ProxyStatus string
|
||||
|
||||
const (
|
||||
StatusUnknown ProxyStatus = "unknown"
|
||||
StatusAlive ProxyStatus = "alive"
|
||||
StatusDead ProxyStatus = "dead"
|
||||
)
|
||||
|
||||
// Proxy 代理实体
|
||||
type Proxy struct {
|
||||
ID string // uuid
|
||||
|
||||
Protocol ProxyProtocol
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
|
||||
Group string
|
||||
Tags []string
|
||||
|
||||
Status ProxyStatus
|
||||
Score int
|
||||
LatencyMs int64
|
||||
LastCheckAt time.Time
|
||||
|
||||
FailCount int
|
||||
SuccessCount int
|
||||
Disabled bool
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// HealthPatch 健康度更新补丁
|
||||
type HealthPatch struct {
|
||||
Status *ProxyStatus
|
||||
ScoreDelta int
|
||||
LatencyMs *int64
|
||||
CheckedAt *time.Time
|
||||
FailInc int
|
||||
SuccessInc int
|
||||
}
|
||||
|
||||
// TestSpec 测试规格
|
||||
type TestSpec struct {
|
||||
URL string
|
||||
Method string
|
||||
Timeout time.Duration
|
||||
ExpectStatus []int
|
||||
ExpectContains string
|
||||
}
|
||||
|
||||
// TestResult 测试结果
|
||||
type TestResult struct {
|
||||
ProxyID string
|
||||
OK bool
|
||||
LatencyMs int64
|
||||
ErrorText string
|
||||
CheckedAt time.Time
|
||||
}
|
||||
|
||||
// Lease 租约
|
||||
type Lease struct {
|
||||
LeaseID string
|
||||
ProxyID string
|
||||
Proxy Proxy
|
||||
ExpireAt time.Time
|
||||
Group string
|
||||
Site string
|
||||
}
|
||||
|
||||
// ProxyQuery 代理查询条件
|
||||
type ProxyQuery struct {
|
||||
Group string
|
||||
TagsAny []string
|
||||
StatusIn []ProxyStatus
|
||||
OnlyEnabled bool
|
||||
Limit int
|
||||
OrderBy string // "random", "score", "latency",默认按 score 降序
|
||||
}
|
||||
|
||||
// InvalidLine 无效行记录
|
||||
type InvalidLine struct {
|
||||
Raw string `json:"raw"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// SelectRequest 代理选择请求
|
||||
type SelectRequest struct {
|
||||
Group string
|
||||
Site string
|
||||
Policy string // round_robin, random, weighted
|
||||
TagsAny []string
|
||||
}
|
||||
|
||||
// ImportInput 导入输入参数
|
||||
type ImportInput struct {
|
||||
Group string
|
||||
Tags []string
|
||||
ProtocolHint string // auto, http, https, socks5
|
||||
}
|
||||
|
||||
// ImportResult 导入结果
|
||||
type ImportResult struct {
|
||||
Imported int `json:"imported"`
|
||||
Duplicated int `json:"duplicated"`
|
||||
Invalid int `json:"invalid"`
|
||||
InvalidItems []InvalidLine `json:"invalid_items,omitempty"`
|
||||
}
|
||||
|
||||
// TestSummary 测试摘要
|
||||
type TestSummary struct {
|
||||
Tested int `json:"tested"`
|
||||
Alive int `json:"alive"`
|
||||
Dead int `json:"dead"`
|
||||
}
|
||||
|
||||
// TestBatchResult 批量测试结果
|
||||
type TestBatchResult struct {
|
||||
Summary TestSummary `json:"summary"`
|
||||
Results []TestResult `json:"results"`
|
||||
}
|
||||
|
||||
// NextProxyResponse 获取下一个代理的响应
|
||||
type NextProxyResponse struct {
|
||||
Proxy ProxyInfo `json:"proxy"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
TTLMs int64 `json:"ttl_ms"`
|
||||
}
|
||||
|
||||
// ProxyInfo 代理信息(用于 API 响应)
|
||||
type ProxyInfo struct {
|
||||
ID string `json:"id"`
|
||||
Protocol ProxyProtocol `json:"protocol"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// ReportRequest 上报请求
|
||||
type ReportRequest struct {
|
||||
LeaseID string `json:"lease_id"`
|
||||
ProxyID string `json:"proxy_id"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
}
|
||||
|
||||
// TestRequest 测试请求
|
||||
type TestRequest struct {
|
||||
Group string `json:"group"`
|
||||
Filter ProxyFilter `json:"filter"`
|
||||
TestSpec TestSpecReq `json:"test_spec"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
UpdateStore bool `json:"update_store"`
|
||||
WriteLog bool `json:"write_log"`
|
||||
}
|
||||
|
||||
// ProxyFilter 代理过滤条件
|
||||
type ProxyFilter struct {
|
||||
Status []string `json:"status"`
|
||||
TagsAny []string `json:"tags_any"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// TestSpecReq 测试规格请求
|
||||
type TestSpecReq struct {
|
||||
URL string `json:"url"`
|
||||
Method string `json:"method"`
|
||||
TimeoutMs int `json:"timeout_ms"`
|
||||
ExpectStatus []int `json:"expect_status"`
|
||||
ExpectContains string `json:"expect_contains"`
|
||||
}
|
||||
|
||||
// ImportTextRequest 文本导入请求
|
||||
type ImportTextRequest struct {
|
||||
Group string `json:"group"`
|
||||
Tags []string `json:"tags"`
|
||||
ProtocolHint string `json:"protocol_hint"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ProxyListQuery 代理列表查询条件(带分页)
|
||||
type ProxyListQuery struct {
|
||||
Group string
|
||||
TagsAny []string
|
||||
StatusIn []ProxyStatus
|
||||
OnlyEnabled bool
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// ProxyListResponse 代理列表分页响应
|
||||
type ProxyListResponse struct {
|
||||
Data []Proxy `json:"data"`
|
||||
Total int `json:"total"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// ProxyPatch 代理更新补丁
|
||||
type ProxyPatch struct {
|
||||
Group *string `json:"group,omitempty"`
|
||||
Tags *[]string `json:"tags,omitempty"`
|
||||
AddTags []string `json:"add_tags,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
// BulkDeleteRequest 批量删除请求
|
||||
type BulkDeleteRequest struct {
|
||||
IDs []string `json:"ids,omitempty"`
|
||||
Status ProxyStatus `json:"status,omitempty"`
|
||||
Group string `json:"group,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
}
|
||||
|
||||
// BulkDeleteResponse 批量删除响应
|
||||
type BulkDeleteResponse struct {
|
||||
Deleted int `json:"deleted"`
|
||||
}
|
||||
|
||||
// SingleTestRequest 单个代理测试请求
|
||||
type SingleTestRequest struct {
|
||||
URL string `json:"url"`
|
||||
Method string `json:"method,omitempty"`
|
||||
TimeoutMs int `json:"timeout_ms,omitempty"`
|
||||
ExpectStatus []int `json:"expect_status,omitempty"`
|
||||
ExpectContains string `json:"expect_contains,omitempty"`
|
||||
UpdateStore bool `json:"update_store"`
|
||||
WriteLog bool `json:"write_log"`
|
||||
}
|
||||
|
||||
// ProxyStats 代理统计信息
|
||||
type ProxyStats struct {
|
||||
Total int `json:"total"`
|
||||
ByStatus map[ProxyStatus]int `json:"by_status"`
|
||||
ByGroup map[string]int `json:"by_group"`
|
||||
ByProtocol map[ProxyProtocol]int `json:"by_protocol"`
|
||||
Disabled int `json:"disabled"`
|
||||
AvgLatencyMs int64 `json:"avg_latency_ms"`
|
||||
AvgScore float64 `json:"avg_score"`
|
||||
}
|
||||
71
internal/security/validate_url.go
Normal file
71
internal/security/validate_url.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
// ValidateTestURL 校验测试目标 URL,防止 SSRF
|
||||
func ValidateTestURL(rawURL string) error {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return model.ErrInvalidURL
|
||||
}
|
||||
|
||||
// 只允许 http/https
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return model.ErrUnsafeScheme
|
||||
}
|
||||
|
||||
// 解析主机名
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return model.ErrInvalidURL
|
||||
}
|
||||
|
||||
// 解析 IP 地址
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return model.ErrInvalidURL
|
||||
}
|
||||
|
||||
// 检查是否为私网 IP
|
||||
for _, ip := range ips {
|
||||
if IsPrivateIP(ip) {
|
||||
return model.ErrPrivateIP
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsPrivateIP 判断是否为私网 IP
|
||||
func IsPrivateIP(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 回环地址
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 链路本地地址
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 私有地址
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 未指定地址
|
||||
if ip.IsUnspecified() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
177
internal/selector/selector.go
Normal file
177
internal/selector/selector.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
"proxyrotator/internal/store"
|
||||
)
|
||||
|
||||
// Selector 代理选择器
|
||||
type Selector struct {
|
||||
store store.ProxyStore
|
||||
leaseTTL time.Duration
|
||||
}
|
||||
|
||||
// NewSelector 创建选择器
|
||||
func NewSelector(store store.ProxyStore, leaseTTL time.Duration) *Selector {
|
||||
if leaseTTL <= 0 {
|
||||
leaseTTL = 60 * time.Second
|
||||
}
|
||||
return &Selector{
|
||||
store: store,
|
||||
leaseTTL: leaseTTL,
|
||||
}
|
||||
}
|
||||
|
||||
// Next 获取下一个可用代理
|
||||
func (s *Selector) Next(ctx context.Context, req model.SelectRequest) (*model.Lease, error) {
|
||||
policy := req.Policy
|
||||
if policy == "" {
|
||||
policy = "round_robin"
|
||||
}
|
||||
|
||||
// 查询可用代理
|
||||
proxies, err := s.store.List(ctx, model.ProxyQuery{
|
||||
Group: req.Group,
|
||||
TagsAny: req.TagsAny,
|
||||
StatusIn: []model.ProxyStatus{model.StatusAlive},
|
||||
OnlyEnabled: true,
|
||||
Limit: 5000,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
return nil, model.ErrNoProxy
|
||||
}
|
||||
|
||||
// 根据策略选择
|
||||
var chosen model.Proxy
|
||||
switch policy {
|
||||
case "round_robin":
|
||||
key := "rr:" + req.Group + ":" + normalizeSite(req.Site)
|
||||
idx, err := s.store.NextIndex(ctx, key, len(proxies))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chosen = proxies[idx]
|
||||
|
||||
case "random":
|
||||
idx := mathrand.Intn(len(proxies))
|
||||
chosen = proxies[idx]
|
||||
|
||||
case "weighted":
|
||||
chosen = weightedPickByScore(proxies)
|
||||
|
||||
default:
|
||||
return nil, model.ErrBadPolicy
|
||||
}
|
||||
|
||||
// 创建租约
|
||||
lease := model.Lease{
|
||||
LeaseID: newLeaseID(),
|
||||
ProxyID: chosen.ID,
|
||||
Proxy: chosen,
|
||||
Group: req.Group,
|
||||
Site: req.Site,
|
||||
ExpireAt: time.Now().Add(s.leaseTTL),
|
||||
}
|
||||
|
||||
// 尝试保存租约(失败也可降级不存)
|
||||
_ = s.store.CreateLease(ctx, lease)
|
||||
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
// Report 上报使用结果
|
||||
func (s *Selector) Report(ctx context.Context, leaseID, proxyID string, success bool, latencyMs int64, errText string) error {
|
||||
now := time.Now()
|
||||
|
||||
if success {
|
||||
status := model.StatusAlive
|
||||
return s.store.UpdateHealth(ctx, proxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &latencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
|
||||
status := model.StatusDead
|
||||
return s.store.UpdateHealth(ctx, proxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
|
||||
// normalizeSite 规范化站点 URL(提取域名)
|
||||
func normalizeSite(site string) string {
|
||||
if site == "" {
|
||||
return "default"
|
||||
}
|
||||
|
||||
u, err := url.Parse(site)
|
||||
if err != nil {
|
||||
return site
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return site
|
||||
}
|
||||
|
||||
// 去除 www 前缀
|
||||
host = strings.TrimPrefix(host, "www.")
|
||||
return host
|
||||
}
|
||||
|
||||
// weightedPickByScore 按分数加权随机选择
|
||||
func weightedPickByScore(proxies []model.Proxy) model.Proxy {
|
||||
// 计算权重(分数 + 偏移量确保正数)
|
||||
const offset = 100
|
||||
totalWeight := 0
|
||||
weights := make([]int, len(proxies))
|
||||
|
||||
for i, p := range proxies {
|
||||
w := p.Score + offset
|
||||
if w < 1 {
|
||||
w = 1
|
||||
}
|
||||
weights[i] = w
|
||||
totalWeight += w
|
||||
}
|
||||
|
||||
// 随机选择
|
||||
r := mathrand.Intn(totalWeight)
|
||||
for i, w := range weights {
|
||||
r -= w
|
||||
if r < 0 {
|
||||
return proxies[i]
|
||||
}
|
||||
}
|
||||
|
||||
return proxies[0]
|
||||
}
|
||||
|
||||
// newLeaseID 生成租约 ID
|
||||
func newLeaseID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// fallback
|
||||
n, _ := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
return "lease_" + n.String()
|
||||
}
|
||||
return "lease_" + hex.EncodeToString(b)
|
||||
}
|
||||
698
internal/store/pg_store.go
Normal file
698
internal/store/pg_store.go
Normal file
@@ -0,0 +1,698 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
// PgStore PostgreSQL 存储实现
|
||||
type PgStore struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewPgStore 创建 PostgreSQL 存储
|
||||
func NewPgStore(ctx context.Context, connString string) (*PgStore, error) {
|
||||
pool, err := pgxpool.New(ctx, connString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &PgStore{pool: pool}, nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (s *PgStore) Close() error {
|
||||
s.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pool 返回连接池(用于其他模块共享)
|
||||
func (s *PgStore) Pool() *pgxpool.Pool {
|
||||
return s.pool
|
||||
}
|
||||
|
||||
// UpsertMany 批量导入代理
|
||||
func (s *PgStore) UpsertMany(ctx context.Context, proxies []model.Proxy) (imported, duplicated int, err error) {
|
||||
if len(proxies) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
const sqlUpsert = `
|
||||
INSERT INTO proxies (id, protocol, host, port, username, password, "group", tags)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (protocol, host, port, username)
|
||||
DO UPDATE SET
|
||||
password = EXCLUDED.password,
|
||||
"group" = EXCLUDED."group",
|
||||
tags = (
|
||||
SELECT ARRAY(
|
||||
SELECT DISTINCT unnest(proxies.tags || EXCLUDED.tags)
|
||||
)
|
||||
)
|
||||
RETURNING (xmax = 0) AS inserted
|
||||
`
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
for _, p := range proxies {
|
||||
var inserted bool
|
||||
err := tx.QueryRow(ctx, sqlUpsert,
|
||||
p.ID, p.Protocol, p.Host, p.Port, p.Username, p.Password, p.Group, p.Tags,
|
||||
).Scan(&inserted)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to upsert proxy: %w", err)
|
||||
}
|
||||
if inserted {
|
||||
imported++
|
||||
} else {
|
||||
duplicated++
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return 0, 0, fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return imported, duplicated, nil
|
||||
}
|
||||
|
||||
// List 查询代理列表
|
||||
func (s *PgStore) List(ctx context.Context, q model.ProxyQuery) ([]model.Proxy, error) {
|
||||
var conditions []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if q.Group != "" {
|
||||
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
|
||||
args = append(args, q.Group)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.OnlyEnabled {
|
||||
conditions = append(conditions, "disabled = false")
|
||||
}
|
||||
|
||||
if len(q.StatusIn) > 0 {
|
||||
placeholders := make([]string, len(q.StatusIn))
|
||||
for i, status := range q.StatusIn {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, status)
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
|
||||
if len(q.TagsAny) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx))
|
||||
args = append(args, q.TagsAny)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
sql := `SELECT id, protocol, host, port, username, password, "group", tags,
|
||||
status, score, latency_ms, last_check_at, fail_count, success_count,
|
||||
disabled, created_at, updated_at
|
||||
FROM proxies`
|
||||
|
||||
if len(conditions) > 0 {
|
||||
sql += " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
// 排序方式
|
||||
switch q.OrderBy {
|
||||
case "random":
|
||||
sql += " ORDER BY RANDOM()"
|
||||
case "latency":
|
||||
sql += " ORDER BY latency_ms ASC NULLS LAST"
|
||||
default:
|
||||
sql += " ORDER BY score DESC, last_check_at DESC NULLS LAST"
|
||||
}
|
||||
|
||||
if q.Limit > 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", q.Limit)
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query proxies: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanProxies(rows)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取代理
|
||||
func (s *PgStore) GetByID(ctx context.Context, id string) (*model.Proxy, error) {
|
||||
const sql = `SELECT id, protocol, host, port, username, password, "group", tags,
|
||||
status, score, latency_ms, last_check_at, fail_count, success_count,
|
||||
disabled, created_at, updated_at
|
||||
FROM proxies WHERE id = $1`
|
||||
|
||||
row := s.pool.QueryRow(ctx, sql, id)
|
||||
p, err := scanProxy(row)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, model.ErrProxyNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get proxy: %w", err)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// UpdateHealth 更新代理健康度
|
||||
func (s *PgStore) UpdateHealth(ctx context.Context, proxyID string, patch model.HealthPatch) error {
|
||||
var sets []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if patch.Status != nil {
|
||||
sets = append(sets, fmt.Sprintf("status = $%d", argIdx))
|
||||
args = append(args, *patch.Status)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.ScoreDelta != 0 {
|
||||
sets = append(sets, fmt.Sprintf("score = GREATEST(-1000, LEAST(1000, score + $%d))", argIdx))
|
||||
args = append(args, patch.ScoreDelta)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.LatencyMs != nil {
|
||||
sets = append(sets, fmt.Sprintf("latency_ms = $%d", argIdx))
|
||||
args = append(args, *patch.LatencyMs)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.CheckedAt != nil {
|
||||
sets = append(sets, fmt.Sprintf("last_check_at = $%d", argIdx))
|
||||
args = append(args, *patch.CheckedAt)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.FailInc > 0 {
|
||||
sets = append(sets, fmt.Sprintf("fail_count = fail_count + $%d", argIdx))
|
||||
args = append(args, patch.FailInc)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.SuccessInc > 0 {
|
||||
sets = append(sets, fmt.Sprintf("success_count = success_count + $%d", argIdx))
|
||||
args = append(args, patch.SuccessInc)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx)
|
||||
args = append(args, proxyID)
|
||||
|
||||
_, err := s.pool.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update health: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextIndex RR 原子游标
|
||||
func (s *PgStore) NextIndex(ctx context.Context, key string, modulo int) (int, error) {
|
||||
if modulo <= 0 {
|
||||
return 0, model.ErrBadModulo
|
||||
}
|
||||
|
||||
const sql = `
|
||||
INSERT INTO rr_cursors (k, v)
|
||||
VALUES ($1, 0)
|
||||
ON CONFLICT (k)
|
||||
DO UPDATE SET v = rr_cursors.v + 1, updated_at = now()
|
||||
RETURNING v
|
||||
`
|
||||
|
||||
var v int64
|
||||
err := s.pool.QueryRow(ctx, sql, key).Scan(&v)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get next index: %w", err)
|
||||
}
|
||||
|
||||
idx := int(v % int64(modulo))
|
||||
if idx < 0 {
|
||||
idx = -idx
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
// CreateLease 创建租约
|
||||
func (s *PgStore) CreateLease(ctx context.Context, lease model.Lease) error {
|
||||
const sql = `
|
||||
INSERT INTO proxy_leases (lease_id, proxy_id, expire_at, site, "group")
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`
|
||||
|
||||
_, err := s.pool.Exec(ctx, sql,
|
||||
lease.LeaseID, lease.ProxyID, lease.ExpireAt, lease.Site, lease.Group,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create lease: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLease 获取租约
|
||||
func (s *PgStore) GetLease(ctx context.Context, leaseID string) (*model.Lease, error) {
|
||||
const sql = `
|
||||
SELECT lease_id, proxy_id, expire_at, site, "group"
|
||||
FROM proxy_leases
|
||||
WHERE lease_id = $1 AND expire_at > now()
|
||||
`
|
||||
|
||||
var lease model.Lease
|
||||
err := s.pool.QueryRow(ctx, sql, leaseID).Scan(
|
||||
&lease.LeaseID, &lease.ProxyID, &lease.ExpireAt, &lease.Site, &lease.Group,
|
||||
)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, model.ErrLeaseExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get lease: %w", err)
|
||||
}
|
||||
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
// DeleteExpiredLeases 删除过期租约
|
||||
func (s *PgStore) DeleteExpiredLeases(ctx context.Context) (int64, error) {
|
||||
const sql = `DELETE FROM proxy_leases WHERE expire_at <= now()`
|
||||
|
||||
result, err := s.pool.Exec(ctx, sql)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete expired leases: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// InsertTestLog 插入测试日志
|
||||
func (s *PgStore) InsertTestLog(ctx context.Context, r model.TestResult, site string) error {
|
||||
const sql = `
|
||||
INSERT INTO proxy_test_logs (proxy_id, site, ok, latency_ms, error_text, checked_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`
|
||||
|
||||
_, err := s.pool.Exec(ctx, sql,
|
||||
r.ProxyID, site, r.OK, r.LatencyMs, r.ErrorText, r.CheckedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert test log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPaginated 分页查询代理列表
|
||||
func (s *PgStore) ListPaginated(ctx context.Context, q model.ProxyListQuery) ([]model.Proxy, int, error) {
|
||||
var conditions []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if q.Group != "" {
|
||||
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
|
||||
args = append(args, q.Group)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.OnlyEnabled {
|
||||
conditions = append(conditions, "disabled = false")
|
||||
}
|
||||
|
||||
if len(q.StatusIn) > 0 {
|
||||
placeholders := make([]string, len(q.StatusIn))
|
||||
for i, status := range q.StatusIn {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, status)
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
|
||||
if len(q.TagsAny) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx))
|
||||
args = append(args, q.TagsAny)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
countSQL := "SELECT COUNT(*) FROM proxies" + whereClause
|
||||
var total int
|
||||
if err := s.pool.QueryRow(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count proxies: %w", err)
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
dataSQL := `SELECT id, protocol, host, port, username, password, "group", tags,
|
||||
status, score, latency_ms, last_check_at, fail_count, success_count,
|
||||
disabled, created_at, updated_at FROM proxies` + whereClause +
|
||||
" ORDER BY score DESC, last_check_at DESC NULLS LAST"
|
||||
|
||||
dataSQL += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, dataSQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query proxies: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
proxies, err := scanProxies(rows)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return proxies, total, nil
|
||||
}
|
||||
|
||||
// Update 更新代理字段
|
||||
func (s *PgStore) Update(ctx context.Context, id string, patch model.ProxyPatch) error {
|
||||
var sets []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if patch.Group != nil {
|
||||
sets = append(sets, fmt.Sprintf(`"group" = $%d`, argIdx))
|
||||
args = append(args, *patch.Group)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.Tags != nil {
|
||||
sets = append(sets, fmt.Sprintf("tags = $%d", argIdx))
|
||||
args = append(args, *patch.Tags)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(patch.AddTags) > 0 {
|
||||
sets = append(sets, fmt.Sprintf(`tags = (
|
||||
SELECT ARRAY(SELECT DISTINCT unnest(tags || $%d))
|
||||
)`, argIdx))
|
||||
args = append(args, patch.AddTags)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if patch.Disabled != nil {
|
||||
sets = append(sets, fmt.Sprintf("disabled = $%d", argIdx))
|
||||
args = append(args, *patch.Disabled)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(sets) == 0 {
|
||||
return model.ErrInvalidPatch
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx)
|
||||
args = append(args, id)
|
||||
|
||||
result, err := s.pool.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update proxy: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return model.ErrProxyNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除单个代理
|
||||
func (s *PgStore) Delete(ctx context.Context, id string) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// 删除租约
|
||||
if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = $1", id); err != nil {
|
||||
return fmt.Errorf("failed to delete leases: %w", err)
|
||||
}
|
||||
|
||||
// 删除测试日志
|
||||
if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = $1", id); err != nil {
|
||||
return fmt.Errorf("failed to delete test logs: %w", err)
|
||||
}
|
||||
|
||||
// 删除代理
|
||||
result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = $1", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete proxy: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return model.ErrProxyNotFound
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// DeleteMany 批量删除代理
|
||||
func (s *PgStore) DeleteMany(ctx context.Context, req model.BulkDeleteRequest) (int64, error) {
|
||||
var conditions []string
|
||||
var args []any
|
||||
argIdx := 1
|
||||
|
||||
if len(req.IDs) > 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("id = ANY($%d)", argIdx))
|
||||
args = append(args, req.IDs)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Status != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("status = $%d", argIdx))
|
||||
args = append(args, req.Status)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Group != "" {
|
||||
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
|
||||
args = append(args, req.Group)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Disabled != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("disabled = $%d", argIdx))
|
||||
args = append(args, *req.Disabled)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return 0, model.ErrBulkDeleteEmpty
|
||||
}
|
||||
|
||||
whereClause := strings.Join(conditions, " AND ")
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// 先获取要删除的代理 ID 列表
|
||||
selectSQL := "SELECT id FROM proxies WHERE " + whereClause
|
||||
rows, err := tx.Query(ctx, selectSQL, args...)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to select proxies: %w", err)
|
||||
}
|
||||
|
||||
var ids []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
rows.Close()
|
||||
return 0, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// 删除关联数据
|
||||
if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = ANY($1)", ids); err != nil {
|
||||
return 0, fmt.Errorf("failed to delete leases: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = ANY($1)", ids); err != nil {
|
||||
return 0, fmt.Errorf("failed to delete test logs: %w", err)
|
||||
}
|
||||
|
||||
// 删除代理
|
||||
result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = ANY($1)", ids)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to delete proxies: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return 0, fmt.Errorf("failed to commit: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// GetStats 获取代理统计信息
|
||||
func (s *PgStore) GetStats(ctx context.Context) (*model.ProxyStats, error) {
|
||||
stats := &model.ProxyStats{
|
||||
ByStatus: make(map[model.ProxyStatus]int),
|
||||
ByGroup: make(map[string]int),
|
||||
ByProtocol: make(map[model.ProxyProtocol]int),
|
||||
}
|
||||
|
||||
// 总数和禁用数
|
||||
const sqlBasic = `
|
||||
SELECT
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE disabled = true) AS disabled,
|
||||
COALESCE(ROUND(AVG(latency_ms) FILTER (WHERE status = 'alive' AND latency_ms > 0)), 0)::BIGINT AS avg_latency,
|
||||
COALESCE(AVG(score), 0)::DOUBLE PRECISION AS avg_score
|
||||
FROM proxies;
|
||||
`
|
||||
if err := s.pool.QueryRow(ctx, sqlBasic).Scan(
|
||||
&stats.Total, &stats.Disabled, &stats.AvgLatencyMs, &stats.AvgScore,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to get basic stats: %w", err)
|
||||
}
|
||||
|
||||
// 按状态统计
|
||||
const sqlByStatus = `SELECT status, COUNT(*) FROM proxies GROUP BY status`
|
||||
rows, err := s.pool.Query(ctx, sqlByStatus)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get status stats: %w", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
var status model.ProxyStatus
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
stats.ByStatus[status] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// 按分组统计
|
||||
const sqlByGroup = `SELECT "group", COUNT(*) FROM proxies GROUP BY "group"`
|
||||
rows, err = s.pool.Query(ctx, sqlByGroup)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group stats: %w", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
var group string
|
||||
var count int
|
||||
if err := rows.Scan(&group, &count); err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
stats.ByGroup[group] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// 按协议统计
|
||||
const sqlByProtocol = `SELECT protocol, COUNT(*) FROM proxies GROUP BY protocol`
|
||||
rows, err = s.pool.Query(ctx, sqlByProtocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get protocol stats: %w", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
var protocol model.ProxyProtocol
|
||||
var count int
|
||||
if err := rows.Scan(&protocol, &count); err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
stats.ByProtocol[protocol] = count
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// scanProxies 扫描多条代理记录
|
||||
func scanProxies(rows pgx.Rows) ([]model.Proxy, error) {
|
||||
var proxies []model.Proxy
|
||||
for rows.Next() {
|
||||
p, err := scanProxyRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxies = append(proxies, *p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating rows: %w", err)
|
||||
}
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
// scanProxy 扫描单条代理记录
|
||||
func scanProxy(row pgx.Row) (*model.Proxy, error) {
|
||||
var p model.Proxy
|
||||
var lastCheckAt *time.Time
|
||||
|
||||
err := row.Scan(
|
||||
&p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password,
|
||||
&p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt,
|
||||
&p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastCheckAt != nil {
|
||||
p.LastCheckAt = *lastCheckAt
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// scanProxyRow 从 Rows 扫描单条代理记录
|
||||
func scanProxyRow(rows pgx.Rows) (*model.Proxy, error) {
|
||||
var p model.Proxy
|
||||
var lastCheckAt *time.Time
|
||||
|
||||
err := rows.Scan(
|
||||
&p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password,
|
||||
&p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt,
|
||||
&p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastCheckAt != nil {
|
||||
p.LastCheckAt = *lastCheckAt
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
55
internal/store/store.go
Normal file
55
internal/store/store.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
// ProxyStore 代理存储接口
|
||||
type ProxyStore interface {
|
||||
// UpsertMany 批量导入代理(upsert + 去重统计)
|
||||
UpsertMany(ctx context.Context, proxies []model.Proxy) (imported, duplicated int, err error)
|
||||
|
||||
// List 查询代理列表
|
||||
List(ctx context.Context, q model.ProxyQuery) ([]model.Proxy, error)
|
||||
|
||||
// ListPaginated 分页查询代理列表,返回数据和总数
|
||||
ListPaginated(ctx context.Context, q model.ProxyListQuery) ([]model.Proxy, int, error)
|
||||
|
||||
// GetByID 根据 ID 获取代理
|
||||
GetByID(ctx context.Context, id string) (*model.Proxy, error)
|
||||
|
||||
// UpdateHealth 更新代理健康度
|
||||
UpdateHealth(ctx context.Context, proxyID string, patch model.HealthPatch) error
|
||||
|
||||
// Update 更新代理字段
|
||||
Update(ctx context.Context, id string, patch model.ProxyPatch) error
|
||||
|
||||
// Delete 删除单个代理
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// DeleteMany 批量删除代理
|
||||
DeleteMany(ctx context.Context, req model.BulkDeleteRequest) (int64, error)
|
||||
|
||||
// GetStats 获取代理统计信息
|
||||
GetStats(ctx context.Context) (*model.ProxyStats, error)
|
||||
|
||||
// NextIndex RR 原子游标:返回 [0, modulo) 的索引
|
||||
NextIndex(ctx context.Context, key string, modulo int) (int, error)
|
||||
|
||||
// CreateLease 创建租约
|
||||
CreateLease(ctx context.Context, lease model.Lease) error
|
||||
|
||||
// GetLease 获取租约
|
||||
GetLease(ctx context.Context, leaseID string) (*model.Lease, error)
|
||||
|
||||
// DeleteExpiredLeases 删除过期租约
|
||||
DeleteExpiredLeases(ctx context.Context) (int64, error)
|
||||
|
||||
// InsertTestLog 插入测试日志
|
||||
InsertTestLog(ctx context.Context, r model.TestResult, site string) error
|
||||
|
||||
// Close 关闭连接
|
||||
Close() error
|
||||
}
|
||||
197
internal/telegram/bot.go
Normal file
197
internal/telegram/bot.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/store"
|
||||
|
||||
tele "gopkg.in/telebot.v3"
|
||||
)
|
||||
|
||||
// Bot Telegram Bot 管理器
|
||||
type Bot struct {
|
||||
mu sync.RWMutex
|
||||
bot *tele.Bot
|
||||
cfg *config.Config
|
||||
store store.ProxyStore
|
||||
|
||||
scheduler *Scheduler
|
||||
notifier *Notifier
|
||||
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// Status Bot 状态
|
||||
type Status struct {
|
||||
Running bool `json:"running"`
|
||||
Connected bool `json:"connected"`
|
||||
Username string `json:"username,omitempty"`
|
||||
}
|
||||
|
||||
// NewBot 创建 Bot 实例
|
||||
func NewBot(cfg *config.Config, proxyStore store.ProxyStore) *Bot {
|
||||
return &Bot{
|
||||
cfg: cfg,
|
||||
store: proxyStore,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动 Bot
|
||||
func (b *Bot) Start(ctx context.Context) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.cfg.TelegramBotToken == "" {
|
||||
slog.Info("telegram bot token not configured, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
return b.startInternal()
|
||||
}
|
||||
|
||||
// startInternal 内部启动(需要持有锁)
|
||||
func (b *Bot) startInternal() error {
|
||||
pref := tele.Settings{
|
||||
Token: b.cfg.TelegramBotToken,
|
||||
Poller: &tele.LongPoller{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
bot, err := tele.NewBot(pref)
|
||||
if err != nil {
|
||||
slog.Error("failed to create telegram bot", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
b.bot = bot
|
||||
b.notifier = NewNotifier(bot, b.cfg.TelegramNotifyChatID)
|
||||
b.scheduler = NewScheduler(b.store, b.notifier, b.cfg)
|
||||
|
||||
// 注册命令处理器
|
||||
b.registerCommands(b.cfg.TelegramAdminIDs)
|
||||
|
||||
// 启动调度器
|
||||
b.scheduler.Start()
|
||||
|
||||
// 注册命令菜单
|
||||
commands := []tele.Command{
|
||||
{Text: "stats", Description: "查看代理池统计"},
|
||||
{Text: "groups", Description: "查看分组统计"},
|
||||
{Text: "get", Description: "获取可用代理 (默认1个,如 /get 5)"},
|
||||
{Text: "import", Description: "导入代理 (如 /import groupname)"},
|
||||
{Text: "test", Description: "触发测活 (如 /test groupname)"},
|
||||
{Text: "purge", Description: "清理死代理"},
|
||||
{Text: "help", Description: "显示帮助信息"},
|
||||
}
|
||||
if err := bot.SetCommands(commands); err != nil {
|
||||
slog.Warn("failed to set bot commands", "error", err)
|
||||
}
|
||||
|
||||
// 启动 Bot
|
||||
b.stopChan = make(chan struct{})
|
||||
go func() {
|
||||
slog.Info("telegram bot started", "username", bot.Me.Username)
|
||||
bot.Start()
|
||||
}()
|
||||
|
||||
b.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止 Bot
|
||||
func (b *Bot) Stop() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.stopInternal()
|
||||
}
|
||||
|
||||
// stopInternal 内部停止(需要持有锁)
|
||||
func (b *Bot) stopInternal() {
|
||||
if !b.running {
|
||||
return
|
||||
}
|
||||
|
||||
if b.scheduler != nil {
|
||||
b.scheduler.Stop()
|
||||
}
|
||||
|
||||
if b.bot != nil {
|
||||
b.bot.Stop()
|
||||
slog.Info("telegram bot stopped")
|
||||
}
|
||||
|
||||
close(b.stopChan)
|
||||
b.running = false
|
||||
}
|
||||
|
||||
// Status 获取 Bot 状态
|
||||
func (b *Bot) Status() Status {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
status := Status{
|
||||
Running: b.running,
|
||||
}
|
||||
|
||||
if b.bot != nil && b.running {
|
||||
status.Connected = true
|
||||
status.Username = b.bot.Me.Username
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// TriggerTest 手动触发测活
|
||||
func (b *Bot) TriggerTest(ctx context.Context) error {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
if b.scheduler == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return b.scheduler.RunTest(ctx)
|
||||
}
|
||||
|
||||
// registerCommands 注册命令
|
||||
func (b *Bot) registerCommands(adminIDs []int64) {
|
||||
// 管理员权限中间件
|
||||
adminOnly := func(next tele.HandlerFunc) tele.HandlerFunc {
|
||||
return func(c tele.Context) error {
|
||||
if len(adminIDs) == 0 {
|
||||
return next(c)
|
||||
}
|
||||
userID := c.Sender().ID
|
||||
for _, id := range adminIDs {
|
||||
if id == userID {
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
return c.Send("⛔ 无权限访问")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建命令处理器
|
||||
cmds := NewCommands(b.store, b.scheduler)
|
||||
|
||||
b.bot.Handle("/start", adminOnly(cmds.HandleStart))
|
||||
b.bot.Handle("/help", adminOnly(cmds.HandleHelp))
|
||||
b.bot.Handle("/stats", adminOnly(cmds.HandleStats))
|
||||
b.bot.Handle("/groups", adminOnly(cmds.HandleGroups))
|
||||
b.bot.Handle("/get", adminOnly(cmds.HandleGet))
|
||||
b.bot.Handle("/test", adminOnly(cmds.HandleTest))
|
||||
b.bot.Handle("/purge", adminOnly(cmds.HandlePurge))
|
||||
b.bot.Handle("/import", adminOnly(cmds.HandleImport))
|
||||
b.bot.Handle(tele.OnDocument, adminOnly(cmds.HandleDocument))
|
||||
b.bot.Handle(tele.OnText, adminOnly(cmds.HandleText))
|
||||
}
|
||||
365
internal/telegram/commands.go
Normal file
365
internal/telegram/commands.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/importer"
|
||||
"proxyrotator/internal/model"
|
||||
"proxyrotator/internal/store"
|
||||
|
||||
tele "gopkg.in/telebot.v3"
|
||||
)
|
||||
|
||||
// Commands 命令处理器
|
||||
type Commands struct {
|
||||
store store.ProxyStore
|
||||
scheduler *Scheduler
|
||||
importer *importer.Importer
|
||||
|
||||
// 导入状态
|
||||
importState map[int64]*importSession
|
||||
}
|
||||
|
||||
type importSession struct {
|
||||
Group string
|
||||
Tags []string
|
||||
}
|
||||
|
||||
// NewCommands 创建命令处理器
|
||||
func NewCommands(store store.ProxyStore, scheduler *Scheduler) *Commands {
|
||||
return &Commands{
|
||||
store: store,
|
||||
scheduler: scheduler,
|
||||
importer: importer.NewImporter(),
|
||||
importState: make(map[int64]*importSession),
|
||||
}
|
||||
}
|
||||
|
||||
// HandleStart /start 命令
|
||||
func (c *Commands) HandleStart(ctx tele.Context) error {
|
||||
return ctx.Send(`🚀 *ProxyRotator Bot*
|
||||
|
||||
欢迎使用代理池管理机器人!
|
||||
|
||||
使用 /help 查看可用命令`, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// HandleHelp /help 命令
|
||||
func (c *Commands) HandleHelp(ctx tele.Context) error {
|
||||
help := `📖 *可用命令*
|
||||
|
||||
*查询类*
|
||||
/stats - 代理池统计(总数/存活/死亡/未知)
|
||||
/groups - 分组统计
|
||||
/get [n] - 获取 n 个可用代理(默认 5)
|
||||
|
||||
*操作类*
|
||||
/import [group] - 导入代理(之后发送文本或文件)
|
||||
/test [group] - 触发测活
|
||||
/purge - 清理死代理
|
||||
|
||||
*其他*
|
||||
/help - 显示帮助信息`
|
||||
|
||||
return ctx.Send(help, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// HandleStats /stats 命令
|
||||
func (c *Commands) HandleStats(ctx tele.Context) error {
|
||||
stats, err := c.store.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 获取统计失败: %v", err))
|
||||
}
|
||||
|
||||
alive := stats.ByStatus[model.StatusAlive]
|
||||
dead := stats.ByStatus[model.StatusDead]
|
||||
unknown := stats.ByStatus[model.StatusUnknown]
|
||||
|
||||
var alivePercent float64
|
||||
if stats.Total > 0 {
|
||||
alivePercent = float64(alive) / float64(stats.Total) * 100
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`📊 *代理池统计*
|
||||
|
||||
*总数:* %d
|
||||
*存活:* %d (%.1f%%)
|
||||
*死亡:* %d
|
||||
*未知:* %d
|
||||
*禁用:* %d
|
||||
|
||||
*平均延迟:* %d ms
|
||||
*平均分数:* %.1f`,
|
||||
stats.Total,
|
||||
alive, alivePercent,
|
||||
dead,
|
||||
unknown,
|
||||
stats.Disabled,
|
||||
stats.AvgLatencyMs,
|
||||
stats.AvgScore,
|
||||
)
|
||||
|
||||
return ctx.Send(msg, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// HandleGroups /groups 命令
|
||||
func (c *Commands) HandleGroups(ctx tele.Context) error {
|
||||
stats, err := c.store.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 获取统计失败: %v", err))
|
||||
}
|
||||
|
||||
if len(stats.ByGroup) == 0 {
|
||||
return ctx.Send("📁 暂无分组数据")
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("📁 *分组统计*\n\n")
|
||||
for group, count := range stats.ByGroup {
|
||||
sb.WriteString(fmt.Sprintf("• `%s`: %d\n", group, count))
|
||||
}
|
||||
|
||||
return ctx.Send(sb.String(), &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// IPInfo ipinfo.io 返回结构
|
||||
type IPInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
Org string `json:"org"`
|
||||
}
|
||||
|
||||
// HandleGet /get [n] 命令
|
||||
func (c *Commands) HandleGet(ctx tele.Context) error {
|
||||
n := 1
|
||||
args := ctx.Args()
|
||||
if len(args) > 0 {
|
||||
if parsed, err := strconv.Atoi(args[0]); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
|
||||
proxies, err := c.store.List(context.Background(), model.ProxyQuery{
|
||||
StatusIn: []model.ProxyStatus{model.StatusAlive},
|
||||
OnlyEnabled: true,
|
||||
OrderBy: "random",
|
||||
Limit: n,
|
||||
})
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 获取代理失败: %v", err))
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
return ctx.Send("😢 没有可用代理")
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("🔗 *可用代理 (%d)*\n\n", len(proxies)))
|
||||
|
||||
for _, p := range proxies {
|
||||
var proxyURL string
|
||||
if p.Username != "" {
|
||||
proxyURL = fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
|
||||
} else {
|
||||
proxyURL = fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
|
||||
}
|
||||
|
||||
// 获取 IP 位置信息
|
||||
ipInfo := fetchIPInfo(proxyURL)
|
||||
|
||||
sb.WriteString(fmt.Sprintf("`%s`\n", proxyURL))
|
||||
if ipInfo != nil {
|
||||
location := fmt.Sprintf("%s, %s, %s", ipInfo.City, ipInfo.Region, ipInfo.Country)
|
||||
sb.WriteString(fmt.Sprintf(" 📍 %s | %s\n", location, ipInfo.Org))
|
||||
} else {
|
||||
sb.WriteString(" 📍 位置获取失败\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return ctx.Send(sb.String(), &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// fetchIPInfo 通过代理获取 IP 信息
|
||||
func fetchIPInfo(proxyURL string) *IPInfo {
|
||||
proxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxy),
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get("https://ipinfo.io/json")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var info IPInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &info
|
||||
}
|
||||
|
||||
// HandleTest /test [group] 命令
|
||||
func (c *Commands) HandleTest(ctx tele.Context) error {
|
||||
if c.scheduler == nil {
|
||||
return ctx.Send("❌ 调度器未初始化")
|
||||
}
|
||||
|
||||
group := ""
|
||||
args := ctx.Args()
|
||||
if len(args) > 0 {
|
||||
group = args[0]
|
||||
}
|
||||
|
||||
_ = ctx.Send("🔄 正在执行测活...")
|
||||
|
||||
err := c.scheduler.RunTestWithGroup(context.Background(), group)
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 测活失败: %v", err))
|
||||
}
|
||||
|
||||
return ctx.Send("✅ 测活完成")
|
||||
}
|
||||
|
||||
// HandlePurge /purge 命令
|
||||
func (c *Commands) HandlePurge(ctx tele.Context) error {
|
||||
deleted, err := c.store.DeleteMany(context.Background(), model.BulkDeleteRequest{
|
||||
Status: model.StatusDead,
|
||||
})
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 清理失败: %v", err))
|
||||
}
|
||||
|
||||
return ctx.Send(fmt.Sprintf("🗑️ 已清理 %d 个死代理", deleted))
|
||||
}
|
||||
|
||||
// HandleImport /import [group] 命令
|
||||
func (c *Commands) HandleImport(ctx tele.Context) error {
|
||||
group := "default"
|
||||
args := ctx.Args()
|
||||
if len(args) > 0 {
|
||||
group = args[0]
|
||||
}
|
||||
|
||||
userID := ctx.Sender().ID
|
||||
c.importState[userID] = &importSession{
|
||||
Group: group,
|
||||
Tags: []string{"telegram"},
|
||||
}
|
||||
|
||||
return ctx.Send(fmt.Sprintf(`📥 *导入模式已开启*
|
||||
|
||||
分组: `+"`%s`"+`
|
||||
|
||||
请发送代理列表(文本或文件),支持格式:
|
||||
• host:port
|
||||
• host:port:user:pass
|
||||
• protocol://host:port
|
||||
• protocol://user:pass@host:port
|
||||
|
||||
发送 /cancel 取消导入`, group), &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
|
||||
// HandleDocument 处理文件上传
|
||||
func (c *Commands) HandleDocument(ctx tele.Context) error {
|
||||
userID := ctx.Sender().ID
|
||||
session, ok := c.importState[userID]
|
||||
if !ok {
|
||||
return nil // 不在导入模式,忽略
|
||||
}
|
||||
|
||||
doc := ctx.Message().Document
|
||||
if doc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 下载文件
|
||||
reader, err := ctx.Bot().File(&doc.File)
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 获取文件失败: %v", err))
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 读取文件失败: %v", err))
|
||||
}
|
||||
|
||||
return c.doImport(ctx, session, string(content))
|
||||
}
|
||||
|
||||
// HandleText 处理文本消息(用于导入)
|
||||
func (c *Commands) HandleText(ctx tele.Context) error {
|
||||
userID := ctx.Sender().ID
|
||||
session, ok := c.importState[userID]
|
||||
if !ok {
|
||||
return nil // 不在导入模式,忽略
|
||||
}
|
||||
|
||||
text := ctx.Text()
|
||||
if text == "/cancel" {
|
||||
delete(c.importState, userID)
|
||||
return ctx.Send("❌ 已取消导入")
|
||||
}
|
||||
|
||||
// 检查是否像代理格式
|
||||
if !strings.Contains(text, ":") {
|
||||
return nil // 不像代理,忽略
|
||||
}
|
||||
|
||||
return c.doImport(ctx, session, text)
|
||||
}
|
||||
|
||||
// doImport 执行导入
|
||||
func (c *Commands) doImport(ctx tele.Context, session *importSession, text string) error {
|
||||
userID := ctx.Sender().ID
|
||||
defer delete(c.importState, userID)
|
||||
|
||||
input := model.ImportInput{
|
||||
Group: session.Group,
|
||||
Tags: session.Tags,
|
||||
}
|
||||
|
||||
proxies, invalid := c.importer.ParseText(context.Background(), input, text)
|
||||
|
||||
if len(proxies) == 0 {
|
||||
return ctx.Send(fmt.Sprintf("❌ 未解析到有效代理\n无效行: %d", len(invalid)))
|
||||
}
|
||||
|
||||
imported, duplicated, err := c.store.UpsertMany(context.Background(), proxies)
|
||||
if err != nil {
|
||||
return ctx.Send(fmt.Sprintf("❌ 导入失败: %v", err))
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`✅ *导入完成*
|
||||
|
||||
• 新增: %d
|
||||
• 重复: %d
|
||||
• 无效: %d
|
||||
• 分组: `+"`%s`",
|
||||
imported, duplicated, len(invalid), session.Group)
|
||||
|
||||
return ctx.Send(msg, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
}
|
||||
83
internal/telegram/notifier.go
Normal file
83
internal/telegram/notifier.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
|
||||
tele "gopkg.in/telebot.v3"
|
||||
)
|
||||
|
||||
// Notifier 告警通知器
|
||||
type Notifier struct {
|
||||
bot *tele.Bot
|
||||
chatID string
|
||||
}
|
||||
|
||||
// NewNotifier 创建通知器
|
||||
func NewNotifier(bot *tele.Bot, chatID string) *Notifier {
|
||||
return &Notifier{
|
||||
bot: bot,
|
||||
chatID: chatID,
|
||||
}
|
||||
}
|
||||
|
||||
// SendAlert 发送告警
|
||||
func (n *Notifier) SendAlert(ctx context.Context, alive, dead, total int, alivePercent float64) {
|
||||
if n.chatID == "" {
|
||||
slog.Warn("notify_chat_id not configured, skipping alert")
|
||||
return
|
||||
}
|
||||
|
||||
chatID, err := strconv.ParseInt(n.chatID, 10, 64)
|
||||
if err != nil {
|
||||
slog.Error("invalid chat_id", "chat_id", n.chatID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(`🚨 *代理池告警*
|
||||
|
||||
存活率低于阈值!
|
||||
|
||||
*统计:*
|
||||
• 存活: %d (%.1f%%)
|
||||
• 死亡: %d
|
||||
• 总数: %d
|
||||
|
||||
请及时补充代理或检查网络状况。`, alive, alivePercent, dead, total)
|
||||
|
||||
chat, err := n.bot.ChatByID(chatID)
|
||||
if err != nil {
|
||||
slog.Error("failed to get chat", "chat_id", n.chatID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = n.bot.Send(chat, msg, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
if err != nil {
|
||||
slog.Error("failed to send alert", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("alert sent", "chat_id", n.chatID, "alive_percent", alivePercent)
|
||||
}
|
||||
|
||||
// SendMessage 发送普通消息
|
||||
func (n *Notifier) SendMessage(ctx context.Context, message string) error {
|
||||
if n.chatID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
chatID, err := strconv.ParseInt(n.chatID, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chat, err := n.bot.ChatByID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = n.bot.Send(chat, message, &tele.SendOptions{ParseMode: tele.ModeMarkdown})
|
||||
return err
|
||||
}
|
||||
167
internal/telegram/scheduler.go
Normal file
167
internal/telegram/scheduler.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package telegram
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"proxyrotator/internal/config"
|
||||
"proxyrotator/internal/model"
|
||||
"proxyrotator/internal/store"
|
||||
"proxyrotator/internal/tester"
|
||||
)
|
||||
|
||||
// Scheduler 定时测活调度器
|
||||
type Scheduler struct {
|
||||
mu sync.Mutex
|
||||
store store.ProxyStore
|
||||
notifier *Notifier
|
||||
tester *tester.HTTPTester
|
||||
cfg *config.Config
|
||||
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewScheduler 创建调度器
|
||||
func NewScheduler(store store.ProxyStore, notifier *Notifier, cfg *config.Config) *Scheduler {
|
||||
return &Scheduler{
|
||||
store: store,
|
||||
notifier: notifier,
|
||||
tester: tester.NewHTTPTester(),
|
||||
cfg: cfg,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动调度器
|
||||
func (s *Scheduler) Start() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return
|
||||
}
|
||||
|
||||
interval := time.Duration(s.cfg.TelegramTestIntervalMin) * time.Minute
|
||||
if interval < 5*time.Minute {
|
||||
interval = 5 * time.Minute
|
||||
}
|
||||
|
||||
s.ticker = time.NewTicker(interval)
|
||||
s.stopChan = make(chan struct{})
|
||||
s.running = true
|
||||
|
||||
go s.loop()
|
||||
slog.Info("telegram scheduler started", "interval", interval)
|
||||
}
|
||||
|
||||
// Stop 停止调度器
|
||||
func (s *Scheduler) Stop() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return
|
||||
}
|
||||
|
||||
if s.ticker != nil {
|
||||
s.ticker.Stop()
|
||||
}
|
||||
close(s.stopChan)
|
||||
s.running = false
|
||||
slog.Info("telegram scheduler stopped")
|
||||
}
|
||||
|
||||
// loop 调度循环
|
||||
func (s *Scheduler) loop() {
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-s.ticker.C:
|
||||
if err := s.RunTest(context.Background()); err != nil {
|
||||
slog.Error("scheduled test failed", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunTest 执行测活(所有分组)
|
||||
func (s *Scheduler) RunTest(ctx context.Context) error {
|
||||
return s.RunTestWithGroup(ctx, "")
|
||||
}
|
||||
|
||||
// RunTestWithGroup 执行测活(指定分组)
|
||||
func (s *Scheduler) RunTestWithGroup(ctx context.Context, group string) error {
|
||||
slog.Info("running scheduled proxy test", "group", group)
|
||||
|
||||
// 获取待测试代理
|
||||
query := model.ProxyQuery{
|
||||
Group: group,
|
||||
StatusIn: []model.ProxyStatus{model.StatusUnknown, model.StatusAlive},
|
||||
OnlyEnabled: true,
|
||||
Limit: 1000,
|
||||
}
|
||||
|
||||
proxies, err := s.store.List(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
slog.Info("no proxies to test")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建测试规格
|
||||
spec := model.TestSpec{
|
||||
URL: s.cfg.TelegramTestURL,
|
||||
Method: "GET",
|
||||
Timeout: time.Duration(s.cfg.TelegramTestTimeoutMs) * time.Millisecond,
|
||||
}
|
||||
|
||||
// 执行测试
|
||||
results := s.tester.TestBatch(ctx, proxies, spec, 50)
|
||||
|
||||
// 统计结果
|
||||
alive, dead := 0, 0
|
||||
for _, r := range results {
|
||||
now := r.CheckedAt
|
||||
if r.OK {
|
||||
alive++
|
||||
status := model.StatusAlive
|
||||
_ = s.store.UpdateHealth(ctx, r.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: 1,
|
||||
SuccessInc: 1,
|
||||
LatencyMs: &r.LatencyMs,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
} else {
|
||||
dead++
|
||||
status := model.StatusDead
|
||||
_ = s.store.UpdateHealth(ctx, r.ProxyID, model.HealthPatch{
|
||||
Status: &status,
|
||||
ScoreDelta: -3,
|
||||
FailInc: 1,
|
||||
CheckedAt: &now,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("scheduled test completed", "tested", len(results), "alive", alive, "dead", dead)
|
||||
|
||||
// 检查是否需要告警
|
||||
total := len(results)
|
||||
if total > 0 {
|
||||
alivePercent := float64(alive) / float64(total) * 100
|
||||
if alivePercent < float64(s.cfg.TelegramAlertThreshold) {
|
||||
s.notifier.SendAlert(ctx, alive, dead, total, alivePercent)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
246
internal/tester/http_tester.go
Normal file
246
internal/tester/http_tester.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package tester
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"proxyrotator/internal/model"
|
||||
)
|
||||
|
||||
// HTTPTester HTTP 代理测试器
|
||||
type HTTPTester struct {
|
||||
maxBodySize int64
|
||||
}
|
||||
|
||||
// NewHTTPTester 创建测试器
|
||||
func NewHTTPTester() *HTTPTester {
|
||||
return &HTTPTester{
|
||||
maxBodySize: 1024 * 1024, // 1MB
|
||||
}
|
||||
}
|
||||
|
||||
// TestOne 测试单个代理
|
||||
func (t *HTTPTester) TestOne(ctx context.Context, p model.Proxy, spec model.TestSpec) model.TestResult {
|
||||
result := model.TestResult{
|
||||
ProxyID: p.ID,
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// 创建 HTTP 客户端
|
||||
client, err := t.createClient(p, spec.Timeout)
|
||||
if err != nil {
|
||||
result.ErrorText = err.Error()
|
||||
result.LatencyMs = time.Since(start).Milliseconds()
|
||||
return result
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
method := spec.Method
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, spec.URL, nil)
|
||||
if err != nil {
|
||||
result.ErrorText = fmt.Sprintf("create request failed: %v", err)
|
||||
result.LatencyMs = time.Since(start).Milliseconds()
|
||||
return result
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
||||
|
||||
// 发起请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
result.ErrorText = err.Error()
|
||||
result.LatencyMs = time.Since(start).Milliseconds()
|
||||
return result
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
result.LatencyMs = time.Since(start).Milliseconds()
|
||||
|
||||
// 检查状态码
|
||||
if len(spec.ExpectStatus) > 0 {
|
||||
found := false
|
||||
for _, s := range spec.ExpectStatus {
|
||||
if resp.StatusCode == s {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
result.ErrorText = fmt.Sprintf("unexpected status: %d", resp.StatusCode)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// 检查响应体关键字
|
||||
if spec.ExpectContains != "" {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, t.maxBodySize))
|
||||
if err != nil {
|
||||
result.ErrorText = fmt.Sprintf("read body failed: %v", err)
|
||||
return result
|
||||
}
|
||||
if !strings.Contains(string(body), spec.ExpectContains) {
|
||||
result.ErrorText = "expected content not found"
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
result.OK = true
|
||||
return result
|
||||
}
|
||||
|
||||
// TestBatch 并发测试多个代理
|
||||
func (t *HTTPTester) TestBatch(ctx context.Context, proxies []model.Proxy, spec model.TestSpec, concurrency int) []model.TestResult {
|
||||
if concurrency <= 0 {
|
||||
concurrency = 10
|
||||
}
|
||||
|
||||
jobs := make(chan model.Proxy, len(proxies))
|
||||
results := make(chan model.TestResult, len(proxies))
|
||||
|
||||
// 启动 worker
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for p := range jobs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
results <- model.TestResult{
|
||||
ProxyID: p.ID,
|
||||
ErrorText: "context cancelled",
|
||||
CheckedAt: time.Now(),
|
||||
}
|
||||
default:
|
||||
results <- t.TestOne(ctx, p, spec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 发送任务
|
||||
go func() {
|
||||
for _, p := range proxies {
|
||||
jobs <- p
|
||||
}
|
||||
close(jobs)
|
||||
}()
|
||||
|
||||
// 等待完成
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(results)
|
||||
}()
|
||||
|
||||
// 收集结果
|
||||
out := make([]model.TestResult, 0, len(proxies))
|
||||
for r := range results {
|
||||
out = append(out, r)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// createClient 根据代理类型创建 HTTP 客户端
|
||||
func (t *HTTPTester) createClient(p model.Proxy, timeout time.Duration) (*http.Client, error) {
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
DisableKeepAlives: true,
|
||||
TLSHandshakeTimeout: timeout,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
// 代理连接设置
|
||||
ProxyConnectHeader: http.Header{},
|
||||
}
|
||||
|
||||
switch p.Protocol {
|
||||
case model.ProtoHTTP, model.ProtoHTTPS:
|
||||
proxyURL := t.buildProxyURL(p)
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
|
||||
case model.ProtoSOCKS5:
|
||||
dialer, err := t.createSOCKS5Dialer(p, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
// 不自动跟随重定向,让我们检查原始响应
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("too many redirects")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildProxyURL 构建代理 URL
|
||||
func (t *HTTPTester) buildProxyURL(p model.Proxy) *url.URL {
|
||||
scheme := "http"
|
||||
if p.Protocol == model.ProtoHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: fmt.Sprintf("%s:%d", p.Host, p.Port),
|
||||
}
|
||||
|
||||
if p.Username != "" {
|
||||
u.User = url.UserPassword(p.Username, p.Password)
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// createSOCKS5Dialer 创建 SOCKS5 拨号器
|
||||
func (t *HTTPTester) createSOCKS5Dialer(p model.Proxy, timeout time.Duration) (proxy.Dialer, error) {
|
||||
addr := fmt.Sprintf("%s:%d", p.Host, p.Port)
|
||||
|
||||
var auth *proxy.Auth
|
||||
if p.Username != "" {
|
||||
auth = &proxy.Auth{
|
||||
User: p.Username,
|
||||
Password: p.Password,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建基础拨号器带超时
|
||||
baseDialer := &net.Dialer{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return proxy.SOCKS5("tcp", addr, auth, baseDialer)
|
||||
}
|
||||
Reference in New Issue
Block a user