Files
codexautopool/backend/internal/api/upload.go

467 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"unicode"
"codex-pool/internal/client"
"codex-pool/internal/database"
"codex-pool/internal/logger"
)
type uploadValidateRequest struct {
Content string `json:"content"`
Accounts []accountRecord `json:"accounts"`
}
type accountRecord struct {
Account string `json:"account"`
Email string `json:"email"`
Password string `json:"password"`
Token string `json:"token"`
AccessTok string `json:"access_token"`
AccountID string `json:"account_id"`
}
// HandleUploadValidate 处理上传验证请求
func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if database.Instance == nil {
Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 10<<20))
if err != nil {
Error(w, http.StatusBadRequest, "读取请求失败")
return
}
var req uploadValidateRequest
if err := json.Unmarshal(body, &req); err != nil {
// 如果不是 JSON直接把 body 当作原始内容
req.Content = string(body)
}
var records []accountRecord
switch {
case len(req.Accounts) > 0:
records = req.Accounts
case strings.TrimSpace(req.Content) != "":
parsed, parseErr := parseAccountsFlexible(req.Content)
if parseErr != nil {
Error(w, http.StatusBadRequest, parseErr.Error())
return
}
records = parsed
default:
Error(w, http.StatusBadRequest, "未提供账号内容")
return
}
owners := make([]database.TeamOwner, 0, len(records))
for i, rec := range records {
owner, err := normalizeOwnerBasic(rec, i+1)
if err != nil {
Error(w, http.StatusBadRequest, err.Error())
return
}
owners = append(owners, owner)
}
if len(owners) == 0 {
Error(w, http.StatusBadRequest, "未解析到有效账号")
return
}
// 统计需要获取 account_id 的数量
needFetch := 0
for _, o := range owners {
if o.AccountID == "" {
needFetch++
}
}
// 并发获取 account_id (使用20并发)
if needFetch > 0 {
logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload")
fetchAccountIDsConcurrent(owners, 20)
logger.Info("account_id 获取完成", "", "upload")
}
inserted, err := database.Instance.AddTeamOwners(owners)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err))
return
}
// 统计获取结果
successCount := 0
failCount := 0
for _, o := range owners {
if o.AccountID != "" {
successCount++
} else {
failCount++
}
}
// 输出导入日志
logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d",
inserted, len(owners), successCount, failCount), "", "upload")
stats := database.Instance.GetOwnerStats()
Success(w, map[string]interface{}{
"imported": inserted,
"total": len(owners),
"stats": stats,
"account_id_ok": successCount,
"account_id_fail": failCount,
})
}
// fetchAccountIDsConcurrent 并发获取 account_id
func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) {
var wg sync.WaitGroup
sem := make(chan struct{}, concurrency)
var mu sync.Mutex
completed := 0
total := 0
// 统计需要获取的数量
for _, o := range owners {
if o.AccountID == "" {
total++
}
}
for i := range owners {
if owners[i].AccountID != "" {
continue
}
wg.Add(1)
sem <- struct{}{} // 获取信号量
go func(idx int) {
defer wg.Done()
defer func() { <-sem }() // 释放信号量
email := owners[idx].Email
token := owners[idx].Token
accountID, err := fetchAccountID(token)
mu.Lock()
completed++
progress := completed
mu.Unlock()
if err != nil {
logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload")
} else {
mu.Lock()
owners[idx].AccountID = accountID
mu.Unlock()
logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload")
}
}(i)
}
wg.Wait()
}
// normalizeOwnerBasic 基础验证,不获取 account_id
func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
// 如果 account_id 为空,自动通过 API 获取
if accountID == "" {
fetchedID, err := fetchAccountID(token)
if err != nil {
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload")
} else {
accountID = fetchedID
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload")
}
}
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
// fetchAccountID 通过 token 获取 account_id
func fetchAccountID(token string) (string, error) {
tlsClient, err := client.New("")
if err != nil {
return "", fmt.Errorf("创建 TLS 客户端失败: %v", err)
}
req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
resp, err := tlsClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API 返回状态码: %d", resp.StatusCode)
}
body, err := client.ReadBody(resp)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
// 解析响应
var result struct {
Accounts map[string]struct {
Account struct {
ID string `json:"id"`
PlanType string `json:"plan_type"`
} `json:"account"`
} `json:"accounts"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
// 优先查找 plan_type 包含 "team" 的账户
for key, acc := range result.Accounts {
if key == "default" {
continue
}
if strings.Contains(strings.ToLower(acc.Account.PlanType), "team") {
return acc.Account.ID, nil
}
}
// 否则取第一个非 "default" 的账户 ID
for key, acc := range result.Accounts {
if key != "default" && acc.Account.ID != "" {
return acc.Account.ID, nil
}
}
return "", fmt.Errorf("未找到有效的 account_id")
}
func parseAccountsFlexible(raw string) ([]accountRecord, error) {
raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF"))
if raw == "" {
return nil, fmt.Errorf("内容为空")
}
cleaned := stripJSONComments(raw)
cleaned = removeTrailingCommas(cleaned)
trimmed := strings.TrimSpace(cleaned)
if trimmed == "" {
return nil, fmt.Errorf("内容为空")
}
if strings.HasPrefix(trimmed, "[") {
var arr []accountRecord
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
return arr, nil
}
} else if strings.HasPrefix(trimmed, "{") {
var single accountRecord
if err := json.Unmarshal([]byte(trimmed), &single); err == nil {
return []accountRecord{single}, nil
}
}
// JSONL 回退
lines := strings.Split(raw, "\n")
records := make([]accountRecord, 0, len(lines))
for i, line := range lines {
line = strings.TrimSpace(stripJSONComments(line))
if line == "" {
continue
}
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(removeTrailingCommas(line))
if line == "" {
continue
}
var rec accountRecord
if err := json.Unmarshal([]byte(line), &rec); err != nil {
return nil, fmt.Errorf("第 %d 行解析失败: %v", i+1, err)
}
records = append(records, rec)
}
if len(records) == 0 {
return nil, fmt.Errorf("未解析到有效账号")
}
return records, nil
}
func stripJSONComments(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '/' {
for i+1 < len(input) && input[i+1] != '\n' {
i++
}
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '*' {
i += 2
for i+1 < len(input) {
if input[i] == '*' && input[i+1] == '/' {
i++
break
}
i++
}
continue
}
b.WriteByte(ch)
}
return b.String()
}
func removeTrailingCommas(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == ',' {
j := i + 1
for j < len(input) && unicode.IsSpace(rune(input[j])) {
j++
}
if j < len(input) && (input[j] == ']' || input[j] == '}') {
continue
}
}
b.WriteByte(ch)
}
return b.String()
}