package store import ( "context" "fmt" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "proxyrotator/internal/model" ) // PgStore PostgreSQL 存储实现 type PgStore struct { pool *pgxpool.Pool } // NewPgStore 创建 PostgreSQL 存储 func NewPgStore(ctx context.Context, connString string) (*PgStore, error) { pool, err := pgxpool.New(ctx, connString) if err != nil { return nil, fmt.Errorf("failed to create connection pool: %w", err) } if err := pool.Ping(ctx); err != nil { pool.Close() return nil, fmt.Errorf("failed to ping database: %w", err) } return &PgStore{pool: pool}, nil } // Close 关闭连接池 func (s *PgStore) Close() error { s.pool.Close() return nil } // Pool 返回连接池(用于其他模块共享) func (s *PgStore) Pool() *pgxpool.Pool { return s.pool } // UpsertMany 批量导入代理 func (s *PgStore) UpsertMany(ctx context.Context, proxies []model.Proxy) (imported, duplicated int, err error) { if len(proxies) == 0 { return 0, 0, nil } const sqlUpsert = ` INSERT INTO proxies (id, protocol, host, port, username, password, "group", tags) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (protocol, host, port, username) DO UPDATE SET password = EXCLUDED.password, "group" = EXCLUDED."group", tags = ( SELECT ARRAY( SELECT DISTINCT unnest(proxies.tags || EXCLUDED.tags) ) ) RETURNING (xmax = 0) AS inserted ` tx, err := s.pool.Begin(ctx) if err != nil { return 0, 0, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Rollback(ctx) for _, p := range proxies { var inserted bool err := tx.QueryRow(ctx, sqlUpsert, p.ID, p.Protocol, p.Host, p.Port, p.Username, p.Password, p.Group, p.Tags, ).Scan(&inserted) if err != nil { return 0, 0, fmt.Errorf("failed to upsert proxy: %w", err) } if inserted { imported++ } else { duplicated++ } } if err := tx.Commit(ctx); err != nil { return 0, 0, fmt.Errorf("failed to commit transaction: %w", err) } return imported, duplicated, nil } // List 查询代理列表 func (s *PgStore) List(ctx context.Context, q model.ProxyQuery) ([]model.Proxy, error) { var conditions []string var args []any argIdx := 1 if q.Group != "" { conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx)) args = append(args, q.Group) argIdx++ } if q.OnlyEnabled { conditions = append(conditions, "disabled = false") } if len(q.StatusIn) > 0 { placeholders := make([]string, len(q.StatusIn)) for i, status := range q.StatusIn { placeholders[i] = fmt.Sprintf("$%d", argIdx) args = append(args, status) argIdx++ } conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) } if len(q.TagsAny) > 0 { conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx)) args = append(args, q.TagsAny) argIdx++ } sql := `SELECT id, protocol, host, port, username, password, "group", tags, status, score, latency_ms, last_check_at, fail_count, success_count, disabled, created_at, updated_at FROM proxies` if len(conditions) > 0 { sql += " WHERE " + strings.Join(conditions, " AND ") } // 排序方式 switch q.OrderBy { case "random": sql += " ORDER BY RANDOM()" case "latency": sql += " ORDER BY latency_ms ASC NULLS LAST" default: sql += " ORDER BY score DESC, last_check_at DESC NULLS LAST" } if q.Limit > 0 { sql += fmt.Sprintf(" LIMIT %d", q.Limit) } rows, err := s.pool.Query(ctx, sql, args...) if err != nil { return nil, fmt.Errorf("failed to query proxies: %w", err) } defer rows.Close() return scanProxies(rows) } // GetByID 根据 ID 获取代理 func (s *PgStore) GetByID(ctx context.Context, id string) (*model.Proxy, error) { const sql = `SELECT id, protocol, host, port, username, password, "group", tags, status, score, latency_ms, last_check_at, fail_count, success_count, disabled, created_at, updated_at FROM proxies WHERE id = $1` row := s.pool.QueryRow(ctx, sql, id) p, err := scanProxy(row) if err != nil { if err == pgx.ErrNoRows { return nil, model.ErrProxyNotFound } return nil, fmt.Errorf("failed to get proxy: %w", err) } return p, nil } // UpdateHealth 更新代理健康度 func (s *PgStore) UpdateHealth(ctx context.Context, proxyID string, patch model.HealthPatch) error { var sets []string var args []any argIdx := 1 if patch.Status != nil { sets = append(sets, fmt.Sprintf("status = $%d", argIdx)) args = append(args, *patch.Status) argIdx++ } if patch.ScoreDelta != 0 { sets = append(sets, fmt.Sprintf("score = GREATEST(-1000, LEAST(1000, score + $%d))", argIdx)) args = append(args, patch.ScoreDelta) argIdx++ } if patch.LatencyMs != nil { sets = append(sets, fmt.Sprintf("latency_ms = $%d", argIdx)) args = append(args, *patch.LatencyMs) argIdx++ } if patch.CheckedAt != nil { sets = append(sets, fmt.Sprintf("last_check_at = $%d", argIdx)) args = append(args, *patch.CheckedAt) argIdx++ } if patch.FailInc > 0 { sets = append(sets, fmt.Sprintf("fail_count = fail_count + $%d", argIdx)) args = append(args, patch.FailInc) argIdx++ } if patch.SuccessInc > 0 { sets = append(sets, fmt.Sprintf("success_count = success_count + $%d", argIdx)) args = append(args, patch.SuccessInc) argIdx++ } if len(sets) == 0 { return nil } sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx) args = append(args, proxyID) _, err := s.pool.Exec(ctx, sql, args...) if err != nil { return fmt.Errorf("failed to update health: %w", err) } return nil } // NextIndex RR 原子游标 func (s *PgStore) NextIndex(ctx context.Context, key string, modulo int) (int, error) { if modulo <= 0 { return 0, model.ErrBadModulo } const sql = ` INSERT INTO rr_cursors (k, v) VALUES ($1, 0) ON CONFLICT (k) DO UPDATE SET v = rr_cursors.v + 1, updated_at = now() RETURNING v ` var v int64 err := s.pool.QueryRow(ctx, sql, key).Scan(&v) if err != nil { return 0, fmt.Errorf("failed to get next index: %w", err) } idx := int(v % int64(modulo)) if idx < 0 { idx = -idx } return idx, nil } // CreateLease 创建租约 func (s *PgStore) CreateLease(ctx context.Context, lease model.Lease) error { const sql = ` INSERT INTO proxy_leases (lease_id, proxy_id, expire_at, site, "group") VALUES ($1, $2, $3, $4, $5) ` _, err := s.pool.Exec(ctx, sql, lease.LeaseID, lease.ProxyID, lease.ExpireAt, lease.Site, lease.Group, ) if err != nil { return fmt.Errorf("failed to create lease: %w", err) } return nil } // GetLease 获取租约 func (s *PgStore) GetLease(ctx context.Context, leaseID string) (*model.Lease, error) { const sql = ` SELECT lease_id, proxy_id, expire_at, site, "group" FROM proxy_leases WHERE lease_id = $1 AND expire_at > now() ` var lease model.Lease err := s.pool.QueryRow(ctx, sql, leaseID).Scan( &lease.LeaseID, &lease.ProxyID, &lease.ExpireAt, &lease.Site, &lease.Group, ) if err != nil { if err == pgx.ErrNoRows { return nil, model.ErrLeaseExpired } return nil, fmt.Errorf("failed to get lease: %w", err) } return &lease, nil } // DeleteExpiredLeases 删除过期租约 func (s *PgStore) DeleteExpiredLeases(ctx context.Context) (int64, error) { const sql = `DELETE FROM proxy_leases WHERE expire_at <= now()` result, err := s.pool.Exec(ctx, sql) if err != nil { return 0, fmt.Errorf("failed to delete expired leases: %w", err) } return result.RowsAffected(), nil } // InsertTestLog 插入测试日志 func (s *PgStore) InsertTestLog(ctx context.Context, r model.TestResult, site string) error { const sql = ` INSERT INTO proxy_test_logs (proxy_id, site, ok, latency_ms, error_text, checked_at) VALUES ($1, $2, $3, $4, $5, $6) ` _, err := s.pool.Exec(ctx, sql, r.ProxyID, site, r.OK, r.LatencyMs, r.ErrorText, r.CheckedAt, ) if err != nil { return fmt.Errorf("failed to insert test log: %w", err) } return nil } // ListPaginated 分页查询代理列表 func (s *PgStore) ListPaginated(ctx context.Context, q model.ProxyListQuery) ([]model.Proxy, int, error) { var conditions []string var args []any argIdx := 1 if q.Group != "" { conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx)) args = append(args, q.Group) argIdx++ } if q.OnlyEnabled { conditions = append(conditions, "disabled = false") } if len(q.StatusIn) > 0 { placeholders := make([]string, len(q.StatusIn)) for i, status := range q.StatusIn { placeholders[i] = fmt.Sprintf("$%d", argIdx) args = append(args, status) argIdx++ } conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) } if len(q.TagsAny) > 0 { conditions = append(conditions, fmt.Sprintf("tags && $%d", argIdx)) args = append(args, q.TagsAny) argIdx++ } whereClause := "" if len(conditions) > 0 { whereClause = " WHERE " + strings.Join(conditions, " AND ") } // 查询总数 countSQL := "SELECT COUNT(*) FROM proxies" + whereClause var total int if err := s.pool.QueryRow(ctx, countSQL, args...).Scan(&total); err != nil { return nil, 0, fmt.Errorf("failed to count proxies: %w", err) } // 查询数据 dataSQL := `SELECT id, protocol, host, port, username, password, "group", tags, status, score, latency_ms, last_check_at, fail_count, success_count, disabled, created_at, updated_at FROM proxies` + whereClause + " ORDER BY score DESC, last_check_at DESC NULLS LAST" dataSQL += fmt.Sprintf(" LIMIT $%d OFFSET $%d", argIdx, argIdx+1) args = append(args, q.Limit, q.Offset) rows, err := s.pool.Query(ctx, dataSQL, args...) if err != nil { return nil, 0, fmt.Errorf("failed to query proxies: %w", err) } defer rows.Close() proxies, err := scanProxies(rows) if err != nil { return nil, 0, err } return proxies, total, nil } // Update 更新代理字段 func (s *PgStore) Update(ctx context.Context, id string, patch model.ProxyPatch) error { var sets []string var args []any argIdx := 1 if patch.Group != nil { sets = append(sets, fmt.Sprintf(`"group" = $%d`, argIdx)) args = append(args, *patch.Group) argIdx++ } if patch.Tags != nil { sets = append(sets, fmt.Sprintf("tags = $%d", argIdx)) args = append(args, *patch.Tags) argIdx++ } if len(patch.AddTags) > 0 { sets = append(sets, fmt.Sprintf(`tags = ( SELECT ARRAY(SELECT DISTINCT unnest(tags || $%d)) )`, argIdx)) args = append(args, patch.AddTags) argIdx++ } if patch.Disabled != nil { sets = append(sets, fmt.Sprintf("disabled = $%d", argIdx)) args = append(args, *patch.Disabled) argIdx++ } if len(sets) == 0 { return model.ErrInvalidPatch } sql := fmt.Sprintf("UPDATE proxies SET %s WHERE id = $%d", strings.Join(sets, ", "), argIdx) args = append(args, id) result, err := s.pool.Exec(ctx, sql, args...) if err != nil { return fmt.Errorf("failed to update proxy: %w", err) } if result.RowsAffected() == 0 { return model.ErrProxyNotFound } return nil } // Delete 删除单个代理 func (s *PgStore) Delete(ctx context.Context, id string) error { tx, err := s.pool.Begin(ctx) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Rollback(ctx) // 删除租约 if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = $1", id); err != nil { return fmt.Errorf("failed to delete leases: %w", err) } // 删除测试日志 if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = $1", id); err != nil { return fmt.Errorf("failed to delete test logs: %w", err) } // 删除代理 result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = $1", id) if err != nil { return fmt.Errorf("failed to delete proxy: %w", err) } if result.RowsAffected() == 0 { return model.ErrProxyNotFound } return tx.Commit(ctx) } // DeleteMany 批量删除代理 func (s *PgStore) DeleteMany(ctx context.Context, req model.BulkDeleteRequest) (int64, error) { var conditions []string var args []any argIdx := 1 if len(req.IDs) > 0 { conditions = append(conditions, fmt.Sprintf("id = ANY($%d)", argIdx)) args = append(args, req.IDs) argIdx++ } if req.Status != "" { conditions = append(conditions, fmt.Sprintf("status = $%d", argIdx)) args = append(args, req.Status) argIdx++ } if req.Group != "" { conditions = append(conditions, fmt.Sprintf(`"group" = $%d`, argIdx)) args = append(args, req.Group) argIdx++ } if req.Disabled != nil { conditions = append(conditions, fmt.Sprintf("disabled = $%d", argIdx)) args = append(args, *req.Disabled) argIdx++ } if len(conditions) == 0 { return 0, model.ErrBulkDeleteEmpty } whereClause := strings.Join(conditions, " AND ") tx, err := s.pool.Begin(ctx) if err != nil { return 0, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Rollback(ctx) // 先获取要删除的代理 ID 列表 selectSQL := "SELECT id FROM proxies WHERE " + whereClause rows, err := tx.Query(ctx, selectSQL, args...) if err != nil { return 0, fmt.Errorf("failed to select proxies: %w", err) } var ids []string for rows.Next() { var id string if err := rows.Scan(&id); err != nil { rows.Close() return 0, err } ids = append(ids, id) } rows.Close() if len(ids) == 0 { return 0, nil } // 删除关联数据 if _, err := tx.Exec(ctx, "DELETE FROM proxy_leases WHERE proxy_id = ANY($1)", ids); err != nil { return 0, fmt.Errorf("failed to delete leases: %w", err) } if _, err := tx.Exec(ctx, "DELETE FROM proxy_test_logs WHERE proxy_id = ANY($1)", ids); err != nil { return 0, fmt.Errorf("failed to delete test logs: %w", err) } // 删除代理 result, err := tx.Exec(ctx, "DELETE FROM proxies WHERE id = ANY($1)", ids) if err != nil { return 0, fmt.Errorf("failed to delete proxies: %w", err) } if err := tx.Commit(ctx); err != nil { return 0, fmt.Errorf("failed to commit: %w", err) } return result.RowsAffected(), nil } // GetStats 获取代理统计信息 func (s *PgStore) GetStats(ctx context.Context) (*model.ProxyStats, error) { stats := &model.ProxyStats{ ByStatus: make(map[model.ProxyStatus]int), ByGroup: make(map[string]int), ByProtocol: make(map[model.ProxyProtocol]int), } // 总数和禁用数 const sqlBasic = ` SELECT COUNT(*) AS total, COUNT(*) FILTER (WHERE disabled = true) AS disabled, COALESCE(ROUND(AVG(latency_ms) FILTER (WHERE status = 'alive' AND latency_ms > 0)), 0)::BIGINT AS avg_latency, COALESCE(AVG(score), 0)::DOUBLE PRECISION AS avg_score FROM proxies; ` if err := s.pool.QueryRow(ctx, sqlBasic).Scan( &stats.Total, &stats.Disabled, &stats.AvgLatencyMs, &stats.AvgScore, ); err != nil { return nil, fmt.Errorf("failed to get basic stats: %w", err) } // 按状态统计 const sqlByStatus = `SELECT status, COUNT(*) FROM proxies GROUP BY status` rows, err := s.pool.Query(ctx, sqlByStatus) if err != nil { return nil, fmt.Errorf("failed to get status stats: %w", err) } for rows.Next() { var status model.ProxyStatus var count int if err := rows.Scan(&status, &count); err != nil { rows.Close() return nil, err } stats.ByStatus[status] = count } rows.Close() // 按分组统计 const sqlByGroup = `SELECT "group", COUNT(*) FROM proxies GROUP BY "group"` rows, err = s.pool.Query(ctx, sqlByGroup) if err != nil { return nil, fmt.Errorf("failed to get group stats: %w", err) } for rows.Next() { var group string var count int if err := rows.Scan(&group, &count); err != nil { rows.Close() return nil, err } stats.ByGroup[group] = count } rows.Close() // 按协议统计 const sqlByProtocol = `SELECT protocol, COUNT(*) FROM proxies GROUP BY protocol` rows, err = s.pool.Query(ctx, sqlByProtocol) if err != nil { return nil, fmt.Errorf("failed to get protocol stats: %w", err) } for rows.Next() { var protocol model.ProxyProtocol var count int if err := rows.Scan(&protocol, &count); err != nil { rows.Close() return nil, err } stats.ByProtocol[protocol] = count } rows.Close() return stats, nil } // scanProxies 扫描多条代理记录 func scanProxies(rows pgx.Rows) ([]model.Proxy, error) { var proxies []model.Proxy for rows.Next() { p, err := scanProxyRow(rows) if err != nil { return nil, err } proxies = append(proxies, *p) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating rows: %w", err) } return proxies, nil } // scanProxy 扫描单条代理记录 func scanProxy(row pgx.Row) (*model.Proxy, error) { var p model.Proxy var lastCheckAt *time.Time err := row.Scan( &p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password, &p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt, &p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt, ) if err != nil { return nil, err } if lastCheckAt != nil { p.LastCheckAt = *lastCheckAt } return &p, nil } // scanProxyRow 从 Rows 扫描单条代理记录 func scanProxyRow(rows pgx.Rows) (*model.Proxy, error) { var p model.Proxy var lastCheckAt *time.Time err := rows.Scan( &p.ID, &p.Protocol, &p.Host, &p.Port, &p.Username, &p.Password, &p.Group, &p.Tags, &p.Status, &p.Score, &p.LatencyMs, &lastCheckAt, &p.FailCount, &p.SuccessCount, &p.Disabled, &p.CreatedAt, &p.UpdatedAt, ) if err != nil { return nil, err } if lastCheckAt != nil { p.LastCheckAt = *lastCheckAt } return &p, nil }