553 lines
13 KiB
Go
553 lines
13 KiB
Go
package api
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"sync"
|
||
"unicode"
|
||
|
||
"codex-pool/internal/client"
|
||
"codex-pool/internal/config"
|
||
"codex-pool/internal/database"
|
||
"codex-pool/internal/logger"
|
||
)
|
||
|
||
type uploadValidateRequest struct {
|
||
Content string `json:"content"`
|
||
Accounts []accountRecord `json:"accounts"`
|
||
}
|
||
|
||
type accountRecord struct {
|
||
Account string `json:"account"`
|
||
Email string `json:"email"`
|
||
Password string `json:"password"`
|
||
Token string `json:"token"`
|
||
AccessTok string `json:"access_token"`
|
||
AccountID string `json:"account_id"`
|
||
}
|
||
|
||
// HandleUploadValidate 处理上传验证请求
|
||
func HandleUploadValidate(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||
return
|
||
}
|
||
if database.Instance == nil {
|
||
Error(w, http.StatusInternalServerError, "数据库未初始化")
|
||
return
|
||
}
|
||
|
||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 10<<20))
|
||
if err != nil {
|
||
Error(w, http.StatusBadRequest, "读取请求失败")
|
||
return
|
||
}
|
||
|
||
var req uploadValidateRequest
|
||
if err := json.Unmarshal(body, &req); err != nil {
|
||
// 如果不是 JSON,直接把 body 当作原始内容
|
||
req.Content = string(body)
|
||
}
|
||
|
||
var records []accountRecord
|
||
switch {
|
||
case len(req.Accounts) > 0:
|
||
records = req.Accounts
|
||
case strings.TrimSpace(req.Content) != "":
|
||
parsed, parseErr := parseAccountsFlexible(req.Content)
|
||
if parseErr != nil {
|
||
Error(w, http.StatusBadRequest, parseErr.Error())
|
||
return
|
||
}
|
||
records = parsed
|
||
default:
|
||
Error(w, http.StatusBadRequest, "未提供账号内容")
|
||
return
|
||
}
|
||
|
||
owners := make([]database.TeamOwner, 0, len(records))
|
||
for i, rec := range records {
|
||
owner, err := normalizeOwnerBasic(rec, i+1)
|
||
if err != nil {
|
||
Error(w, http.StatusBadRequest, err.Error())
|
||
return
|
||
}
|
||
owners = append(owners, owner)
|
||
}
|
||
if len(owners) == 0 {
|
||
Error(w, http.StatusBadRequest, "未解析到有效账号")
|
||
return
|
||
}
|
||
|
||
// 统计需要获取 account_id 的数量
|
||
needFetch := 0
|
||
for _, o := range owners {
|
||
if o.AccountID == "" {
|
||
needFetch++
|
||
}
|
||
}
|
||
|
||
// 并发获取 account_id (使用20并发)
|
||
if needFetch > 0 {
|
||
logger.Info(fmt.Sprintf("开始获取 account_id: 需要获取 %d 个", needFetch), "", "upload")
|
||
fetchAccountIDsConcurrent(owners, 20)
|
||
logger.Info("account_id 获取完成", "", "upload")
|
||
}
|
||
|
||
inserted, err := database.Instance.AddTeamOwners(owners)
|
||
if err != nil {
|
||
Error(w, http.StatusInternalServerError, fmt.Sprintf("写入数据库失败: %v", err))
|
||
return
|
||
}
|
||
|
||
// 统计获取结果
|
||
successCount := 0
|
||
failCount := 0
|
||
for _, o := range owners {
|
||
if o.AccountID != "" {
|
||
successCount++
|
||
} else {
|
||
failCount++
|
||
}
|
||
}
|
||
|
||
// 输出导入日志
|
||
logger.Info(fmt.Sprintf("母号导入成功: 成功=%d, 总数=%d, account_id获取成功=%d, 失败=%d",
|
||
inserted, len(owners), successCount, failCount), "", "upload")
|
||
|
||
stats := database.Instance.GetOwnerStats()
|
||
Success(w, map[string]interface{}{
|
||
"imported": inserted,
|
||
"total": len(owners),
|
||
"stats": stats,
|
||
"account_id_ok": successCount,
|
||
"account_id_fail": failCount,
|
||
})
|
||
}
|
||
|
||
// HandleRefetchAccountIDs 重新获取缺少 account_id 的母号
|
||
func HandleRefetchAccountIDs(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||
return
|
||
}
|
||
if database.Instance == nil {
|
||
Error(w, http.StatusInternalServerError, "数据库未初始化")
|
||
return
|
||
}
|
||
|
||
// 获取所有缺少 account_id 的 owners
|
||
owners, err := database.Instance.GetOwnersWithoutAccountID()
|
||
if err != nil {
|
||
Error(w, http.StatusInternalServerError, fmt.Sprintf("查询数据库失败: %v", err))
|
||
return
|
||
}
|
||
|
||
if len(owners) == 0 {
|
||
Success(w, map[string]interface{}{
|
||
"message": "所有母号都已有 account_id",
|
||
"total": 0,
|
||
"success": 0,
|
||
"fail": 0,
|
||
})
|
||
return
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("开始重新获取 account_id: 共 %d 个", len(owners)), "", "upload")
|
||
|
||
// 并发获取 account_id
|
||
var wg sync.WaitGroup
|
||
sem := make(chan struct{}, 20) // 20 并发
|
||
var mu sync.Mutex
|
||
successCount := 0
|
||
failCount := 0
|
||
|
||
for _, owner := range owners {
|
||
wg.Add(1)
|
||
sem <- struct{}{}
|
||
|
||
go func(o database.TeamOwner) {
|
||
defer wg.Done()
|
||
defer func() { <-sem }()
|
||
|
||
accountID, err := fetchAccountID(o.Token)
|
||
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
|
||
if err != nil {
|
||
failCount++
|
||
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", o.Email, err), "", "upload")
|
||
} else {
|
||
// 更新数据库
|
||
if updateErr := database.Instance.UpdateOwnerAccountID(o.ID, accountID); updateErr != nil {
|
||
failCount++
|
||
logger.Error(fmt.Sprintf("更新 account_id 失败 (%s): %v", o.Email, updateErr), "", "upload")
|
||
} else {
|
||
successCount++
|
||
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", o.Email, accountID), "", "upload")
|
||
}
|
||
}
|
||
}(owner)
|
||
}
|
||
|
||
wg.Wait()
|
||
|
||
logger.Info(fmt.Sprintf("重新获取 account_id 完成: 成功=%d, 失败=%d", successCount, failCount), "", "upload")
|
||
|
||
Success(w, map[string]interface{}{
|
||
"message": "重新获取 account_id 完成",
|
||
"total": len(owners),
|
||
"success": successCount,
|
||
"fail": failCount,
|
||
})
|
||
}
|
||
|
||
// fetchAccountIDsConcurrent 并发获取 account_id
|
||
func fetchAccountIDsConcurrent(owners []database.TeamOwner, concurrency int) {
|
||
var wg sync.WaitGroup
|
||
sem := make(chan struct{}, concurrency)
|
||
var mu sync.Mutex
|
||
completed := 0
|
||
total := 0
|
||
|
||
// 统计需要获取的数量
|
||
for _, o := range owners {
|
||
if o.AccountID == "" {
|
||
total++
|
||
}
|
||
}
|
||
|
||
for i := range owners {
|
||
if owners[i].AccountID != "" {
|
||
continue
|
||
}
|
||
|
||
wg.Add(1)
|
||
sem <- struct{}{} // 获取信号量
|
||
|
||
go func(idx int) {
|
||
defer wg.Done()
|
||
defer func() { <-sem }() // 释放信号量
|
||
|
||
email := owners[idx].Email
|
||
token := owners[idx].Token
|
||
|
||
accountID, err := fetchAccountID(token)
|
||
|
||
mu.Lock()
|
||
completed++
|
||
progress := completed
|
||
mu.Unlock()
|
||
|
||
if err != nil {
|
||
logger.Warning(fmt.Sprintf("[%d/%d] 获取 account_id 失败 (%s): %v", progress, total, email, err), "", "upload")
|
||
} else {
|
||
mu.Lock()
|
||
owners[idx].AccountID = accountID
|
||
mu.Unlock()
|
||
logger.Info(fmt.Sprintf("[%d/%d] 获取 account_id 成功: %s -> %s", progress, total, email, accountID), "", "upload")
|
||
}
|
||
}(i)
|
||
}
|
||
|
||
wg.Wait()
|
||
}
|
||
|
||
// normalizeOwnerBasic 基础验证,不获取 account_id
|
||
func normalizeOwnerBasic(rec accountRecord, index int) (database.TeamOwner, error) {
|
||
email := strings.TrimSpace(rec.Email)
|
||
if email == "" {
|
||
email = strings.TrimSpace(rec.Account)
|
||
}
|
||
if email == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
|
||
}
|
||
|
||
password := strings.TrimSpace(rec.Password)
|
||
if password == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
|
||
}
|
||
|
||
token := strings.TrimSpace(rec.Token)
|
||
if token == "" {
|
||
token = strings.TrimSpace(rec.AccessTok)
|
||
}
|
||
if token == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
|
||
}
|
||
|
||
accountID := strings.TrimSpace(rec.AccountID)
|
||
|
||
return database.TeamOwner{
|
||
Email: email,
|
||
Password: password,
|
||
Token: token,
|
||
AccountID: accountID,
|
||
}, nil
|
||
}
|
||
|
||
func normalizeOwner(rec accountRecord, index int) (database.TeamOwner, error) {
|
||
email := strings.TrimSpace(rec.Email)
|
||
if email == "" {
|
||
email = strings.TrimSpace(rec.Account)
|
||
}
|
||
if email == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 account/email 字段", index)
|
||
}
|
||
|
||
password := strings.TrimSpace(rec.Password)
|
||
if password == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 password 字段", index)
|
||
}
|
||
|
||
token := strings.TrimSpace(rec.Token)
|
||
if token == "" {
|
||
token = strings.TrimSpace(rec.AccessTok)
|
||
}
|
||
if token == "" {
|
||
return database.TeamOwner{}, fmt.Errorf("第 %d 条记录缺少 token 字段", index)
|
||
}
|
||
|
||
accountID := strings.TrimSpace(rec.AccountID)
|
||
// 如果 account_id 为空,自动通过 API 获取
|
||
if accountID == "" {
|
||
fetchedID, err := fetchAccountID(token)
|
||
if err != nil {
|
||
logger.Warning(fmt.Sprintf("获取 account_id 失败 (%s): %v", email, err), "", "upload")
|
||
} else {
|
||
accountID = fetchedID
|
||
logger.Info(fmt.Sprintf("获取 account_id 成功: %s -> %s", email, accountID), "", "upload")
|
||
}
|
||
}
|
||
|
||
return database.TeamOwner{
|
||
Email: email,
|
||
Password: password,
|
||
Token: token,
|
||
AccountID: accountID,
|
||
}, nil
|
||
}
|
||
|
||
// fetchAccountID 通过 token 获取 account_id
|
||
func fetchAccountID(token string) (string, error) {
|
||
// 使用配置中的代理(如果启用)
|
||
proxy := ""
|
||
if cfg := config.Get(); cfg != nil {
|
||
proxy = cfg.GetProxy()
|
||
}
|
||
|
||
tlsClient, err := client.New(proxy)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建 TLS 客户端失败: %v", err)
|
||
}
|
||
|
||
req, err := http.NewRequest("GET", "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27", nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+token)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := tlsClient.Do(req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("API 返回状态码: %d, 响应: %s", resp.StatusCode, string(body)[:min(200, len(body))])
|
||
}
|
||
|
||
// 解析响应
|
||
var result struct {
|
||
Accounts map[string]struct {
|
||
Account struct {
|
||
ID string `json:"id"`
|
||
PlanType string `json:"plan_type"`
|
||
} `json:"account"`
|
||
} `json:"accounts"`
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &result); err != nil {
|
||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
// 优先查找 plan_type 为 "team" 的账户
|
||
// 注意:account_id 是 map 的 key,而不是 Account.ID 字段
|
||
for accountID, info := range result.Accounts {
|
||
if accountID == "default" {
|
||
continue
|
||
}
|
||
if strings.Contains(strings.ToLower(info.Account.PlanType), "team") {
|
||
return accountID, nil
|
||
}
|
||
}
|
||
|
||
// 否则取第一个非 "default" 的账户
|
||
for accountID := range result.Accounts {
|
||
if accountID != "default" {
|
||
return accountID, nil
|
||
}
|
||
}
|
||
|
||
return "", fmt.Errorf("未找到有效的 account_id")
|
||
}
|
||
|
||
func parseAccountsFlexible(raw string) ([]accountRecord, error) {
|
||
raw = strings.TrimSpace(strings.TrimPrefix(raw, "\uFEFF"))
|
||
if raw == "" {
|
||
return nil, fmt.Errorf("内容为空")
|
||
}
|
||
|
||
cleaned := stripJSONComments(raw)
|
||
cleaned = removeTrailingCommas(cleaned)
|
||
trimmed := strings.TrimSpace(cleaned)
|
||
if trimmed == "" {
|
||
return nil, fmt.Errorf("内容为空")
|
||
}
|
||
|
||
if strings.HasPrefix(trimmed, "[") {
|
||
var arr []accountRecord
|
||
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
|
||
return arr, nil
|
||
}
|
||
} else if strings.HasPrefix(trimmed, "{") {
|
||
var single accountRecord
|
||
if err := json.Unmarshal([]byte(trimmed), &single); err == nil {
|
||
return []accountRecord{single}, nil
|
||
}
|
||
}
|
||
|
||
// JSONL 回退
|
||
lines := strings.Split(raw, "\n")
|
||
records := make([]accountRecord, 0, len(lines))
|
||
for i, line := range lines {
|
||
line = strings.TrimSpace(stripJSONComments(line))
|
||
if line == "" {
|
||
continue
|
||
}
|
||
if strings.HasPrefix(line, "#") {
|
||
continue
|
||
}
|
||
line = strings.TrimSpace(removeTrailingCommas(line))
|
||
if line == "" {
|
||
continue
|
||
}
|
||
var rec accountRecord
|
||
if err := json.Unmarshal([]byte(line), &rec); err != nil {
|
||
return nil, fmt.Errorf("第 %d 行解析失败: %v", i+1, err)
|
||
}
|
||
records = append(records, rec)
|
||
}
|
||
if len(records) == 0 {
|
||
return nil, fmt.Errorf("未解析到有效账号")
|
||
}
|
||
return records, nil
|
||
}
|
||
|
||
func stripJSONComments(input string) string {
|
||
var b strings.Builder
|
||
b.Grow(len(input))
|
||
|
||
inString := false
|
||
escaped := false
|
||
for i := 0; i < len(input); i++ {
|
||
ch := input[i]
|
||
if inString {
|
||
b.WriteByte(ch)
|
||
if escaped {
|
||
escaped = false
|
||
continue
|
||
}
|
||
if ch == '\\' {
|
||
escaped = true
|
||
} else if ch == '"' {
|
||
inString = false
|
||
}
|
||
continue
|
||
}
|
||
|
||
if ch == '"' {
|
||
inString = true
|
||
b.WriteByte(ch)
|
||
continue
|
||
}
|
||
|
||
if ch == '/' && i+1 < len(input) && input[i+1] == '/' {
|
||
for i+1 < len(input) && input[i+1] != '\n' {
|
||
i++
|
||
}
|
||
continue
|
||
}
|
||
if ch == '/' && i+1 < len(input) && input[i+1] == '*' {
|
||
i += 2
|
||
for i+1 < len(input) {
|
||
if input[i] == '*' && input[i+1] == '/' {
|
||
i++
|
||
break
|
||
}
|
||
i++
|
||
}
|
||
continue
|
||
}
|
||
|
||
b.WriteByte(ch)
|
||
}
|
||
|
||
return b.String()
|
||
}
|
||
|
||
func removeTrailingCommas(input string) string {
|
||
var b strings.Builder
|
||
b.Grow(len(input))
|
||
|
||
inString := false
|
||
escaped := false
|
||
for i := 0; i < len(input); i++ {
|
||
ch := input[i]
|
||
if inString {
|
||
b.WriteByte(ch)
|
||
if escaped {
|
||
escaped = false
|
||
continue
|
||
}
|
||
if ch == '\\' {
|
||
escaped = true
|
||
} else if ch == '"' {
|
||
inString = false
|
||
}
|
||
continue
|
||
}
|
||
|
||
if ch == '"' {
|
||
inString = true
|
||
b.WriteByte(ch)
|
||
continue
|
||
}
|
||
|
||
if ch == ',' {
|
||
j := i + 1
|
||
for j < len(input) && unicode.IsSpace(rune(input[j])) {
|
||
j++
|
||
}
|
||
if j < len(input) && (input[j] == ']' || input[j] == '}') {
|
||
continue
|
||
}
|
||
}
|
||
|
||
b.WriteByte(ch)
|
||
}
|
||
|
||
return b.String()
|
||
}
|