From cbf65ba74f21f4f7e6603d496fcc0affc1bd3e0f Mon Sep 17 00:00:00 2001 From: kyx236 Date: Fri, 30 Jan 2026 12:37:18 +0800 Subject: [PATCH] feat: Add account upload functionality with backend validation, concurrent ID fetching, and a new frontend page. --- backend/internal/api/upload.go | 201 ++++++++++++++++++++++++++++++++- frontend/src/pages/Upload.tsx | 50 +++++++- 2 files changed, 244 insertions(+), 7 deletions(-) diff --git a/backend/internal/api/upload.go b/backend/internal/api/upload.go index 79dcc57..0541b63 100644 --- a/backend/internal/api/upload.go +++ b/backend/internal/api/upload.go @@ -6,9 +6,12 @@ import ( "io" "net/http" "strings" + "sync" + "time" "unicode" "codex-pool/internal/database" + "codex-pool/internal/logger" ) type uploadValidateRequest struct { @@ -66,7 +69,7 @@ func HandleUploadValidate(w http.ResponseWriter, r *http.Request) { owners := make([]database.TeamOwner, 0, len(records)) for i, rec := range records { - owner, err := normalizeOwner(rec, i+1) + owner, err := normalizeOwnerBasic(rec, i+1) if err != nil { Error(w, http.StatusBadRequest, err.Error()) return @@ -78,21 +81,105 @@ func HandleUploadValidate(w http.ResponseWriter, r *http.Request) { return } + // 统计需要获取 account_id 的数量 + needFetch := 0 + for _, o := range owners { + if o.AccountID == "" { + needFetch++ + } + } + + // 并发获取 account_id (使用20并发) + if needFetch > 0 { + logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload") + fetchAccountIDsConcurrent(owners, 20) + logger.Info("account_id 获取完成", "", "upload") + } + inserted, err := database.Instance.AddTeamOwners(owners) if err != nil { Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err)) return } + // 统计获取结果 + successCount := 0 + failCount := 0 + for _, o := range owners { + if o.AccountID != "" { + successCount++ + } else { + failCount++ + } + } + + // 输出导入日志 + logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d", + inserted, len(owners), successCount, failCount), "", "upload") + stats := database.Instance.GetOwnerStats() Success(w, map[string]interface{}{ - "imported": inserted, - "total": len(owners), - "stats": stats, + "imported": inserted, + "total": len(owners), + "stats": stats, + "account_id_ok": successCount, + "account_id_fail": failCount, }) } -func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) { +// fetchAccountIDsConcurrent 并发获取 account_id +func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) { + var wg sync.WaitGroup + sem := make(chan struct{}, concurrency) + var mu sync.Mutex + completed := 0 + total := 0 + + // 统计需要获取的数量 + for _, o := range owners { + if o.AccountID == "" { + total++ + } + } + + for i := range owners { + if owners[i].AccountID != "" { + continue + } + + wg.Add(1) + sem <- struct{}{} // 获取信号量 + + go func(idx int) { + defer wg.Done() + defer func() { <-sem }() // 释放信号量 + + email := owners[idx].Email + token := owners[idx].Token + + accountID, err := fetchAccountID(token) + + mu.Lock() + completed++ + progress := completed + mu.Unlock() + + if err != nil { + logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload") + } else { + mu.Lock() + owners[idx].AccountID = accountID + mu.Unlock() + logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload") + } + }(i) + } + + wg.Wait() +} + +// normalizeOwnerBasic 基础验证,不获取 account_id +func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) { email := strings.TrimSpace(rec.Email) if email == "" { email = strings.TrimSpace(rec.Account) @@ -124,6 +211,110 @@ func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) { }, nil } +func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) { + email := strings.TrimSpace(rec.Email) + if email == "" { + email = strings.TrimSpace(rec.Account) + } + if email == "" { + return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index) + } + + password := strings.TrimSpace(rec.Password) + if password == "" { + return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index) + } + + token := strings.TrimSpace(rec.Token) + if token == "" { + token = strings.TrimSpace(rec.AccessTok) + } + if token == "" { + return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index) + } + + accountID := strings.TrimSpace(rec.AccountID) + // 如果 account_id 为空,自动通过 API 获取 + if accountID == "" { + fetchedID, err := fetchAccountID(token) + if err != nil { + logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload") + } else { + accountID = fetchedID + logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload") + } + } + + return database.TeamOwner{ + Email: email, + Password: password, + Token: token, + AccountID: accountID, + }, nil +} + +// fetchAccountID 通过 token 获取 account_id +func fetchAccountID(token string) (string, error) { + client := &http.Client{Timeout: 15 * time.Second} + + req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil) + if err != nil { + return "", fmt.Errorf("创建请求失败: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API 返回状态码: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("读取响应失败: %v", err) + } + + // 解析响应 + var result struct { + Accounts map[string]struct { + Account struct { + ID string `json:"id"` + PlanType string `json:"plan_type"` + } `json:"account"` + } `json:"accounts"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("解析响应失败: %v", err) + } + + // 优先查找 plan_type 包含 "team" 的账户 + for key, acc := range result.Accounts { + if key == "default" { + continue + } + if strings.Contains(strings.ToLower(acc.Account.PlanType), "team") { + return acc.Account.ID, nil + } + } + + // 否则取第一个非 "default" 的账户 ID + for key, acc := range result.Accounts { + if key != "default" && acc.Account.ID != "" { + return acc.Account.ID, nil + } + } + + return "", fmt.Errorf("未找到有效的 account_id") +} + func parseAccountsFlexible(raw string) ([]accountRecord, error) { raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF")) if raw == "" { diff --git a/frontend/src/pages/Upload.tsx b/frontend/src/pages/Upload.tsx index 66a1fad..423c8f2 100644 --- a/frontend/src/pages/Upload.tsx +++ b/frontend/src/pages/Upload.tsx @@ -58,6 +58,12 @@ export default function Upload() { const [status, setStatus] = useState(null) const [polling, setPolling] = useState(false) const [loading, setLoading] = useState(false) + const [importResult, setImportResult] = useState<{ + imported: number + total: number + account_id_ok: number + account_id_fail: number + } | null>(null) // 配置 const [membersPerTeam, setMembersPerTeam] = useState(4) @@ -117,6 +123,7 @@ export default function Upload() { async (file: File) => { setFileError(null) setValidating(true) + setImportResult(null) try { const text = await file.text() @@ -128,6 +135,12 @@ export default function Upload() { const data = await res.json() if (data.code === 0) { + setImportResult({ + imported: data.data.imported, + total: data.data.total, + account_id_ok: data.data.account_id_ok || 0, + account_id_fail: data.data.account_id_fail || 0, + }) loadStats() } else { setFileError(data.message || '验证失败') @@ -300,9 +313,42 @@ export default function Upload() { error={fileError} /> {validating && ( -
+
- 正在验证账号... + 正在验证并获取 account_id (20并发)... +
+ )} + {importResult && !validating && ( +
+
+ + 导入完成 +
+
+
+
成功导入
+
{importResult.imported} / {importResult.total}
+
+
+
Account ID
+
+ {importResult.account_id_ok} + {importResult.account_id_fail > 0 && ( + / {importResult.account_id_fail} 失败 + )} +
+
+
+ {importResult.account_id_ok > 0 && ( +
+
+
+
+
+ )}
)}