Files
codexautopool/backend/internal/api/batch_rt_import.go

424 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"codex-pool/internal/config"
"codex-pool/internal/logger"
)
// ============================================================================
// 批量 RT 导入模块
// 功能: 读取 Refresh Token 列表,通过 S2A API 验证并创建账号
// ============================================================================
// rtImportRequest 导入请求
type rtImportRequest struct {
Tokens []string `json:"tokens"`
Prefix string `json:"prefix"` // "team" 或 "free"
Concurrency int `json:"concurrency"` // S2A 账号并发数
Priority int `json:"priority"` // S2A 账号优先级
GroupIDs []int `json:"group_ids"` // S2A 分组ID
ProxyID *int `json:"proxy_id"` // S2A 代理ID
RateMultiplier float64 `json:"rate_multiplier"` // 计费倍率
}
// rtImportResult 单条导入结果
type rtImportResult struct {
Index int `json:"index"`
RT string `json:"rt"` // 脱敏后的 RT 前缀
Email string `json:"email"` // 验证得到的邮箱
Success bool `json:"success"`
Error string `json:"error,omitempty"`
AcctID int `json:"account_id,omitempty"` // S2A 创建的账号ID
}
// rtImportState 导入任务状态
type rtImportState struct {
Running bool `json:"running"`
StartedAt time.Time `json:"started_at"`
Total int `json:"total"`
Completed int32 `json:"completed"`
Success int32 `json:"success"`
Failed int32 `json:"failed"`
Results []rtImportResult `json:"results"`
mu sync.Mutex
stopCh chan struct{}
}
var rtImportTaskState = &rtImportState{}
// isRTImportStopped 检查导入任务是否已被停止
func isRTImportStopped() bool {
select {
case <-rtImportTaskState.stopCh:
return true
default:
return false
}
}
// maskRT 脱敏处理 RT只显示前16个字符
func maskRT(rt string) string {
if len(rt) <= 16 {
return rt
}
return rt[:16] + "..."
}
// HandleRTImportStart POST /api/rt-import/start - 启动批量 RT 导入
func HandleRTImportStart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
// 检查是否正在运行
if rtImportTaskState.Running {
Error(w, http.StatusConflict, "已有导入任务正在运行")
return
}
// 检查 S2A 配置
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
Error(w, http.StatusBadRequest, "S2A 配置未设置,请先配置 S2A API 地址和 Admin Key")
return
}
var req rtImportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("请求格式错误: %v", err))
return
}
// 参数校验
if len(req.Tokens) == 0 {
Error(w, http.StatusBadRequest, "没有提供 Refresh Token")
return
}
if req.Prefix != "team" && req.Prefix != "free" {
Error(w, http.StatusBadRequest, "前缀必须是 team 或 free")
return
}
// 默认值
if req.Concurrency <= 0 {
if config.Global != nil && config.Global.Concurrency > 0 {
req.Concurrency = config.Global.Concurrency
} else {
req.Concurrency = 3
}
}
if req.Priority <= 0 {
if config.Global != nil && config.Global.Priority > 0 {
req.Priority = config.Global.Priority
} else {
req.Priority = 10
}
}
if req.RateMultiplier <= 0 {
req.RateMultiplier = 1.0
}
if len(req.GroupIDs) == 0 && config.Global != nil && len(config.Global.GroupIDs) > 0 {
req.GroupIDs = config.Global.GroupIDs
}
// 初始化状态
rtImportTaskState.Running = true
rtImportTaskState.stopCh = make(chan struct{})
rtImportTaskState.StartedAt = time.Now()
rtImportTaskState.Total = len(req.Tokens)
rtImportTaskState.Completed = 0
rtImportTaskState.Success = 0
rtImportTaskState.Failed = 0
rtImportTaskState.Results = make([]rtImportResult, 0, len(req.Tokens))
// 异步执行
go runRTImport(req)
logger.Info(fmt.Sprintf("RT 导入任务已启动: 共 %d 个 Token, 前缀: %s", len(req.Tokens), req.Prefix), "", "rt-import")
Success(w, map[string]interface{}{
"message": "导入任务已启动",
"total": len(req.Tokens),
"prefix": req.Prefix,
})
}
// HandleRTImportStatus GET /api/rt-import/status - 获取导入状态
func HandleRTImportStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "仅支持 GET")
return
}
rtImportTaskState.mu.Lock()
defer rtImportTaskState.mu.Unlock()
elapsed := int64(0)
if !rtImportTaskState.StartedAt.IsZero() {
elapsed = time.Since(rtImportTaskState.StartedAt).Milliseconds()
}
Success(w, map[string]interface{}{
"running": rtImportTaskState.Running,
"started_at": rtImportTaskState.StartedAt,
"total": rtImportTaskState.Total,
"completed": rtImportTaskState.Completed,
"success": rtImportTaskState.Success,
"failed": rtImportTaskState.Failed,
"results": rtImportTaskState.Results,
"elapsed_ms": elapsed,
})
}
// HandleRTImportStop POST /api/rt-import/stop - 停止导入
func HandleRTImportStop(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if !rtImportTaskState.Running {
Error(w, http.StatusBadRequest, "没有正在运行的导入任务")
return
}
rtImportTaskState.Running = false
if rtImportTaskState.stopCh != nil {
select {
case <-rtImportTaskState.stopCh:
// 已关闭
default:
close(rtImportTaskState.stopCh)
}
}
logger.Warning("RT 导入任务已收到停止信号", "", "rt-import")
Success(w, map[string]string{"message": "已发送停止信号"})
}
// runRTImport 执行批量导入
func runRTImport(req rtImportRequest) {
defer func() {
rtImportTaskState.Running = false
completed := atomic.LoadInt32(&rtImportTaskState.Completed)
success := atomic.LoadInt32(&rtImportTaskState.Success)
failed := atomic.LoadInt32(&rtImportTaskState.Failed)
logger.Success(fmt.Sprintf("RT 导入完成: 总数 %d, 成功 %d, 失败 %d",
completed, success, failed), "", "rt-import")
}()
for i, rt := range req.Tokens {
// 检查停止信号
if isRTImportStopped() {
logger.Warning(fmt.Sprintf("RT 导入已停止,跳过剩余 %d 个 Token", len(req.Tokens)-i), "", "rt-import")
break
}
result := processOneRT(i, rt, req)
rtImportTaskState.mu.Lock()
rtImportTaskState.Results = append(rtImportTaskState.Results, result)
rtImportTaskState.mu.Unlock()
atomic.AddInt32(&rtImportTaskState.Completed, 1)
if result.Success {
atomic.AddInt32(&rtImportTaskState.Success, 1)
} else {
atomic.AddInt32(&rtImportTaskState.Failed, 1)
}
// 限速,避免请求过快
time.Sleep(500 * time.Millisecond)
}
}
// processOneRT 处理单条 RT: 验证 → 创建账号
func processOneRT(index int, rt string, req rtImportRequest) rtImportResult {
result := rtImportResult{
Index: index + 1,
RT: maskRT(rt),
}
logger.Info(fmt.Sprintf("[%d/%d] 开始处理 RT: %s", index+1, req.Prefix, maskRT(rt)), "", "rt-import")
// Step 1: 通过 S2A 验证 RT
tokenInfo, err := validateRTViaS2A(rt, req.ProxyID)
if err != nil {
result.Error = fmt.Sprintf("验证失败: %v", err)
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), "", "rt-import")
return result
}
email, _ := tokenInfo["email"].(string)
if email == "" {
email = "unknown"
}
result.Email = email
logger.Info(fmt.Sprintf("[%d/%d] 验证成功: %s", index+1, len(req.Tokens), email), email, "rt-import")
// Step 2: 通过 S2A 创建账号
acctID, err := createAccountViaS2A(tokenInfo, req)
if err != nil {
result.Error = fmt.Sprintf("创建账号失败: %v", err)
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), email, "rt-import")
return result
}
result.Success = true
result.AcctID = acctID
logger.Success(fmt.Sprintf("[%d/%d] 账号创建成功: %s-%s (ID: %d)", index+1, len(req.Tokens), req.Prefix, email, acctID), email, "rt-import")
return result
}
// validateRTViaS2A 通过 S2A API 验证 Refresh Token
func validateRTViaS2A(rt string, proxyID *int) (map[string]interface{}, error) {
payload := map[string]interface{}{
"refresh_token": rt,
}
if proxyID != nil {
payload["proxy_id"] = *proxyID
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %v", err)
}
url := config.Global.S2AApiBase + "/api/v1/admin/openai/refresh-token"
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
setS2AHeaders(httpReq)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
// 检查 S2A 包装的 data 字段
if data, ok := result["data"].(map[string]interface{}); ok {
return data, nil
}
return result, nil
}
// createAccountViaS2A 通过 S2A API 创建账号
func createAccountViaS2A(tokenInfo map[string]interface{}, req rtImportRequest) (int, error) {
email, _ := tokenInfo["email"].(string)
if email == "" {
email = "unknown"
}
name := fmt.Sprintf("%s-%s", req.Prefix, email)
// 构建 credentials
credentials := map[string]interface{}{}
for _, key := range []string{"access_token", "refresh_token", "token_type", "expires_in", "expires_at", "scope"} {
if v, ok := tokenInfo[key]; ok && v != nil {
credentials[key] = v
}
}
// 可选字段
for _, key := range []string{"chatgpt_account_id", "chatgpt_user_id", "organization_id"} {
if v, ok := tokenInfo[key]; ok && v != nil {
credentials[key] = v
}
}
// 构建 extra
extra := map[string]interface{}{}
if email != "" && email != "unknown" {
extra["email"] = email
}
if nameVal, ok := tokenInfo["name"]; ok && nameVal != nil {
extra["name"] = nameVal
}
payload := map[string]interface{}{
"name": name,
"platform": "openai",
"type": "oauth",
"credentials": credentials,
"concurrency": req.Concurrency,
"priority": req.Priority,
"group_ids": req.GroupIDs,
"rate_multiplier": req.RateMultiplier,
}
if len(extra) > 0 {
payload["extra"] = extra
}
if req.ProxyID != nil {
payload["proxy_id"] = *req.ProxyID
}
body, err := json.Marshal(payload)
if err != nil {
return 0, fmt.Errorf("序列化请求失败: %v", err)
}
url := config.Global.S2AApiBase + "/api/v1/admin/accounts"
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return 0, fmt.Errorf("创建请求失败: %v", err)
}
setS2AHeaders(httpReq)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return 0, fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return 0, fmt.Errorf("解析响应失败: %v", err)
}
// 从响应中提取账号ID
if data, ok := result["data"].(map[string]interface{}); ok {
if id, ok := data["id"].(float64); ok {
return int(id), nil
}
}
if id, ok := result["id"].(float64); ok {
return int(id), nil
}
return 0, nil
}