package api import ( "encoding/json" "fmt" "io" "net/http" "strings" "sync" "unicode" "codex-pool/internal/client" "codex-pool/internal/config" "codex-pool/internal/database" "codex-pool/internal/logger" ) type uploadValidateRequest struct { Content string `json:"content"` Accounts []accountRecord `json:"accounts"` } type accountRecord struct { Account string `json:"account"` Email string `json:"email"` Password string `json:"password"` Token string `json:"token"` AccessTok string `json:"access_token"` AccountID string `json:"account_id"` } // HandleUploadValidate 处理上传验证请求 func HandleUploadValidate(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { Error(w, http.StatusMethodNotAllowed, "仅支持 POST") return } if database.Instance == nil { Error(w, http.StatusInternalServerError, "数据库未初始化") return } body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 10<<20)) if err != nil { Error(w, http.StatusBadRequest, "读取请求失败") return } var req uploadValidateRequest if err := json.Unmarshal(body, &req); err != nil { // 如果不是 JSON,直接把 body 当作原始内容 req.Content = string(body) } var records []accountRecord switch { case len(req.Accounts) > 0: records = req.Accounts case strings.TrimSpace(req.Content) != "": parsed, parseErr := parseAccountsFlexible(req.Content) if parseErr != nil { Error(w, http.StatusBadRequest, parseErr.Error()) return } records = parsed default: Error(w, http.StatusBadRequest, "未提供账号内容") return } owners := make([]database.TeamOwner, 0, len(records)) for i, rec := range records { owner, err := normalizeOwnerBasic(rec, i+1) if err != nil { Error(w, http.StatusBadRequest, err.Error()) return } owners = append(owners, owner) } if len(owners) == 0 { Error(w, http.StatusBadRequest, "未解析到有效账号") return } // 统计需要获取 account_id 的数量 needFetch := 0 for _, o := range owners { if o.AccountID == "" { needFetch++ } } // 并发获取 account_id (使用20并发) if needFetch > 0 { logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload") fetchAccountIDsConcurrent(owners, 20) logger.Info("account_id 获取完成", "", "upload") } inserted, err := database.Instance.AddTeamOwners(owners) if err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err)) return } // 统计获取结果 successCount := 0 failCount := 0 for _, o := range owners { if o.AccountID != "" { successCount++ } else { failCount++ } } // 输出导入日志 logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d", inserted, len(owners), successCount, failCount), "", "upload") stats := database.Instance.GetOwnerStats() Success(w, map[string]interface{}{ "imported": inserted, "total": len(owners), "stats": stats, "account_id_ok": successCount, "account_id_fail": failCount, }) } // HandleRefetchAccountIDs 重新获取缺少 account_id 的母号 func HandleRefetchAccountIDs(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { Error(w, http.StatusMethodNotAllowed, "仅支持 POST") return } if database.Instance == nil { Error(w, http.StatusInternalServerError, "数据库未初始化") return } // 获取所有缺少 account_id 的 owners owners, err := database.Instance.GetOwnersWithoutAccountID() if err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("查询数据库失败: %v", err)) return } if len(owners) == 0 { Success(w, map[string]interface{}{ "message": "所有母号都已有 account_id", "total": 0, "success": 0, "fail": 0, }) return } logger.Info(fmt.Sprintf("开始重新获取 account_id: 共 %d 个", len(owners)), "", "upload") // 并发获取 account_id var wg sync.WaitGroup sem := make(chan struct{}, 20) // 20 并发 var mu sync.Mutex successCount := 0 failCount := 0 for _, owner := range owners { wg.Add(1) sem <- struct{}{} go func(o database.TeamOwner) { defer wg.Done() defer func() { <-sem }() accountID, err := fetchAccountID(o.Token) mu.Lock() defer mu.Unlock() if err != nil { failCount++ logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", o.Email, err), "", "upload") } else { // 更新数据库 if updateErr := database.Instance.UpdateOwnerAccountID(o.ID, accountID); updateErr != nil { failCount++ logger.Error(fmt.Sprintf("更新 account_id 失败 (%s): %v", o.Email, updateErr), "", "upload") } else { successCount++ logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", o.Email, accountID), "", "upload") } } }(owner) } wg.Wait() logger.Info(fmt.Sprintf("重新获取 account_id 完成: 成功=%d, 失败=%d", successCount, failCount), "", "upload") Success(w, map[string]interface{}{ "message": "重新获取 account_id 完成", "total": len(owners), "success": successCount, "fail": failCount, }) } // fetchAccountIDsConcurrent 并发获取 account_id func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) { var wg sync.WaitGroup sem := make(chan struct{}, concurrency) var mu sync.Mutex completed := 0 total := 0 // 统计需要获取的数量 for _, o := range owners { if o.AccountID == "" { total++ } } for i := range owners { if owners[i].AccountID != "" { continue } wg.Add(1) sem <- struct{}{} // 获取信号量 go func(idx int) { defer wg.Done() defer func() { <-sem }() // 释放信号量 email := owners[idx].Email token := owners[idx].Token accountID, err := fetchAccountID(token) mu.Lock() completed++ progress := completed mu.Unlock() if err != nil { logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload") } else { mu.Lock() owners[idx].AccountID = accountID mu.Unlock() logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload") } }(i) } wg.Wait() } // normalizeOwnerBasic 基础验证,不获取 account_id func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) { email := strings.TrimSpace(rec.Email) if email == "" { email = strings.TrimSpace(rec.Account) } if email == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index) } password := strings.TrimSpace(rec.Password) if password == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index) } token := strings.TrimSpace(rec.Token) if token == "" { token = strings.TrimSpace(rec.AccessTok) } if token == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index) } accountID := strings.TrimSpace(rec.AccountID) return database.TeamOwner{ Email: email, Password: password, Token: token, AccountID: accountID, }, nil } func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) { email := strings.TrimSpace(rec.Email) if email == "" { email = strings.TrimSpace(rec.Account) } if email == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index) } password := strings.TrimSpace(rec.Password) if password == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index) } token := strings.TrimSpace(rec.Token) if token == "" { token = strings.TrimSpace(rec.AccessTok) } if token == "" { return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index) } accountID := strings.TrimSpace(rec.AccountID) // 如果 account_id 为空,自动通过 API 获取 if accountID == "" { fetchedID, err := fetchAccountID(token) if err != nil { logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload") } else { accountID = fetchedID logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload") } } return database.TeamOwner{ Email: email, Password: password, Token: token, AccountID: accountID, }, nil } // fetchAccountID 通过 token 获取 account_id func fetchAccountID(token string) (string, error) { // 使用配置中的代理(如果启用) proxy := "" if cfg := config.Get(); cfg != nil { proxy = cfg.GetProxy() } tlsClient, err := client.New(proxy) if err != nil { return "", fmt.Errorf("创建 TLS 客户端失败: %v", err) } req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil) if err != nil { return "", fmt.Errorf("创建请求失败: %v", err) } req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Content-Type", "application/json") resp, err := tlsClient.Do(req) if err != nil { return "", fmt.Errorf("请求失败: %v", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("读取响应失败: %v", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("API 返回状态码: %d, 响应: %s", resp.StatusCode, string(body)[:min(200, len(body))]) } // 解析响应 var result struct { Accounts map[string]struct { Account struct { ID string `json:"id"` PlanType string `json:"plan_type"` } `json:"account"` } `json:"accounts"` } if err := json.Unmarshal(body, &result); err != nil { return "", fmt.Errorf("解析响应失败: %v", err) } // 优先查找 plan_type 为 "team" 的账户 // 注意:account_id 是 map 的 key,而不是 Account.ID 字段 for accountID, info := range result.Accounts { if accountID == "default" { continue } if strings.Contains(strings.ToLower(info.Account.PlanType), "team") { return accountID, nil } } // 否则取第一个非 "default" 的账户 for accountID := range result.Accounts { if accountID != "default" { return accountID, nil } } return "", fmt.Errorf("未找到有效的 account_id") } func parseAccountsFlexible(raw string) ([]accountRecord, error) { raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF")) if raw == "" { return nil, fmt.Errorf("内容为空") } cleaned := stripJSONComments(raw) cleaned = removeTrailingCommas(cleaned) trimmed := strings.TrimSpace(cleaned) if trimmed == "" { return nil, fmt.Errorf("内容为空") } if strings.HasPrefix(trimmed, "[") { var arr []accountRecord if err := json.Unmarshal([]byte(trimmed), &arr); err == nil { return arr, nil } } else if strings.HasPrefix(trimmed, "{") { var single accountRecord if err := json.Unmarshal([]byte(trimmed), &single); err == nil { return []accountRecord{single}, nil } } // JSONL 回退 lines := strings.Split(raw, "\n") records := make([]accountRecord, 0, len(lines)) for i, line := range lines { line = strings.TrimSpace(stripJSONComments(line)) if line == "" { continue } if strings.HasPrefix(line, "#") { continue } line = strings.TrimSpace(removeTrailingCommas(line)) if line == "" { continue } var rec accountRecord if err := json.Unmarshal([]byte(line), &rec); err != nil { return nil, fmt.Errorf("第 %d 行解析失败: %v", i+1, err) } records = append(records, rec) } if len(records) == 0 { return nil, fmt.Errorf("未解析到有效账号") } return records, nil } func stripJSONComments(input string) string { var b strings.Builder b.Grow(len(input)) inString := false escaped := false for i := 0; i < len(input); i++ { ch := input[i] if inString { b.WriteByte(ch) if escaped { escaped = false continue } if ch == '\\' { escaped = true } else if ch == '"' { inString = false } continue } if ch == '"' { inString = true b.WriteByte(ch) continue } if ch == '/' && i+1 < len(input) && input[i+1] == '/' { for i+1 < len(input) && input[i+1] != '\n' { i++ } continue } if ch == '/' && i+1 < len(input) && input[i+1] == '*' { i += 2 for i+1 < len(input) { if input[i] == '*' && input[i+1] == '/' { i++ break } i++ } continue } b.WriteByte(ch) } return b.String() } func removeTrailingCommas(input string) string { var b strings.Builder b.Grow(len(input)) inString := false escaped := false for i := 0; i < len(input); i++ { ch := input[i] if inString { b.WriteByte(ch) if escaped { escaped = false continue } if ch == '\\' { escaped = true } else if ch == '"' { inString = false } continue } if ch == '"' { inString = true b.WriteByte(ch) continue } if ch == ',' { j := i + 1 for j < len(input) && unicode.IsSpace(rune(input[j])) { j++ } if j < len(input) && (input[j] == ']' || input[j] == '}') { continue } } b.WriteByte(ch) } return b.String() }