package api import ( "encoding/json" "fmt" "net/http" "strconv" "sync" "sync/atomic" "time" "codex-pool/internal/config" "codex-pool/internal/database" "codex-pool/internal/invite" "codex-pool/internal/logger" ) // 封禁检查服务状态 var ( banCheckRunning bool banCheckStopChan chan struct{} banCheckMu sync.Mutex lastBanCheckTime time.Time banCheckTaskState BanCheckTaskState ) // BanCheckTaskState 封禁检查任务状态 type BanCheckTaskState struct { Running bool `json:"running"` StartedAt time.Time `json:"started_at"` Total int32 `json:"total"` Checked int32 `json:"checked"` Banned int32 `json:"banned"` Valid int32 `json:"valid"` Failed int32 `json:"failed"` } // BanCheckResult 单个检查结果 type BanCheckResult struct { Email string `json:"email"` Status string `json:"status"` // valid, banned, error Message string `json:"message,omitempty"` } // StartBanCheckService 启动定期封禁检查服务 func StartBanCheckService() { banCheckMu.Lock() if banCheckRunning { banCheckMu.Unlock() return } banCheckRunning = true banCheckStopChan = make(chan struct{}) banCheckMu.Unlock() logger.Info("母号封禁检查服务已启动", "", "ban-check") go func() { // 默认检查间隔 30 分钟 checkInterval := 1800 for { // 动态读取检查间隔配置 if database.Instance != nil { if val, _ := database.Instance.GetConfig("ban_check_interval"); val != "" { if v, err := strconv.Atoi(val); err == nil && v >= 60 { checkInterval = v } } } select { case <-banCheckStopChan: logger.Info("母号封禁检查服务已停止", "", "ban-check") return case <-time.After(time.Duration(checkInterval) * time.Second): runScheduledBanCheck() } } }() } // StopBanCheckService 停止定期封禁检查服务 func StopBanCheckService() { banCheckMu.Lock() defer banCheckMu.Unlock() if banCheckRunning && banCheckStopChan != nil { close(banCheckStopChan) banCheckRunning = false } } // runScheduledBanCheck 执行定期封禁检查 func runScheduledBanCheck() { if database.Instance == nil { return } // 检查是否开启定期检查 enabled := false if val, _ := database.Instance.GetConfig("ban_check_enabled"); val == "true" { enabled = true } if !enabled { return } // 检查是否有任务在运行 if banCheckTaskState.Running || teamProcessState.Running { logger.Info("有其他任务在运行,跳过定期封禁检查", "", "ban-check") return } // 获取检查间隔(小时) checkIntervalHours := 24 if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" { if v, err := strconv.Atoi(val); err == nil && v > 0 { checkIntervalHours = v } } // 获取需要检查的母号 owners, err := database.Instance.GetOwnersForBanCheck(checkIntervalHours) if err != nil { logger.Error(fmt.Sprintf("获取待检查母号失败: %v", err), "", "ban-check") return } if len(owners) == 0 { logger.Info("没有需要检查的母号", "", "ban-check") return } logger.Info(fmt.Sprintf("定期封禁检查: 发现 %d 个需要检查的母号", len(owners)), "", "ban-check") // 执行检查(并发数 20) go runBanCheckTask(owners, 20) } // HandleBanCheckSettings 获取/设置封禁检查配置 func HandleBanCheckSettings(w http.ResponseWriter, r *http.Request) { if database.Instance == nil { Error(w, http.StatusInternalServerError, "数据库未初始化") return } switch r.Method { case http.MethodGet: settings := map[string]interface{}{ "enabled": false, "interval": 1800, // 检查服务间隔(秒) "check_hours": 24, // 多少小时后重新检查 "last_check": lastBanCheckTime, "task_state": banCheckTaskState, "service_running": banCheckRunning, } if val, _ := database.Instance.GetConfig("ban_check_enabled"); val == "true" { settings["enabled"] = true } if val, _ := database.Instance.GetConfig("ban_check_interval"); val != "" { if v, err := strconv.Atoi(val); err == nil { settings["interval"] = v } } if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" { if v, err := strconv.Atoi(val); err == nil { settings["check_hours"] = v } } Success(w, settings) case http.MethodPut: var req struct { Enabled *bool `json:"enabled"` Interval *int `json:"interval"` CheckHours *int `json:"check_hours"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { Error(w, http.StatusBadRequest, "请求格式错误") return } if req.Enabled != nil { database.Instance.SetConfig("ban_check_enabled", strconv.FormatBool(*req.Enabled)) } if req.Interval != nil && *req.Interval >= 60 { database.Instance.SetConfig("ban_check_interval", strconv.Itoa(*req.Interval)) } if req.CheckHours != nil && *req.CheckHours > 0 { database.Instance.SetConfig("ban_check_hours", strconv.Itoa(*req.CheckHours)) } logger.Success("封禁检查配置已更新", "", "ban-check") Success(w, map[string]string{"message": "配置已更新"}) default: Error(w, http.StatusMethodNotAllowed, "不支持的方法") } } // HandleManualBanCheck 手动触发封禁检查 func HandleManualBanCheck(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { Error(w, http.StatusMethodNotAllowed, "不支持的方法") return } if database.Instance == nil { Error(w, http.StatusInternalServerError, "数据库未初始化") return } // 检查是否有任务在运行 if banCheckTaskState.Running { Error(w, http.StatusConflict, "封禁检查任务正在运行中") return } if teamProcessState.Running { Error(w, http.StatusConflict, "Team 处理任务正在运行中,请稍后再试") return } var req struct { IDs []int64 `json:"ids"` // 指定检查的母号 ID,为空则检查所有有效母号 Concurrency int `json:"concurrency"` // 并发数 ForceCheck bool `json:"force_check"` // 是否强制检查(忽略上次检查时间) } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { // 允许空 body req.Concurrency = 20 } if req.Concurrency <= 0 { req.Concurrency = 20 } var owners []database.TeamOwner var err error if len(req.IDs) > 0 { // 检查指定的母号 for _, id := range req.IDs { ownerList, _, err := database.Instance.GetTeamOwners("", 1000, 0) if err != nil { continue } for _, o := range ownerList { if o.ID == id && o.Status == "valid" { owners = append(owners, o) break } } } } else if req.ForceCheck { // 强制检查所有有效母号 owners, err = database.Instance.GetPendingOwners() } else { // 检查需要检查的母号 checkIntervalHours := 24 if val, _ := database.Instance.GetConfig("ban_check_hours"); val != "" { if v, err := strconv.Atoi(val); err == nil && v > 0 { checkIntervalHours = v } } owners, err = database.Instance.GetOwnersForBanCheck(checkIntervalHours) } if err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("获取母号失败: %v", err)) return } if len(owners) == 0 { Success(w, map[string]interface{}{ "message": "没有需要检查的母号", "total": 0, }) return } // 异步执行检查 go runBanCheckTask(owners, req.Concurrency) Success(w, map[string]interface{}{ "message": "封禁检查任务已启动", "total": len(owners), "concurrency": req.Concurrency, }) } // HandleBanCheckStatus 获取封禁检查任务状态 func HandleBanCheckStatus(w http.ResponseWriter, r *http.Request) { Success(w, banCheckTaskState) } // runBanCheckTask 执行封禁检查任务 func runBanCheckTask(owners []database.TeamOwner, concurrency int) { banCheckTaskState = BanCheckTaskState{ Running: true, StartedAt: time.Now(), Total: int32(len(owners)), } defer func() { banCheckTaskState.Running = false lastBanCheckTime = time.Now() }() logger.Info(fmt.Sprintf("开始封禁检查: 共 %d 个母号, 并发数: %d", len(owners), concurrency), "", "ban-check") // 任务队列 taskChan := make(chan database.TeamOwner, len(owners)) var wg sync.WaitGroup // 获取代理配置 proxy := "" if config.Global != nil { proxy = config.Global.GetProxy() } // 启动 worker for w := 0; w < concurrency; w++ { wg.Add(1) go func() { defer wg.Done() for owner := range taskChan { result := checkSingleOwnerBan(owner, proxy) // 更新计数 atomic.AddInt32(&banCheckTaskState.Checked, 1) switch result.Status { case "valid": atomic.AddInt32(&banCheckTaskState.Valid, 1) case "banned": atomic.AddInt32(&banCheckTaskState.Banned, 1) case "error": atomic.AddInt32(&banCheckTaskState.Failed, 1) } } }() } // 发送任务 for _, owner := range owners { taskChan <- owner } close(taskChan) // 等待完成 wg.Wait() logger.Success(fmt.Sprintf("封禁检查完成: 共 %d, 有效 %d, 封禁 %d, 失败 %d", banCheckTaskState.Total, banCheckTaskState.Valid, banCheckTaskState.Banned, banCheckTaskState.Failed), "", "ban-check") } // checkSingleOwnerBan 检查单个母号的封禁状态 // 使用 accounts/check API 直接检测,不发送邀请 func checkSingleOwnerBan(owner database.TeamOwner, proxy string) BanCheckResult { result := BanCheckResult{ Email: owner.Email, Status: "error", } // 创建检查器 var checker *invite.TeamInviter if proxy != "" { checker = invite.NewWithProxy(owner.Token, proxy) } else { checker = invite.New(owner.Token) } // 调用 accounts/check API 检测状态 accountStatus := checker.CheckAccountStatus() // 更新最后检查时间 database.Instance.UpdateOwnerLastCheckedAtByEmail(owner.Email) switch accountStatus.Status { case "active": // 账户正常 logger.Info(fmt.Sprintf("母号有效: %s (plan: %s)", owner.Email, accountStatus.PlanType), owner.Email, "ban-check") result.Status = "valid" result.Message = fmt.Sprintf("母号状态正常 (plan: %s)", accountStatus.PlanType) // 如果获取到了 account_id 且数据库中没有,则更新 if accountStatus.AccountID != "" && owner.AccountID == "" { database.Instance.UpdateOwnerAccountID(owner.ID, accountStatus.AccountID) } case "banned": // 账户被封禁 logger.Warning(fmt.Sprintf("母号被封禁: %s - %s", owner.Email, accountStatus.Error), owner.Email, "ban-check") database.Instance.MarkOwnerAsInvalid(owner.Email) result.Status = "banned" result.Message = accountStatus.Error case "token_expired": // Token 过期 logger.Warning(fmt.Sprintf("母号 Token 过期: %s", owner.Email), owner.Email, "ban-check") database.Instance.MarkOwnerAsInvalid(owner.Email) result.Status = "banned" result.Message = "Token 已过期" default: // 其他错误 logger.Error(fmt.Sprintf("检查失败: %s - %s", owner.Email, accountStatus.Error), owner.Email, "ban-check") result.Message = accountStatus.Error } return result }