Files
ProxyPool/internal/store/pg_store.go
2026-01-31 22:53:12 +08:00

699 lines
17 KiB
Go

package store
import (
"context"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"proxyrotator/internal/model"
)
// PgStore PostgreSQL 存储实现
type PgStore struct {
pool *pgxpool.Pool
}
// NewPgStore 创建 PostgreSQL 存储
func NewPgStore(ctx context.Context, connString string) (*PgStore, error) {
pool, err := pgxpool.New(ctx, connString)
if err != nil {
return nil, fmt.Errorf("failed to create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &PgStore{pool: pool}, nil
}
// Close 关闭连接池
func (s *PgStore) Close() error {
s.pool.Close()
return nil
}
// Pool 返回连接池(用于其他模块共享)
func (s *PgStore) Pool() *pgxpool.Pool {
return s.pool
}
// UpsertMany 批量导入代理
func (s *PgStore) UpsertMany(ctx context.Context, proxies []model.Proxy) (imported, duplicated int, err error) {
if len(proxies) == 0 {
return 0, 0, nil
}
const sqlUpsert = `
INSERT INTO proxies (id, protocol, host, port, username, password, "group", tags)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (protocol, host, port, username)
DO UPDATE SET
password = EXCLUDED.password,
"group" = EXCLUDED."group",
tags = (
SELECT ARRAY(
SELECT DISTINCT unnest(proxies.tags || EXCLUDED.tags)
)
)
RETURNING (xmax = 0) AS inserted
`
tx, err := s.pool.Begin(ctx)
if err != nil {
return 0, 0, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
for _, p := range proxies {
var inserted bool
err := tx.QueryRow(ctx, sqlUpsert,
p.ID, p.Protocol, p.Host, p.Port, p.Username, p.Password, p.Group, p.Tags,
).Scan(&inserted)
if err != nil {
return 0, 0, fmt.Errorf("failed to upsert proxy: %w", err)
}
if inserted {
imported++
} else {
duplicated++
}
}
if err := tx.Commit(ctx); err != nil {
return 0, 0, fmt.Errorf("failed to commit transaction: %w", err)
}
return imported, duplicated, nil
}
// List 查询代理列表
func (s *PgStore) List(ctx context.Context, q model.ProxyQuery) ([]model.Proxy, error) {
var conditions []string
var args []any
argIdx := 1
if q.Group != "" {
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
args = append(args, q.Group)
argIdx++
}
if q.OnlyEnabled {
conditions = append(conditions, "disabled = false")
}
if len(q.StatusIn) > 0 {
placeholders := make([]string, len(q.StatusIn))
for i, status := range q.StatusIn {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, status)
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if len(q.TagsAny) > 0 {
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx))
args = append(args, q.TagsAny)
argIdx++
}
sql := `SELECT id, protocol, host, port, username, password, "group", tags,
status, score, latency_ms, last_check_at, fail_count, success_count,
disabled, created_at, updated_at
FROM proxies`
if len(conditions) > 0 {
sql += " WHERE " + strings.Join(conditions, " AND ")
}
// 排序方式
switch q.OrderBy {
case "random":
sql += " ORDER BY RANDOM()"
case "latency":
sql += " ORDER BY latency_ms ASC NULLS LAST"
default:
sql += " ORDER BY score DESC, last_check_at DESC NULLS LAST"
}
if q.Limit > 0 {
sql += fmt.Sprintf(" LIMIT %d", q.Limit)
}
rows, err := s.pool.Query(ctx, sql, args...)
if err != nil {
return nil, fmt.Errorf("failed to query proxies: %w", err)
}
defer rows.Close()
return scanProxies(rows)
}
// GetByID 根据 ID 获取代理
func (s *PgStore) GetByID(ctx context.Context, id string) (*model.Proxy, error) {
const sql = `SELECT id, protocol, host, port, username, password, "group", tags,
status, score, latency_ms, last_check_at, fail_count, success_count,
disabled, created_at, updated_at
FROM proxies WHERE id = $1`
row := s.pool.QueryRow(ctx, sql, id)
p, err := scanProxy(row)
if err != nil {
if err == pgx.ErrNoRows {
return nil, model.ErrProxyNotFound
}
return nil, fmt.Errorf("failed to get proxy: %w", err)
}
return p, nil
}
// UpdateHealth 更新代理健康度
func (s *PgStore) UpdateHealth(ctx context.Context, proxyID string, patch model.HealthPatch) error {
var sets []string
var args []any
argIdx := 1
if patch.Status != nil {
sets = append(sets, fmt.Sprintf("status = $%d", argIdx))
args = append(args, *patch.Status)
argIdx++
}
if patch.ScoreDelta != 0 {
sets = append(sets, fmt.Sprintf("score = GREATEST(-1000, LEAST(1000, score + $%d))", argIdx))
args = append(args, patch.ScoreDelta)
argIdx++
}
if patch.LatencyMs != nil {
sets = append(sets, fmt.Sprintf("latency_ms = $%d", argIdx))
args = append(args, *patch.LatencyMs)
argIdx++
}
if patch.CheckedAt != nil {
sets = append(sets, fmt.Sprintf("last_check_at = $%d", argIdx))
args = append(args, *patch.CheckedAt)
argIdx++
}
if patch.FailInc > 0 {
sets = append(sets, fmt.Sprintf("fail_count = fail_count + $%d", argIdx))
args = append(args, patch.FailInc)
argIdx++
}
if patch.SuccessInc > 0 {
sets = append(sets, fmt.Sprintf("success_count = success_count + $%d", argIdx))
args = append(args, patch.SuccessInc)
argIdx++
}
if len(sets) == 0 {
return nil
}
sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx)
args = append(args, proxyID)
_, err := s.pool.Exec(ctx, sql, args...)
if err != nil {
return fmt.Errorf("failed to update health: %w", err)
}
return nil
}
// NextIndex RR 原子游标
func (s *PgStore) NextIndex(ctx context.Context, key string, modulo int) (int, error) {
if modulo <= 0 {
return 0, model.ErrBadModulo
}
const sql = `
INSERT INTO rr_cursors (k, v)
VALUES ($1, 0)
ON CONFLICT (k)
DO UPDATE SET v = rr_cursors.v + 1, updated_at = now()
RETURNING v
`
var v int64
err := s.pool.QueryRow(ctx, sql, key).Scan(&v)
if err != nil {
return 0, fmt.Errorf("failed to get next index: %w", err)
}
idx := int(v % int64(modulo))
if idx < 0 {
idx = -idx
}
return idx, nil
}
// CreateLease 创建租约
func (s *PgStore) CreateLease(ctx context.Context, lease model.Lease) error {
const sql = `
INSERT INTO proxy_leases (lease_id, proxy_id, expire_at, site, "group")
VALUES ($1, $2, $3, $4, $5)
`
_, err := s.pool.Exec(ctx, sql,
lease.LeaseID, lease.ProxyID, lease.ExpireAt, lease.Site, lease.Group,
)
if err != nil {
return fmt.Errorf("failed to create lease: %w", err)
}
return nil
}
// GetLease 获取租约
func (s *PgStore) GetLease(ctx context.Context, leaseID string) (*model.Lease, error) {
const sql = `
SELECT lease_id, proxy_id, expire_at, site, "group"
FROM proxy_leases
WHERE lease_id = $1 AND expire_at > now()
`
var lease model.Lease
err := s.pool.QueryRow(ctx, sql, leaseID).Scan(
&lease.LeaseID, &lease.ProxyID, &lease.ExpireAt, &lease.Site, &lease.Group,
)
if err != nil {
if err == pgx.ErrNoRows {
return nil, model.ErrLeaseExpired
}
return nil, fmt.Errorf("failed to get lease: %w", err)
}
return &lease, nil
}
// DeleteExpiredLeases 删除过期租约
func (s *PgStore) DeleteExpiredLeases(ctx context.Context) (int64, error) {
const sql = `DELETE FROM proxy_leases WHERE expire_at <= now()`
result, err := s.pool.Exec(ctx, sql)
if err != nil {
return 0, fmt.Errorf("failed to delete expired leases: %w", err)
}
return result.RowsAffected(), nil
}
// InsertTestLog 插入测试日志
func (s *PgStore) InsertTestLog(ctx context.Context, r model.TestResult, site string) error {
const sql = `
INSERT INTO proxy_test_logs (proxy_id, site, ok, latency_ms, error_text, checked_at)
VALUES ($1, $2, $3, $4, $5, $6)
`
_, err := s.pool.Exec(ctx, sql,
r.ProxyID, site, r.OK, r.LatencyMs, r.ErrorText, r.CheckedAt,
)
if err != nil {
return fmt.Errorf("failed to insert test log: %w", err)
}
return nil
}
// ListPaginated 分页查询代理列表
func (s *PgStore) ListPaginated(ctx context.Context, q model.ProxyListQuery) ([]model.Proxy, int, error) {
var conditions []string
var args []any
argIdx := 1
if q.Group != "" {
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
args = append(args, q.Group)
argIdx++
}
if q.OnlyEnabled {
conditions = append(conditions, "disabled = false")
}
if len(q.StatusIn) > 0 {
placeholders := make([]string, len(q.StatusIn))
for i, status := range q.StatusIn {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, status)
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if len(q.TagsAny) > 0 {
conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx))
args = append(args, q.TagsAny)
argIdx++
}
whereClause := ""
if len(conditions) > 0 {
whereClause = " WHERE " + strings.Join(conditions, " AND ")
}
// 查询总数
countSQL := "SELECT COUNT(*) FROM proxies" + whereClause
var total int
if err := s.pool.QueryRow(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count proxies: %w", err)
}
// 查询数据
dataSQL := `SELECT id, protocol, host, port, username, password, "group", tags,
status, score, latency_ms, last_check_at, fail_count, success_count,
disabled, created_at, updated_at FROM proxies` + whereClause +
" ORDER BY score DESC, last_check_at DESC NULLS LAST"
dataSQL += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := s.pool.Query(ctx, dataSQL, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query proxies: %w", err)
}
defer rows.Close()
proxies, err := scanProxies(rows)
if err != nil {
return nil, 0, err
}
return proxies, total, nil
}
// Update 更新代理字段
func (s *PgStore) Update(ctx context.Context, id string, patch model.ProxyPatch) error {
var sets []string
var args []any
argIdx := 1
if patch.Group != nil {
sets = append(sets, fmt.Sprintf(`"group" = $%d`, argIdx))
args = append(args, *patch.Group)
argIdx++
}
if patch.Tags != nil {
sets = append(sets, fmt.Sprintf("tags = $%d", argIdx))
args = append(args, *patch.Tags)
argIdx++
}
if len(patch.AddTags) > 0 {
sets = append(sets, fmt.Sprintf(`tags = (
SELECT ARRAY(SELECT DISTINCT unnest(tags || $%d))
)`, argIdx))
args = append(args, patch.AddTags)
argIdx++
}
if patch.Disabled != nil {
sets = append(sets, fmt.Sprintf("disabled = $%d", argIdx))
args = append(args, *patch.Disabled)
argIdx++
}
if len(sets) == 0 {
return model.ErrInvalidPatch
}
sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx)
args = append(args, id)
result, err := s.pool.Exec(ctx, sql, args...)
if err != nil {
return fmt.Errorf("failed to update proxy: %w", err)
}
if result.RowsAffected() == 0 {
return model.ErrProxyNotFound
}
return nil
}
// Delete 删除单个代理
func (s *PgStore) Delete(ctx context.Context, id string) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
// 删除租约
if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = $1", id); err != nil {
return fmt.Errorf("failed to delete leases: %w", err)
}
// 删除测试日志
if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = $1", id); err != nil {
return fmt.Errorf("failed to delete test logs: %w", err)
}
// 删除代理
result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = $1", id)
if err != nil {
return fmt.Errorf("failed to delete proxy: %w", err)
}
if result.RowsAffected() == 0 {
return model.ErrProxyNotFound
}
return tx.Commit(ctx)
}
// DeleteMany 批量删除代理
func (s *PgStore) DeleteMany(ctx context.Context, req model.BulkDeleteRequest) (int64, error) {
var conditions []string
var args []any
argIdx := 1
if len(req.IDs) > 0 {
conditions = append(conditions, fmt.Sprintf("id = ANY($%d)", argIdx))
args = append(args, req.IDs)
argIdx++
}
if req.Status != "" {
conditions = append(conditions, fmt.Sprintf("status = $%d", argIdx))
args = append(args, req.Status)
argIdx++
}
if req.Group != "" {
conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx))
args = append(args, req.Group)
argIdx++
}
if req.Disabled != nil {
conditions = append(conditions, fmt.Sprintf("disabled = $%d", argIdx))
args = append(args, *req.Disabled)
argIdx++
}
if len(conditions) == 0 {
return 0, model.ErrBulkDeleteEmpty
}
whereClause := strings.Join(conditions, " AND ")
tx, err := s.pool.Begin(ctx)
if err != nil {
return 0, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
// 先获取要删除的代理 ID 列表
selectSQL := "SELECT id FROM proxies WHERE " + whereClause
rows, err := tx.Query(ctx, selectSQL, args...)
if err != nil {
return 0, fmt.Errorf("failed to select proxies: %w", err)
}
var ids []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
rows.Close()
return 0, err
}
ids = append(ids, id)
}
rows.Close()
if len(ids) == 0 {
return 0, nil
}
// 删除关联数据
if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = ANY($1)", ids); err != nil {
return 0, fmt.Errorf("failed to delete leases: %w", err)
}
if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = ANY($1)", ids); err != nil {
return 0, fmt.Errorf("failed to delete test logs: %w", err)
}
// 删除代理
result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = ANY($1)", ids)
if err != nil {
return 0, fmt.Errorf("failed to delete proxies: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return 0, fmt.Errorf("failed to commit: %w", err)
}
return result.RowsAffected(), nil
}
// GetStats 获取代理统计信息
func (s *PgStore) GetStats(ctx context.Context) (*model.ProxyStats, error) {
stats := &model.ProxyStats{
ByStatus: make(map[model.ProxyStatus]int),
ByGroup: make(map[string]int),
ByProtocol: make(map[model.ProxyProtocol]int),
}
// 总数和禁用数
const sqlBasic = `
SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE disabled = true) AS disabled,
COALESCE(ROUND(AVG(latency_ms) FILTER (WHERE status = 'alive' AND latency_ms > 0)), 0)::BIGINT AS avg_latency,
COALESCE(AVG(score), 0)::DOUBLE PRECISION AS avg_score
FROM proxies;
`
if err := s.pool.QueryRow(ctx, sqlBasic).Scan(
&stats.Total, &stats.Disabled, &stats.AvgLatencyMs, &stats.AvgScore,
); err != nil {
return nil, fmt.Errorf("failed to get basic stats: %w", err)
}
// 按状态统计
const sqlByStatus = `SELECT status, COUNT(*) FROM proxies GROUP BY status`
rows, err := s.pool.Query(ctx, sqlByStatus)
if err != nil {
return nil, fmt.Errorf("failed to get status stats: %w", err)
}
for rows.Next() {
var status model.ProxyStatus
var count int
if err := rows.Scan(&status, &count); err != nil {
rows.Close()
return nil, err
}
stats.ByStatus[status] = count
}
rows.Close()
// 按分组统计
const sqlByGroup = `SELECT "group", COUNT(*) FROM proxies GROUP BY "group"`
rows, err = s.pool.Query(ctx, sqlByGroup)
if err != nil {
return nil, fmt.Errorf("failed to get group stats: %w", err)
}
for rows.Next() {
var group string
var count int
if err := rows.Scan(&group, &count); err != nil {
rows.Close()
return nil, err
}
stats.ByGroup[group] = count
}
rows.Close()
// 按协议统计
const sqlByProtocol = `SELECT protocol, COUNT(*) FROM proxies GROUP BY protocol`
rows, err = s.pool.Query(ctx, sqlByProtocol)
if err != nil {
return nil, fmt.Errorf("failed to get protocol stats: %w", err)
}
for rows.Next() {
var protocol model.ProxyProtocol
var count int
if err := rows.Scan(&protocol, &count); err != nil {
rows.Close()
return nil, err
}
stats.ByProtocol[protocol] = count
}
rows.Close()
return stats, nil
}
// scanProxies 扫描多条代理记录
func scanProxies(rows pgx.Rows) ([]model.Proxy, error) {
var proxies []model.Proxy
for rows.Next() {
p, err := scanProxyRow(rows)
if err != nil {
return nil, err
}
proxies = append(proxies, *p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating rows: %w", err)
}
return proxies, nil
}
// scanProxy 扫描单条代理记录
func scanProxy(row pgx.Row) (*model.Proxy, error) {
var p model.Proxy
var lastCheckAt *time.Time
err := row.Scan(
&p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password,
&p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt,
&p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt,
)
if err != nil {
return nil, err
}
if lastCheckAt != nil {
p.LastCheckAt = *lastCheckAt
}
return &p, nil
}
// scanProxyRow 从 Rows 扫描单条代理记录
func scanProxyRow(rows pgx.Rows) (*model.Proxy, error) {
var p model.Proxy
var lastCheckAt *time.Time
err := rows.Scan(
&p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password,
&p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt,
&p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt,
)
if err != nil {
return nil, err
}
if lastCheckAt != nil {
p.LastCheckAt = *lastCheckAt
}
return &p, nil
}