feat: Add account upload functionality with backend validation, concurrent ID fetching, and a new frontend page.
This commit is contained in:
@@ -6,9 +6,12 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"codex-pool/internal/database"
|
||||
"codex-pool/internal/logger"
|
||||
)
|
||||
|
||||
type uploadValidateRequest struct {
|
||||
@@ -66,7 +69,7 @@ func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
owners := make([]database.TeamOwner, 0, len(records))
|
||||
for i, rec := range records {
|
||||
owner, err := normalizeOwner(rec, i+1)
|
||||
owner, err := normalizeOwnerBasic(rec, i+1)
|
||||
if err != nil {
|
||||
Error(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
@@ -78,21 +81,105 @@ func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 统计需要获取 account_id 的数量
|
||||
needFetch := 0
|
||||
for _, o := range owners {
|
||||
if o.AccountID == "" {
|
||||
needFetch++
|
||||
}
|
||||
}
|
||||
|
||||
// 并发获取 account_id (使用20并发)
|
||||
if needFetch > 0 {
|
||||
logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload")
|
||||
fetchAccountIDsConcurrent(owners, 20)
|
||||
logger.Info("account_id 获取完成", "", "upload")
|
||||
}
|
||||
|
||||
inserted, err := database.Instance.AddTeamOwners(owners)
|
||||
if err != nil {
|
||||
Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 统计获取结果
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, o := range owners {
|
||||
if o.AccountID != "" {
|
||||
successCount++
|
||||
} else {
|
||||
failCount++
|
||||
}
|
||||
}
|
||||
|
||||
// 输出导入日志
|
||||
logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d",
|
||||
inserted, len(owners), successCount, failCount), "", "upload")
|
||||
|
||||
stats := database.Instance.GetOwnerStats()
|
||||
Success(w, map[string]interface{}{
|
||||
"imported": inserted,
|
||||
"total": len(owners),
|
||||
"stats": stats,
|
||||
"account_id_ok": successCount,
|
||||
"account_id_fail": failCount,
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
|
||||
// fetchAccountIDsConcurrent 并发获取 account_id
|
||||
func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) {
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, concurrency)
|
||||
var mu sync.Mutex
|
||||
completed := 0
|
||||
total := 0
|
||||
|
||||
// 统计需要获取的数量
|
||||
for _, o := range owners {
|
||||
if o.AccountID == "" {
|
||||
total++
|
||||
}
|
||||
}
|
||||
|
||||
for i := range owners {
|
||||
if owners[i].AccountID != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
sem <- struct{}{} // 获取信号量
|
||||
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }() // 释放信号量
|
||||
|
||||
email := owners[idx].Email
|
||||
token := owners[idx].Token
|
||||
|
||||
accountID, err := fetchAccountID(token)
|
||||
|
||||
mu.Lock()
|
||||
completed++
|
||||
progress := completed
|
||||
mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload")
|
||||
} else {
|
||||
mu.Lock()
|
||||
owners[idx].AccountID = accountID
|
||||
mu.Unlock()
|
||||
logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload")
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// normalizeOwnerBasic 基础验证,不获取 account_id
|
||||
func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) {
|
||||
email := strings.TrimSpace(rec.Email)
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(rec.Account)
|
||||
@@ -124,6 +211,110 @@ func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
|
||||
email := strings.TrimSpace(rec.Email)
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(rec.Account)
|
||||
}
|
||||
if email == "" {
|
||||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
|
||||
}
|
||||
|
||||
password := strings.TrimSpace(rec.Password)
|
||||
if password == "" {
|
||||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(rec.Token)
|
||||
if token == "" {
|
||||
token = strings.TrimSpace(rec.AccessTok)
|
||||
}
|
||||
if token == "" {
|
||||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
|
||||
}
|
||||
|
||||
accountID := strings.TrimSpace(rec.AccountID)
|
||||
// 如果 account_id 为空,自动通过 API 获取
|
||||
if accountID == "" {
|
||||
fetchedID, err := fetchAccountID(token)
|
||||
if err != nil {
|
||||
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload")
|
||||
} else {
|
||||
accountID = fetchedID
|
||||
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload")
|
||||
}
|
||||
}
|
||||
|
||||
return database.TeamOwner{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Token: token,
|
||||
AccountID: accountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchAccountID 通过 token 获取 account_id
|
||||
func fetchAccountID(token string) (string, error) {
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
|
||||
req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("API 返回状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result struct {
|
||||
Accounts map[string]struct {
|
||||
Account struct {
|
||||
ID string `json:"id"`
|
||||
PlanType string `json:"plan_type"`
|
||||
} `json:"account"`
|
||||
} `json:"accounts"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 优先查找 plan_type 包含 "team" 的账户
|
||||
for key, acc := range result.Accounts {
|
||||
if key == "default" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(acc.Account.PlanType), "team") {
|
||||
return acc.Account.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 否则取第一个非 "default" 的账户 ID
|
||||
for key, acc := range result.Accounts {
|
||||
if key != "default" && acc.Account.ID != "" {
|
||||
return acc.Account.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("未找到有效的 account_id")
|
||||
}
|
||||
|
||||
func parseAccountsFlexible(raw string) ([]accountRecord, error) {
|
||||
raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF"))
|
||||
if raw == "" {
|
||||
|
||||
@@ -58,6 +58,12 @@ export default function Upload() {
|
||||
const [status, setStatus] = useState<ProcessStatus | null>(null)
|
||||
const [polling, setPolling] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [importResult, setImportResult] = useState<{
|
||||
imported: number
|
||||
total: number
|
||||
account_id_ok: number
|
||||
account_id_fail: number
|
||||
} | null>(null)
|
||||
|
||||
// 配置
|
||||
const [membersPerTeam, setMembersPerTeam] = useState(4)
|
||||
@@ -117,6 +123,7 @@ export default function Upload() {
|
||||
async (file: File) => {
|
||||
setFileError(null)
|
||||
setValidating(true)
|
||||
setImportResult(null)
|
||||
|
||||
try {
|
||||
const text = await file.text()
|
||||
@@ -128,6 +135,12 @@ export default function Upload() {
|
||||
|
||||
const data = await res.json()
|
||||
if (data.code === 0) {
|
||||
setImportResult({
|
||||
imported: data.data.imported,
|
||||
total: data.data.total,
|
||||
account_id_ok: data.data.account_id_ok || 0,
|
||||
account_id_fail: data.data.account_id_fail || 0,
|
||||
})
|
||||
loadStats()
|
||||
} else {
|
||||
setFileError(data.message || '验证失败')
|
||||
@@ -300,9 +313,42 @@ export default function Upload() {
|
||||
error={fileError}
|
||||
/>
|
||||
{validating && (
|
||||
<div className="mt-3 flex items-center gap-2 text-blue-500 bg-blue-50 dark:bg-blue-900/20 p-2 rounded-lg text-sm">
|
||||
<div className="mt-3 flex items-center gap-2 text-blue-500 bg-blue-50 dark:bg-blue-900/20 p-3 rounded-lg text-sm">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>正在验证账号...</span>
|
||||
<span>正在验证并获取 account_id (20并发)...</span>
|
||||
</div>
|
||||
)}
|
||||
{importResult && !validating && (
|
||||
<div className="mt-3 p-3 bg-green-50 dark:bg-green-900/20 border border-green-200 dark:border-green-800 rounded-lg">
|
||||
<div className="flex items-center gap-2 text-green-600 dark:text-green-400 font-medium mb-2">
|
||||
<CheckCircle className="h-4 w-4" />
|
||||
<span>导入完成</span>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-3 text-sm">
|
||||
<div className="p-2 bg-white dark:bg-slate-800 rounded">
|
||||
<div className="text-slate-500 text-xs">成功导入</div>
|
||||
<div className="font-bold text-green-600">{importResult.imported} / {importResult.total}</div>
|
||||
</div>
|
||||
<div className="p-2 bg-white dark:bg-slate-800 rounded">
|
||||
<div className="text-slate-500 text-xs">Account ID</div>
|
||||
<div className="font-bold">
|
||||
<span className="text-green-600">{importResult.account_id_ok}</span>
|
||||
{importResult.account_id_fail > 0 && (
|
||||
<span className="text-red-500 ml-1">/ {importResult.account_id_fail} 失败</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{importResult.account_id_ok > 0 && (
|
||||
<div className="mt-2">
|
||||
<div className="h-2 bg-slate-200 dark:bg-slate-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-green-500 rounded-full transition-all duration-500"
|
||||
style={{ width: `${(importResult.account_id_ok / importResult.total) * 100}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
|
||||
Reference in New Issue
Block a user