419 lines
11 KiB
Go
419 lines
11 KiB
Go
package api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"codex-pool/internal/config"
|
||
"codex-pool/internal/database"
|
||
"codex-pool/internal/invite"
|
||
"codex-pool/internal/logger"
|
||
)
|
||
|
||
// 封禁检查服务状态
|
||
var (
|
||
banCheckRunning bool
|
||
banCheckStopChan chan struct{}
|
||
banCheckMu sync.Mutex
|
||
lastBanCheckTime time.Time
|
||
banCheckTaskState BanCheckTaskState
|
||
)
|
||
|
||
// BanCheckTaskState 封禁检查任务状态
|
||
type BanCheckTaskState struct {
|
||
Running bool `json:"running"`
|
||
StartedAt time.Time `json:"started_at"`
|
||
Total int32 `json:"total"`
|
||
Checked int32 `json:"checked"`
|
||
Banned int32 `json:"banned"`
|
||
Valid int32 `json:"valid"`
|
||
Failed int32 `json:"failed"`
|
||
}
|
||
|
||
// BanCheckResult 单个检查结果
|
||
type BanCheckResult struct {
|
||
Email string `json:"email"`
|
||
Status string `json:"status"` // valid, banned, error
|
||
Message string `json:"message,omitempty"`
|
||
}
|
||
|
||
// StartBanCheckService 启动定期封禁检查服务
|
||
func StartBanCheckService() {
|
||
banCheckMu.Lock()
|
||
if banCheckRunning {
|
||
banCheckMu.Unlock()
|
||
return
|
||
}
|
||
banCheckRunning = true
|
||
banCheckStopChan = make(chan struct{})
|
||
banCheckMu.Unlock()
|
||
|
||
logger.Info("母号封禁检查服务已启动", "", "ban-check")
|
||
|
||
go func() {
|
||
// 默认检查间隔 30 分钟
|
||
checkInterval := 1800
|
||
|
||
for {
|
||
// 动态读取检查间隔配置
|
||
if database.Instance != nil {
|
||
if val, _ := database.Instance.GetConfig("ban_check_interval"); val != "" {
|
||
if v, err := strconv.Atoi(val); err == nil && v >= 60 {
|
||
checkInterval = v
|
||
}
|
||
}
|
||
}
|
||
|
||
select {
|
||
case <-banCheckStopChan:
|
||
logger.Info("母号封禁检查服务已停止", "", "ban-check")
|
||
return
|
||
case <-time.After(time.Duration(checkInterval) * time.Second):
|
||
runScheduledBanCheck()
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// StopBanCheckService 停止定期封禁检查服务
|
||
func StopBanCheckService() {
|
||
banCheckMu.Lock()
|
||
defer banCheckMu.Unlock()
|
||
|
||
if banCheckRunning && banCheckStopChan != nil {
|
||
close(banCheckStopChan)
|
||
banCheckRunning = false
|
||
}
|
||
}
|
||
|
||
// runScheduledBanCheck 执行定期封禁检查
|
||
func runScheduledBanCheck() {
|
||
if database.Instance == nil {
|
||
return
|
||
}
|
||
|
||
// 检查是否开启定期检查
|
||
enabled := false
|
||
if val, _ := database.Instance.GetConfig("ban_check_enabled"); val == "true" {
|
||
enabled = true
|
||
}
|
||
if !enabled {
|
||
return
|
||
}
|
||
|
||
// 检查是否有任务在运行
|
||
if banCheckTaskState.Running || teamProcessState.Running {
|
||
logger.Info("有其他任务在运行,跳过定期封禁检查", "", "ban-check")
|
||
return
|
||
}
|
||
|
||
// 获取检查间隔(小时)
|
||
checkIntervalHours := 24
|
||
if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" {
|
||
if v, err := strconv.Atoi(val); err == nil && v > 0 {
|
||
checkIntervalHours = v
|
||
}
|
||
}
|
||
|
||
// 获取需要检查的母号
|
||
owners, err := database.Instance.GetOwnersForBanCheck(checkIntervalHours)
|
||
if err != nil {
|
||
logger.Error(fmt.Sprintf("获取待检查母号失败: %v", err), "", "ban-check")
|
||
return
|
||
}
|
||
|
||
if len(owners) == 0 {
|
||
logger.Info("没有需要检查的母号", "", "ban-check")
|
||
return
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("定期封禁检查: 发现 %d 个需要检查的母号", len(owners)), "", "ban-check")
|
||
|
||
// 执行检查(并发数 20)
|
||
go runBanCheckTask(owners, 20)
|
||
}
|
||
|
||
// HandleBanCheckSettings 获取/设置封禁检查配置
|
||
func HandleBanCheckSettings(w http.ResponseWriter, r *http.Request) {
|
||
if database.Instance == nil {
|
||
Error(w, http.StatusInternalServerError, "数据库未初始化")
|
||
return
|
||
}
|
||
|
||
switch r.Method {
|
||
case http.MethodGet:
|
||
settings := map[string]interface{}{
|
||
"enabled": false,
|
||
"interval": 1800, // 检查服务间隔(秒)
|
||
"check_hours": 24, // 多少小时后重新检查
|
||
"last_check": lastBanCheckTime,
|
||
"task_state": banCheckTaskState,
|
||
"service_running": banCheckRunning,
|
||
}
|
||
|
||
if val, _ := database.Instance.GetConfig("ban_check_enabled"); val == "true" {
|
||
settings["enabled"] = true
|
||
}
|
||
if val, _ := database.Instance.GetConfig("ban_check_interval"); val != "" {
|
||
if v, err := strconv.Atoi(val); err == nil {
|
||
settings["interval"] = v
|
||
}
|
||
}
|
||
if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" {
|
||
if v, err := strconv.Atoi(val); err == nil {
|
||
settings["check_hours"] = v
|
||
}
|
||
}
|
||
|
||
Success(w, settings)
|
||
|
||
case http.MethodPut:
|
||
var req struct {
|
||
Enabled *bool `json:"enabled"`
|
||
Interval *int `json:"interval"`
|
||
CheckHours *int `json:"check_hours"`
|
||
}
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
Error(w, http.StatusBadRequest, "请求格式错误")
|
||
return
|
||
}
|
||
|
||
if req.Enabled != nil {
|
||
database.Instance.SetConfig("ban_check_enabled", strconv.FormatBool(*req.Enabled))
|
||
}
|
||
if req.Interval != nil && *req.Interval >= 60 {
|
||
database.Instance.SetConfig("ban_check_interval", strconv.Itoa(*req.Interval))
|
||
}
|
||
if req.CheckHours != nil && *req.CheckHours > 0 {
|
||
database.Instance.SetConfig("ban_check_hours", strconv.Itoa(*req.CheckHours))
|
||
}
|
||
|
||
logger.Success("封禁检查配置已更新", "", "ban-check")
|
||
Success(w, map[string]string{"message": "配置已更新"})
|
||
|
||
default:
|
||
Error(w, http.StatusMethodNotAllowed, "不支持的方法")
|
||
}
|
||
}
|
||
|
||
// HandleManualBanCheck 手动触发封禁检查
|
||
func HandleManualBanCheck(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
Error(w, http.StatusMethodNotAllowed, "不支持的方法")
|
||
return
|
||
}
|
||
|
||
if database.Instance == nil {
|
||
Error(w, http.StatusInternalServerError, "数据库未初始化")
|
||
return
|
||
}
|
||
|
||
// 检查是否有任务在运行
|
||
if banCheckTaskState.Running {
|
||
Error(w, http.StatusConflict, "封禁检查任务正在运行中")
|
||
return
|
||
}
|
||
|
||
if teamProcessState.Running {
|
||
Error(w, http.StatusConflict, "Team 处理任务正在运行中,请稍后再试")
|
||
return
|
||
}
|
||
|
||
var req struct {
|
||
IDs []int64 `json:"ids"` // 指定检查的母号 ID,为空则检查所有有效母号
|
||
Concurrency int `json:"concurrency"` // 并发数
|
||
ForceCheck bool `json:"force_check"` // 是否强制检查(忽略上次检查时间)
|
||
}
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
// 允许空 body
|
||
req.Concurrency = 20
|
||
}
|
||
|
||
if req.Concurrency <= 0 {
|
||
req.Concurrency = 20
|
||
}
|
||
|
||
var owners []database.TeamOwner
|
||
var err error
|
||
|
||
if len(req.IDs) > 0 {
|
||
// 检查指定的母号
|
||
for _, id := range req.IDs {
|
||
ownerList, _, err := database.Instance.GetTeamOwners("", 1000, 0)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
for _, o := range ownerList {
|
||
if o.ID == id && o.Status == "valid" {
|
||
owners = append(owners, o)
|
||
break
|
||
}
|
||
}
|
||
}
|
||
} else if req.ForceCheck {
|
||
// 强制检查所有有效母号
|
||
owners, err = database.Instance.GetPendingOwners()
|
||
} else {
|
||
// 检查需要检查的母号
|
||
checkIntervalHours := 24
|
||
if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" {
|
||
if v, err := strconv.Atoi(val); err == nil && v > 0 {
|
||
checkIntervalHours = v
|
||
}
|
||
}
|
||
owners, err = database.Instance.GetOwnersForBanCheck(checkIntervalHours)
|
||
}
|
||
|
||
if err != nil {
|
||
Error(w, http.StatusInternalServerError, fmt.Sprintf("获取母号失败: %v", err))
|
||
return
|
||
}
|
||
|
||
if len(owners) == 0 {
|
||
Success(w, map[string]interface{}{
|
||
"message": "没有需要检查的母号",
|
||
"total": 0,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 异步执行检查
|
||
go runBanCheckTask(owners, req.Concurrency)
|
||
|
||
Success(w, map[string]interface{}{
|
||
"message": "封禁检查任务已启动",
|
||
"total": len(owners),
|
||
"concurrency": req.Concurrency,
|
||
})
|
||
}
|
||
|
||
// HandleBanCheckStatus 获取封禁检查任务状态
|
||
func HandleBanCheckStatus(w http.ResponseWriter, r *http.Request) {
|
||
Success(w, banCheckTaskState)
|
||
}
|
||
|
||
// runBanCheckTask 执行封禁检查任务
|
||
func runBanCheckTask(owners []database.TeamOwner, concurrency int) {
|
||
banCheckTaskState = BanCheckTaskState{
|
||
Running: true,
|
||
StartedAt: time.Now(),
|
||
Total: int32(len(owners)),
|
||
}
|
||
defer func() {
|
||
banCheckTaskState.Running = false
|
||
lastBanCheckTime = time.Now()
|
||
}()
|
||
|
||
logger.Status(fmt.Sprintf("封禁检查中: 共 %d 个母号, 并发数: %d", len(owners), concurrency), "", "ban-check")
|
||
|
||
// 任务队列
|
||
taskChan := make(chan database.TeamOwner, len(owners))
|
||
var wg sync.WaitGroup
|
||
|
||
// 获取代理配置
|
||
proxy := ""
|
||
if config.Global != nil {
|
||
proxy = config.Global.GetProxy()
|
||
}
|
||
|
||
// 启动 worker
|
||
for w := 0; w < concurrency; w++ {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
for owner := range taskChan {
|
||
result := checkSingleOwnerBan(owner, proxy)
|
||
|
||
// 更新计数
|
||
atomic.AddInt32(&banCheckTaskState.Checked, 1)
|
||
switch result.Status {
|
||
case "valid":
|
||
atomic.AddInt32(&banCheckTaskState.Valid, 1)
|
||
case "banned":
|
||
atomic.AddInt32(&banCheckTaskState.Banned, 1)
|
||
case "error":
|
||
atomic.AddInt32(&banCheckTaskState.Failed, 1)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 发送任务
|
||
for _, owner := range owners {
|
||
taskChan <- owner
|
||
}
|
||
close(taskChan)
|
||
|
||
// 等待完成
|
||
wg.Wait()
|
||
|
||
logger.Success(fmt.Sprintf("封禁检查完成: 共 %d, 有效 %d, 封禁 %d, 失败 %d",
|
||
banCheckTaskState.Total, banCheckTaskState.Valid, banCheckTaskState.Banned, banCheckTaskState.Failed), "", "ban-check")
|
||
}
|
||
|
||
// checkSingleOwnerBan 检查单个母号的封禁状态
|
||
// 使用 accounts/check API 直接检测,不发送邀请
|
||
func checkSingleOwnerBan(owner database.TeamOwner, proxy string) BanCheckResult {
|
||
result := BanCheckResult{
|
||
Email: owner.Email,
|
||
Status: "error",
|
||
}
|
||
|
||
// 创建检查器
|
||
var checker *invite.TeamInviter
|
||
if proxy != "" {
|
||
checker = invite.NewWithProxy(owner.Token, proxy)
|
||
} else {
|
||
checker = invite.New(owner.Token)
|
||
}
|
||
|
||
// 调用 accounts/check API 检测状态
|
||
accountStatus := checker.CheckAccountStatus()
|
||
|
||
// 更新最后检查时间
|
||
database.Instance.UpdateOwnerLastCheckedAtByEmail(owner.Email)
|
||
|
||
switch accountStatus.Status {
|
||
case "active":
|
||
// 账户正常
|
||
logger.Info(fmt.Sprintf("母号有效: %s (plan: %s)", owner.Email, accountStatus.PlanType), owner.Email, "ban-check")
|
||
result.Status = "valid"
|
||
result.Message = fmt.Sprintf("母号状态正常 (plan: %s)", accountStatus.PlanType)
|
||
|
||
// 如果获取到了 account_id 且数据库中没有,则更新
|
||
if accountStatus.AccountID != "" && owner.AccountID == "" {
|
||
database.Instance.UpdateOwnerAccountID(owner.ID, accountStatus.AccountID)
|
||
}
|
||
|
||
case "banned":
|
||
// 账户被封禁
|
||
logger.Warning(fmt.Sprintf("母号被封禁: %s - %s", owner.Email, accountStatus.Error), owner.Email, "ban-check")
|
||
database.Instance.MarkOwnerAsInvalid(owner.Email)
|
||
database.Instance.DeleteTeamOwnerByEmail(owner.Email)
|
||
logger.Info(fmt.Sprintf("母号被封禁已删除: %s", owner.Email), owner.Email, "ban-check")
|
||
result.Status = "banned"
|
||
result.Message = accountStatus.Error
|
||
|
||
case "token_expired":
|
||
// Token 过期
|
||
logger.Warning(fmt.Sprintf("母号 Token 过期: %s", owner.Email), owner.Email, "ban-check")
|
||
database.Instance.MarkOwnerAsInvalid(owner.Email)
|
||
database.Instance.DeleteTeamOwnerByEmail(owner.Email)
|
||
logger.Info(fmt.Sprintf("母号Token过期已删除: %s", owner.Email), owner.Email, "ban-check")
|
||
result.Status = "banned"
|
||
result.Message = "Token 已过期"
|
||
|
||
default:
|
||
// 其他错误
|
||
logger.Error(fmt.Sprintf("检查失败: %s - %s", owner.Email, accountStatus.Error), owner.Email, "ban-check")
|
||
result.Message = accountStatus.Error
|
||
}
|
||
|
||
return result
|
||
}
|