Files
codexautopool/backend/internal/api/upload.go

553 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"unicode"
"codex-pool/internal/client"
"codex-pool/internal/config"
"codex-pool/internal/database"
"codex-pool/internal/logger"
)
type uploadValidateRequest struct {
Content string `json:"content"`
Accounts []accountRecord `json:"accounts"`
}
type accountRecord struct {
Account string `json:"account"`
Email string `json:"email"`
Password string `json:"password"`
Token string `json:"token"`
AccessTok string `json:"access_token"`
AccountID string `json:"account_id"`
}
// HandleUploadValidate 处理上传验证请求
func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if database.Instance == nil {
Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 10<<20))
if err != nil {
Error(w, http.StatusBadRequest, "读取请求失败")
return
}
var req uploadValidateRequest
if err := json.Unmarshal(body, &req); err != nil {
// 如果不是 JSON直接把 body 当作原始内容
req.Content = string(body)
}
var records []accountRecord
switch {
case len(req.Accounts) > 0:
records = req.Accounts
case strings.TrimSpace(req.Content) != "":
parsed, parseErr := parseAccountsFlexible(req.Content)
if parseErr != nil {
Error(w, http.StatusBadRequest, parseErr.Error())
return
}
records = parsed
default:
Error(w, http.StatusBadRequest, "未提供账号内容")
return
}
owners := make([]database.TeamOwner, 0, len(records))
for i, rec := range records {
owner, err := normalizeOwnerBasic(rec, i+1)
if err != nil {
Error(w, http.StatusBadRequest, err.Error())
return
}
owners = append(owners, owner)
}
if len(owners) == 0 {
Error(w, http.StatusBadRequest, "未解析到有效账号")
return
}
// 统计需要获取 account_id 的数量
needFetch := 0
for _, o := range owners {
if o.AccountID == "" {
needFetch++
}
}
// 并发获取 account_id (使用20并发)
if needFetch > 0 {
logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload")
fetchAccountIDsConcurrent(owners, 20)
logger.Info("account_id 获取完成", "", "upload")
}
inserted, err := database.Instance.AddTeamOwners(owners)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err))
return
}
// 统计获取结果
successCount := 0
failCount := 0
for _, o := range owners {
if o.AccountID != "" {
successCount++
} else {
failCount++
}
}
// 输出导入日志
logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d",
inserted, len(owners), successCount, failCount), "", "upload")
stats := database.Instance.GetOwnerStats()
Success(w, map[string]interface{}{
"imported": inserted,
"total": len(owners),
"stats": stats,
"account_id_ok": successCount,
"account_id_fail": failCount,
})
}
// HandleRefetchAccountIDs 重新获取缺少 account_id 的母号
func HandleRefetchAccountIDs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if database.Instance == nil {
Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
// 获取所有缺少 account_id 的 owners
owners, err := database.Instance.GetOwnersWithoutAccountID()
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("查询数据库失败: %v", err))
return
}
if len(owners) == 0 {
Success(w, map[string]interface{}{
"message": "所有母号都已有 account_id",
"total": 0,
"success": 0,
"fail": 0,
})
return
}
logger.Info(fmt.Sprintf("开始重新获取 account_id: 共 %d 个", len(owners)), "", "upload")
// 并发获取 account_id
var wg sync.WaitGroup
sem := make(chan struct{}, 20) // 20 并发
var mu sync.Mutex
successCount := 0
failCount := 0
for _, owner := range owners {
wg.Add(1)
sem <- struct{}{}
go func(o database.TeamOwner) {
defer wg.Done()
defer func() { <-sem }()
accountID, err := fetchAccountID(o.Token)
mu.Lock()
defer mu.Unlock()
if err != nil {
failCount++
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", o.Email, err), "", "upload")
} else {
// 更新数据库
if updateErr := database.Instance.UpdateOwnerAccountID(o.ID, accountID); updateErr != nil {
failCount++
logger.Error(fmt.Sprintf("更新 account_id 失败 (%s): %v", o.Email, updateErr), "", "upload")
} else {
successCount++
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", o.Email, accountID), "", "upload")
}
}
}(owner)
}
wg.Wait()
logger.Info(fmt.Sprintf("重新获取 account_id 完成: 成功=%d, 失败=%d", successCount, failCount), "", "upload")
Success(w, map[string]interface{}{
"message": "重新获取 account_id 完成",
"total": len(owners),
"success": successCount,
"fail": failCount,
})
}
// fetchAccountIDsConcurrent 并发获取 account_id
func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) {
var wg sync.WaitGroup
sem := make(chan struct{}, concurrency)
var mu sync.Mutex
completed := 0
total := 0
// 统计需要获取的数量
for _, o := range owners {
if o.AccountID == "" {
total++
}
}
for i := range owners {
if owners[i].AccountID != "" {
continue
}
wg.Add(1)
sem <- struct{}{} // 获取信号量
go func(idx int) {
defer wg.Done()
defer func() { <-sem }() // 释放信号量
email := owners[idx].Email
token := owners[idx].Token
accountID, err := fetchAccountID(token)
mu.Lock()
completed++
progress := completed
mu.Unlock()
if err != nil {
logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload")
} else {
mu.Lock()
owners[idx].AccountID = accountID
mu.Unlock()
logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload")
}
}(i)
}
wg.Wait()
}
// normalizeOwnerBasic 基础验证,不获取 account_id
func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
// 如果 account_id 为空,自动通过 API 获取
if accountID == "" {
fetchedID, err := fetchAccountID(token)
if err != nil {
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload")
} else {
accountID = fetchedID
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload")
}
}
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
// fetchAccountID 通过 token 获取 account_id
func fetchAccountID(token string) (string, error) {
// 使用配置中的代理(如果启用)
proxy := ""
if cfg := config.Get(); cfg != nil {
proxy = cfg.GetProxy()
}
tlsClient, err := client.New(proxy)
if err != nil {
return "", fmt.Errorf("创建 TLS 客户端失败: %v", err)
}
req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
resp, err := tlsClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API 返回状态码: %d, 响应: %s", resp.StatusCode, string(body)[:min(200, len(body))])
}
// 解析响应
var result struct {
Accounts map[string]struct {
Account struct {
ID string `json:"id"`
PlanType string `json:"plan_type"`
} `json:"account"`
} `json:"accounts"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
// 优先查找 plan_type 为 "team" 的账户
// 注意account_id 是 map 的 key而不是 Account.ID 字段
for accountID, info := range result.Accounts {
if accountID == "default" {
continue
}
if strings.Contains(strings.ToLower(info.Account.PlanType), "team") {
return accountID, nil
}
}
// 否则取第一个非 "default" 的账户
for accountID := range result.Accounts {
if accountID != "default" {
return accountID, nil
}
}
return "", fmt.Errorf("未找到有效的 account_id")
}
func parseAccountsFlexible(raw string) ([]accountRecord, error) {
raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF"))
if raw == "" {
return nil, fmt.Errorf("内容为空")
}
cleaned := stripJSONComments(raw)
cleaned = removeTrailingCommas(cleaned)
trimmed := strings.TrimSpace(cleaned)
if trimmed == "" {
return nil, fmt.Errorf("内容为空")
}
if strings.HasPrefix(trimmed, "[") {
var arr []accountRecord
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
return arr, nil
}
} else if strings.HasPrefix(trimmed, "{") {
var single accountRecord
if err := json.Unmarshal([]byte(trimmed), &single); err == nil {
return []accountRecord{single}, nil
}
}
// JSONL 回退
lines := strings.Split(raw, "\n")
records := make([]accountRecord, 0, len(lines))
for i, line := range lines {
line = strings.TrimSpace(stripJSONComments(line))
if line == "" {
continue
}
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(removeTrailingCommas(line))
if line == "" {
continue
}
var rec accountRecord
if err := json.Unmarshal([]byte(line), &rec); err != nil {
return nil, fmt.Errorf("第 %d 行解析失败: %v", i+1, err)
}
records = append(records, rec)
}
if len(records) == 0 {
return nil, fmt.Errorf("未解析到有效账号")
}
return records, nil
}
func stripJSONComments(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '/' {
for i+1 < len(input) && input[i+1] != '\n' {
i++
}
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '*' {
i += 2
for i+1 < len(input) {
if input[i] == '*' && input[i+1] == '/' {
i++
break
}
i++
}
continue
}
b.WriteByte(ch)
}
return b.String()
}
func removeTrailingCommas(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == ',' {
j := i + 1
for j < len(input) && unicode.IsSpace(rune(input[j])) {
j++
}
if j < len(input) && (input[j] == ']' || input[j] == '}') {
continue
}
}
b.WriteByte(ch)
}
return b.String()
}