package api import ( "bytes" "encoding/json" "fmt" "io" "net/http" "sync" "sync/atomic" "time" "codex-pool/internal/config" "codex-pool/internal/logger" ) // ============================================================================ // 批量 RT 导入模块 // 功能: 读取 Refresh Token 列表,通过 S2A API 验证并创建账号 // ============================================================================ // rtImportRequest 导入请求 type rtImportRequest struct { Tokens []string `json:"tokens"` Prefix string `json:"prefix"` // "team" 或 "free" Concurrency int `json:"concurrency"` // S2A 账号并发数 Priority int `json:"priority"` // S2A 账号优先级 GroupIDs []int `json:"group_ids"` // S2A 分组ID ProxyID *int `json:"proxy_id"` // S2A 代理ID RateMultiplier float64 `json:"rate_multiplier"` // 计费倍率 } // rtImportResult 单条导入结果 type rtImportResult struct { Index int `json:"index"` RT string `json:"rt"` // 脱敏后的 RT 前缀 Email string `json:"email"` // 验证得到的邮箱 Success bool `json:"success"` Error string `json:"error,omitempty"` AcctID int `json:"account_id,omitempty"` // S2A 创建的账号ID } // rtImportState 导入任务状态 type rtImportState struct { Running bool `json:"running"` StartedAt time.Time `json:"started_at"` Total int `json:"total"` Completed int32 `json:"completed"` Success int32 `json:"success"` Failed int32 `json:"failed"` Results []rtImportResult `json:"results"` mu sync.Mutex stopCh chan struct{} } var rtImportTaskState = &rtImportState{} // isRTImportStopped 检查导入任务是否已被停止 func isRTImportStopped() bool { select { case <-rtImportTaskState.stopCh: return true default: return false } } // maskRT 脱敏处理 RT,只显示前16个字符 func maskRT(rt string) string { if len(rt) <= 16 { return rt } return rt[:16] + "..." } // HandleRTImportStart POST /api/rt-import/start - 启动批量 RT 导入 func HandleRTImportStart(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { Error(w, http.StatusMethodNotAllowed, "仅支持 POST") return } // 检查是否正在运行 if rtImportTaskState.Running { Error(w, http.StatusConflict, "已有导入任务正在运行") return } // 检查 S2A 配置 if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" { Error(w, http.StatusBadRequest, "S2A 配置未设置,请先配置 S2A API 地址和 Admin Key") return } var req rtImportRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { Error(w, http.StatusBadRequest, fmt.Sprintf("请求格式错误: %v", err)) return } // 参数校验 if len(req.Tokens) == 0 { Error(w, http.StatusBadRequest, "没有提供 Refresh Token") return } if req.Prefix != "team" && req.Prefix != "free" { Error(w, http.StatusBadRequest, "前缀必须是 team 或 free") return } // 默认值 if req.Concurrency <= 0 { if config.Global != nil && config.Global.Concurrency > 0 { req.Concurrency = config.Global.Concurrency } else { req.Concurrency = 3 } } if req.Priority <= 0 { if config.Global != nil && config.Global.Priority > 0 { req.Priority = config.Global.Priority } else { req.Priority = 10 } } if req.RateMultiplier <= 0 { req.RateMultiplier = 1.0 } if len(req.GroupIDs) == 0 && config.Global != nil && len(config.Global.GroupIDs) > 0 { req.GroupIDs = config.Global.GroupIDs } // 初始化状态 rtImportTaskState.Running = true rtImportTaskState.stopCh = make(chan struct{}) rtImportTaskState.StartedAt = time.Now() rtImportTaskState.Total = len(req.Tokens) rtImportTaskState.Completed = 0 rtImportTaskState.Success = 0 rtImportTaskState.Failed = 0 rtImportTaskState.Results = make([]rtImportResult, 0, len(req.Tokens)) // 异步执行 go runRTImport(req) logger.Info(fmt.Sprintf("RT 导入任务已启动: 共 %d 个 Token, 前缀: %s", len(req.Tokens), req.Prefix), "", "rt-import") Success(w, map[string]interface{}{ "message": "导入任务已启动", "total": len(req.Tokens), "prefix": req.Prefix, }) } // HandleRTImportStatus GET /api/rt-import/status - 获取导入状态 func HandleRTImportStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { Error(w, http.StatusMethodNotAllowed, "仅支持 GET") return } rtImportTaskState.mu.Lock() defer rtImportTaskState.mu.Unlock() elapsed := int64(0) if !rtImportTaskState.StartedAt.IsZero() { elapsed = time.Since(rtImportTaskState.StartedAt).Milliseconds() } Success(w, map[string]interface{}{ "running": rtImportTaskState.Running, "started_at": rtImportTaskState.StartedAt, "total": rtImportTaskState.Total, "completed": rtImportTaskState.Completed, "success": rtImportTaskState.Success, "failed": rtImportTaskState.Failed, "results": rtImportTaskState.Results, "elapsed_ms": elapsed, }) } // HandleRTImportStop POST /api/rt-import/stop - 停止导入 func HandleRTImportStop(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { Error(w, http.StatusMethodNotAllowed, "仅支持 POST") return } if !rtImportTaskState.Running { Error(w, http.StatusBadRequest, "没有正在运行的导入任务") return } rtImportTaskState.Running = false if rtImportTaskState.stopCh != nil { select { case <-rtImportTaskState.stopCh: // 已关闭 default: close(rtImportTaskState.stopCh) } } logger.Warning("RT 导入任务已收到停止信号", "", "rt-import") Success(w, map[string]string{"message": "已发送停止信号"}) } // runRTImport 执行批量导入 func runRTImport(req rtImportRequest) { defer func() { rtImportTaskState.Running = false completed := atomic.LoadInt32(&rtImportTaskState.Completed) success := atomic.LoadInt32(&rtImportTaskState.Success) failed := atomic.LoadInt32(&rtImportTaskState.Failed) logger.Success(fmt.Sprintf("RT 导入完成: 总数 %d, 成功 %d, 失败 %d", completed, success, failed), "", "rt-import") }() for i, rt := range req.Tokens { // 检查停止信号 if isRTImportStopped() { logger.Warning(fmt.Sprintf("RT 导入已停止,跳过剩余 %d 个 Token", len(req.Tokens)-i), "", "rt-import") break } result := processOneRT(i, rt, req) rtImportTaskState.mu.Lock() rtImportTaskState.Results = append(rtImportTaskState.Results, result) rtImportTaskState.mu.Unlock() atomic.AddInt32(&rtImportTaskState.Completed, 1) if result.Success { atomic.AddInt32(&rtImportTaskState.Success, 1) } else { atomic.AddInt32(&rtImportTaskState.Failed, 1) } // 限速,避免请求过快 time.Sleep(500 * time.Millisecond) } } // processOneRT 处理单条 RT: 验证 → 创建账号 func processOneRT(index int, rt string, req rtImportRequest) rtImportResult { result := rtImportResult{ Index: index + 1, RT: maskRT(rt), } logger.Info(fmt.Sprintf("[%d/%d] 开始处理 RT: %s", index+1, req.Prefix, maskRT(rt)), "", "rt-import") // Step 1: 通过 S2A 验证 RT tokenInfo, err := validateRTViaS2A(rt, req.ProxyID) if err != nil { result.Error = fmt.Sprintf("验证失败: %v", err) logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), "", "rt-import") return result } email, _ := tokenInfo["email"].(string) if email == "" { email = "unknown" } result.Email = email logger.Info(fmt.Sprintf("[%d/%d] 验证成功: %s", index+1, len(req.Tokens), email), email, "rt-import") // Step 2: 通过 S2A 创建账号 acctID, err := createAccountViaS2A(tokenInfo, req) if err != nil { result.Error = fmt.Sprintf("创建账号失败: %v", err) logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), email, "rt-import") return result } result.Success = true result.AcctID = acctID logger.Success(fmt.Sprintf("[%d/%d] 账号创建成功: %s-%s (ID: %d)", index+1, len(req.Tokens), req.Prefix, email, acctID), email, "rt-import") return result } // validateRTViaS2A 通过 S2A API 验证 Refresh Token func validateRTViaS2A(rt string, proxyID *int) (map[string]interface{}, error) { payload := map[string]interface{}{ "refresh_token": rt, } if proxyID != nil { payload["proxy_id"] = *proxyID } body, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("序列化请求失败: %v", err) } url := config.Global.S2AApiBase + "/api/v1/admin/openai/refresh-token" httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("创建请求失败: %v", err) } setS2AHeaders(httpReq) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(httpReq) if err != nil { return nil, fmt.Errorf("请求失败: %v", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %v", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("解析响应失败: %v", err) } // 检查 S2A 包装的 data 字段 if data, ok := result["data"].(map[string]interface{}); ok { return data, nil } return result, nil } // createAccountViaS2A 通过 S2A API 创建账号 func createAccountViaS2A(tokenInfo map[string]interface{}, req rtImportRequest) (int, error) { email, _ := tokenInfo["email"].(string) if email == "" { email = "unknown" } name := fmt.Sprintf("%s-%s", req.Prefix, email) // 构建 credentials credentials := map[string]interface{}{} for _, key := range []string{"access_token", "refresh_token", "token_type", "expires_in", "expires_at", "scope"} { if v, ok := tokenInfo[key]; ok && v != nil { credentials[key] = v } } // 可选字段 for _, key := range []string{"chatgpt_account_id", "chatgpt_user_id", "organization_id"} { if v, ok := tokenInfo[key]; ok && v != nil { credentials[key] = v } } // 构建 extra extra := map[string]interface{}{} if email != "" && email != "unknown" { extra["email"] = email } if nameVal, ok := tokenInfo["name"]; ok && nameVal != nil { extra["name"] = nameVal } payload := map[string]interface{}{ "name": name, "platform": "openai", "type": "oauth", "credentials": credentials, "concurrency": req.Concurrency, "priority": req.Priority, "group_ids": req.GroupIDs, "rate_multiplier": req.RateMultiplier, } if len(extra) > 0 { payload["extra"] = extra } if req.ProxyID != nil { payload["proxy_id"] = *req.ProxyID } body, err := json.Marshal(payload) if err != nil { return 0, fmt.Errorf("序列化请求失败: %v", err) } url := config.Global.S2AApiBase + "/api/v1/admin/accounts" httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body)) if err != nil { return 0, fmt.Errorf("创建请求失败: %v", err) } setS2AHeaders(httpReq) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(httpReq) if err != nil { return 0, fmt.Errorf("请求失败: %v", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return 0, fmt.Errorf("读取响应失败: %v", err) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) } var result map[string]interface{} if err := json.Unmarshal(respBody, &result); err != nil { return 0, fmt.Errorf("解析响应失败: %v", err) } // 从响应中提取账号ID if data, ok := result["data"].(map[string]interface{}); ok { if id, ok := data["id"].(float64); ok { return int(id), nil } } if id, ok := result["id"].(float64); ok { return int(id), nil } return 0, nil }