Files
codexautopool/backend/internal/api/upload.go

587 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"unicode"
"codex-pool/internal/client"
"codex-pool/internal/config"
"codex-pool/internal/database"
"codex-pool/internal/logger"
)
type uploadValidateRequest struct {
Content string `json:"content"`
Accounts []accountRecord `json:"accounts"`
}
type accountRecord struct {
Account string `json:"account"`
Email string `json:"email"`
Password string `json:"password"`
Token string `json:"token"`
AccessTok string `json:"access_token"`
AccountID string `json:"account_id"`
}
// HandleUploadValidate 处理上传验证请求
func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if database.Instance == nil {
Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 10<<20))
if err != nil {
Error(w, http.StatusBadRequest, "读取请求失败")
return
}
var req uploadValidateRequest
if err := json.Unmarshal(body, &req); err != nil {
// 如果不是 JSON直接把 body 当作原始内容
req.Content = string(body)
}
var records []accountRecord
switch {
case len(req.Accounts) > 0:
records = req.Accounts
case strings.TrimSpace(req.Content) != "":
parsed, parseErr := parseAccountsFlexible(req.Content)
if parseErr != nil {
Error(w, http.StatusBadRequest, parseErr.Error())
return
}
records = parsed
default:
Error(w, http.StatusBadRequest, "未提供账号内容")
return
}
owners := make([]database.TeamOwner, 0, len(records))
for i, rec := range records {
owner, err := normalizeOwnerBasic(rec, i+1)
if err != nil {
Error(w, http.StatusBadRequest, err.Error())
return
}
owners = append(owners, owner)
}
if len(owners) == 0 {
Error(w, http.StatusBadRequest, "未解析到有效账号")
return
}
// 并发验证账号并获取 account_id只保留 team 账号)
logger.Status(fmt.Sprintf("验证账号中: 共 %d 个,只导入 plan 为 team 的账号", len(owners)), "", "upload")
validOwners, teamCount, nonTeamCount, failCount := validateAndFetchAccountIDs(owners, 20)
logger.Success(fmt.Sprintf("验证完成: team=%d, 非team=%d, 失败=%d", teamCount, nonTeamCount, failCount), "", "upload")
if len(validOwners) == 0 {
Error(w, http.StatusBadRequest, fmt.Sprintf("没有有效的 team 账号(共 %d 个非team: %d失败: %d", len(owners), nonTeamCount, failCount))
return
}
inserted, err := database.Instance.AddTeamOwners(validOwners)
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err))
return
}
// 输出导入日志
logger.Info(fmt.Sprintf("母号导入成功: 导入=%d, 总数=%d, team=%d, 非team跳过=%d, 失败=%d",
inserted, len(owners), teamCount, nonTeamCount, failCount), "", "upload")
stats := database.Instance.GetOwnerStats()
Success(w, map[string]interface{}{
"imported": inserted,
"total": len(owners),
"stats": stats,
"team_count": teamCount,
"non_team_count": nonTeamCount,
"fail_count": failCount,
})
}
// HandleRefetchAccountIDs 重新获取缺少 account_id 的母号
func HandleRefetchAccountIDs(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if database.Instance == nil {
Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
// 获取所有缺少 account_id 的 owners
owners, err := database.Instance.GetOwnersWithoutAccountID()
if err != nil {
Error(w, http.StatusInternalServerError, fmt.Sprintf("查询数据库失败: %v", err))
return
}
if len(owners) == 0 {
Success(w, map[string]interface{}{
"message": "所有母号都已有 account_id",
"total": 0,
"success": 0,
"fail": 0,
})
return
}
logger.Status(fmt.Sprintf("重新获取 account_id 中: 共 %d 个", len(owners)), "", "upload")
// 并发获取 account_id
var wg sync.WaitGroup
sem := make(chan struct{}, 20) // 20 并发
var mu sync.Mutex
successCount := 0
failCount := 0
for _, owner := range owners {
wg.Add(1)
sem <- struct{}{}
go func(o database.TeamOwner) {
defer wg.Done()
defer func() { <-sem }()
accountID, err := fetchAccountID(o.Token)
mu.Lock()
defer mu.Unlock()
if err != nil {
failCount++
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", o.Email, err), "", "upload")
} else {
// 更新数据库
if updateErr := database.Instance.UpdateOwnerAccountID(o.ID, accountID); updateErr != nil {
failCount++
logger.Error(fmt.Sprintf("更新 account_id 失败 (%s): %v", o.Email, updateErr), "", "upload")
} else {
successCount++
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", o.Email, accountID), "", "upload")
}
}
}(owner)
}
wg.Wait()
logger.Info(fmt.Sprintf("重新获取 account_id 完成: 成功=%d, 失败=%d", successCount, failCount), "", "upload")
Success(w, map[string]interface{}{
"message": "重新获取 account_id 完成",
"total": len(owners),
"success": successCount,
"fail": failCount,
})
}
// validateAndFetchAccountIDs 并发验证账号并获取 account_id只保留 team 账号)
// 返回: 有效的 owners 列表, team 数量, 非 team 数量, 失败数量
func validateAndFetchAccountIDs(owners []database.TeamOwner, concurrency int) ([]database.TeamOwner, int, int, int) {
var wg sync.WaitGroup
sem := make(chan struct{}, concurrency)
var mu sync.Mutex
completed := 0
total := len(owners)
validOwners := make([]database.TeamOwner, 0, total)
teamCount := 0
nonTeamCount := 0
failCount := 0
for i := range owners {
wg.Add(1)
sem <- struct{}{}
go func(idx int) {
defer wg.Done()
defer func() { <-sem }()
email := owners[idx].Email
token := owners[idx].Token
accountID, err := fetchAccountID(token)
mu.Lock()
completed++
progress := completed
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "非 team 账户") {
nonTeamCount++
logger.Warning(fmt.Sprintf("[%d/%d] 跳过非 team 账号: %s", progress, total, email), "", "upload")
} else {
failCount++
logger.Warning(fmt.Sprintf("[%d/%d] 验证失败 (%s): %v", progress, total, email, err), "", "upload")
}
} else {
owners[idx].AccountID = accountID
validOwners = append(validOwners, owners[idx])
teamCount++
logger.Info(fmt.Sprintf("[%d/%d] team 账号验证通过: %s -> %s", progress, total, email, accountID), "", "upload")
}
mu.Unlock()
}(i)
}
wg.Wait()
return validOwners, teamCount, nonTeamCount, failCount
}
// fetchAccountIDsConcurrent 并发获取 account_id用于重新获取
func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) {
var wg sync.WaitGroup
sem := make(chan struct{}, concurrency)
var mu sync.Mutex
completed := 0
total := 0
// 统计需要获取的数量
for _, o := range owners {
if o.AccountID == "" {
total++
}
}
for i := range owners {
if owners[i].AccountID != "" {
continue
}
wg.Add(1)
sem <- struct{}{} // 获取信号量
go func(idx int) {
defer wg.Done()
defer func() { <-sem }() // 释放信号量
email := owners[idx].Email
token := owners[idx].Token
accountID, err := fetchAccountID(token)
mu.Lock()
completed++
progress := completed
mu.Unlock()
if err != nil {
logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload")
} else {
mu.Lock()
owners[idx].AccountID = accountID
mu.Unlock()
logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload")
}
}(i)
}
wg.Wait()
}
// normalizeOwnerBasic 基础验证,不获取 account_id
func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
email := strings.TrimSpace(rec.Email)
if email == "" {
email = strings.TrimSpace(rec.Account)
}
if email == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
}
password := strings.TrimSpace(rec.Password)
if password == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
}
token := strings.TrimSpace(rec.Token)
if token == "" {
token = strings.TrimSpace(rec.AccessTok)
}
if token == "" {
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
}
accountID := strings.TrimSpace(rec.AccountID)
// 如果 account_id 为空,自动通过 API 获取
if accountID == "" {
fetchedID, err := fetchAccountID(token)
if err != nil {
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload")
} else {
accountID = fetchedID
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload")
}
}
return database.TeamOwner{
Email: email,
Password: password,
Token: token,
AccountID: accountID,
}, nil
}
// fetchAccountID 通过 token 获取 account_id只接受 plan_type 为 team 的账号)
func fetchAccountID(token string) (string, error) {
// 使用配置中的代理(如果启用)
proxy := ""
if cfg := config.Get(); cfg != nil {
proxy = cfg.GetProxy()
}
tlsClient, err := client.New(proxy)
if err != nil {
return "", fmt.Errorf("创建 TLS 客户端失败: %v", err)
}
req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
resp, err := tlsClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("API 返回状态码: %d, 响应: %s", resp.StatusCode, string(body)[:min(200, len(body))])
}
// 解析响应
var result struct {
Accounts map[string]struct {
Account struct {
ID string `json:"id"`
PlanType string `json:"plan_type"`
} `json:"account"`
} `json:"accounts"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
// 只接受 plan_type 为 "team" 的账户
for accountID, info := range result.Accounts {
if accountID == "default" {
continue
}
planType := strings.ToLower(info.Account.PlanType)
if strings.Contains(planType, "team") {
return accountID, nil
}
// 如果找到非 team 的账户,返回错误
return "", fmt.Errorf("账户 plan 为 %s非 team 账户", info.Account.PlanType)
}
return "", fmt.Errorf("未找到有效的 account_id")
}
func parseAccountsFlexible(raw string) ([]accountRecord, error) {
raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF"))
if raw == "" {
return nil, fmt.Errorf("内容为空")
}
cleaned := stripJSONComments(raw)
cleaned = removeTrailingCommas(cleaned)
trimmed := strings.TrimSpace(cleaned)
if trimmed == "" {
return nil, fmt.Errorf("内容为空")
}
if strings.HasPrefix(trimmed, "[") {
var arr []accountRecord
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
return arr, nil
}
} else if strings.HasPrefix(trimmed, "{") {
var single accountRecord
if err := json.Unmarshal([]byte(trimmed), &single); err == nil {
return []accountRecord{single}, nil
}
}
// JSONL 回退
lines := strings.Split(raw, "\n")
records := make([]accountRecord, 0, len(lines))
for i, line := range lines {
line = strings.TrimSpace(stripJSONComments(line))
if line == "" {
continue
}
if strings.HasPrefix(line, "#") {
continue
}
line = strings.TrimSpace(removeTrailingCommas(line))
if line == "" {
continue
}
var rec accountRecord
if err := json.Unmarshal([]byte(line), &rec); err != nil {
return nil, fmt.Errorf("第 %d 行解析失败: %v", i+1, err)
}
records = append(records, rec)
}
if len(records) == 0 {
return nil, fmt.Errorf("未解析到有效账号")
}
return records, nil
}
func stripJSONComments(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '/' {
for i+1 < len(input) && input[i+1] != '\n' {
i++
}
continue
}
if ch == '/' && i+1 < len(input) && input[i+1] == '*' {
i += 2
for i+1 < len(input) {
if input[i] == '*' && input[i+1] == '/' {
i++
break
}
i++
}
continue
}
b.WriteByte(ch)
}
return b.String()
}
func removeTrailingCommas(input string) string {
var b strings.Builder
b.Grow(len(input))
inString := false
escaped := false
for i := 0; i < len(input); i++ {
ch := input[i]
if inString {
b.WriteByte(ch)
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
} else if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
b.WriteByte(ch)
continue
}
if ch == ',' {
j := i + 1
for j < len(input) && unicode.IsSpace(rune(input[j])) {
j++
}
if j < len(input) && (input[j] == ']' || input[j] == '}') {
continue
}
}
b.WriteByte(ch)
}
return b.String()
}