feat: Initialize the core backend API server, frontend application structure, and implement batch RT import functionality.
This commit is contained in:
@@ -186,6 +186,11 @@ func startServer(cfg *config.Config) {
|
||||
// Owner 降级 API
|
||||
mux.HandleFunc("/api/demote/owner", api.CORS(api.HandleDemoteOwner))
|
||||
|
||||
// 批量 RT 导入 API
|
||||
mux.HandleFunc("/api/rt-import/start", api.CORS(api.HandleRTImportStart))
|
||||
mux.HandleFunc("/api/rt-import/status", api.CORS(api.HandleRTImportStatus))
|
||||
mux.HandleFunc("/api/rt-import/stop", api.CORS(api.HandleRTImportStop))
|
||||
|
||||
// 嵌入的前端静态文件
|
||||
if web.IsEmbedded() {
|
||||
webFS := web.GetFileSystem()
|
||||
|
||||
423
backend/internal/api/batch_rt_import.go
Normal file
423
backend/internal/api/batch_rt_import.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"codex-pool/internal/config"
|
||||
"codex-pool/internal/logger"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// 批量 RT 导入模块
|
||||
// 功能: 读取 Refresh Token 列表,通过 S2A API 验证并创建账号
|
||||
// ============================================================================
|
||||
|
||||
// rtImportRequest 导入请求
|
||||
type rtImportRequest struct {
|
||||
Tokens []string `json:"tokens"`
|
||||
Prefix string `json:"prefix"` // "team" 或 "free"
|
||||
Concurrency int `json:"concurrency"` // S2A 账号并发数
|
||||
Priority int `json:"priority"` // S2A 账号优先级
|
||||
GroupIDs []int `json:"group_ids"` // S2A 分组ID
|
||||
ProxyID *int `json:"proxy_id"` // S2A 代理ID
|
||||
RateMultiplier float64 `json:"rate_multiplier"` // 计费倍率
|
||||
}
|
||||
|
||||
// rtImportResult 单条导入结果
|
||||
type rtImportResult struct {
|
||||
Index int `json:"index"`
|
||||
RT string `json:"rt"` // 脱敏后的 RT 前缀
|
||||
Email string `json:"email"` // 验证得到的邮箱
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
AcctID int `json:"account_id,omitempty"` // S2A 创建的账号ID
|
||||
}
|
||||
|
||||
// rtImportState 导入任务状态
|
||||
type rtImportState struct {
|
||||
Running bool `json:"running"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
Total int `json:"total"`
|
||||
Completed int32 `json:"completed"`
|
||||
Success int32 `json:"success"`
|
||||
Failed int32 `json:"failed"`
|
||||
Results []rtImportResult `json:"results"`
|
||||
mu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
var rtImportTaskState = &rtImportState{}
|
||||
|
||||
// isRTImportStopped 检查导入任务是否已被停止
|
||||
func isRTImportStopped() bool {
|
||||
select {
|
||||
case <-rtImportTaskState.stopCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// maskRT 脱敏处理 RT,只显示前16个字符
|
||||
func maskRT(rt string) string {
|
||||
if len(rt) <= 16 {
|
||||
return rt
|
||||
}
|
||||
return rt[:16] + "..."
|
||||
}
|
||||
|
||||
// HandleRTImportStart POST /api/rt-import/start - 启动批量 RT 导入
|
||||
func HandleRTImportStart(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否正在运行
|
||||
if rtImportTaskState.Running {
|
||||
Error(w, http.StatusConflict, "已有导入任务正在运行")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 S2A 配置
|
||||
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
|
||||
Error(w, http.StatusBadRequest, "S2A 配置未设置,请先配置 S2A API 地址和 Admin Key")
|
||||
return
|
||||
}
|
||||
|
||||
var req rtImportRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
Error(w, http.StatusBadRequest, fmt.Sprintf("请求格式错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 参数校验
|
||||
if len(req.Tokens) == 0 {
|
||||
Error(w, http.StatusBadRequest, "没有提供 Refresh Token")
|
||||
return
|
||||
}
|
||||
if req.Prefix != "team" && req.Prefix != "free" {
|
||||
Error(w, http.StatusBadRequest, "前缀必须是 team 或 free")
|
||||
return
|
||||
}
|
||||
|
||||
// 默认值
|
||||
if req.Concurrency <= 0 {
|
||||
if config.Global != nil && config.Global.Concurrency > 0 {
|
||||
req.Concurrency = config.Global.Concurrency
|
||||
} else {
|
||||
req.Concurrency = 3
|
||||
}
|
||||
}
|
||||
if req.Priority <= 0 {
|
||||
if config.Global != nil && config.Global.Priority > 0 {
|
||||
req.Priority = config.Global.Priority
|
||||
} else {
|
||||
req.Priority = 10
|
||||
}
|
||||
}
|
||||
if req.RateMultiplier <= 0 {
|
||||
req.RateMultiplier = 1.0
|
||||
}
|
||||
if len(req.GroupIDs) == 0 && config.Global != nil && len(config.Global.GroupIDs) > 0 {
|
||||
req.GroupIDs = config.Global.GroupIDs
|
||||
}
|
||||
|
||||
// 初始化状态
|
||||
rtImportTaskState.Running = true
|
||||
rtImportTaskState.stopCh = make(chan struct{})
|
||||
rtImportTaskState.StartedAt = time.Now()
|
||||
rtImportTaskState.Total = len(req.Tokens)
|
||||
rtImportTaskState.Completed = 0
|
||||
rtImportTaskState.Success = 0
|
||||
rtImportTaskState.Failed = 0
|
||||
rtImportTaskState.Results = make([]rtImportResult, 0, len(req.Tokens))
|
||||
|
||||
// 异步执行
|
||||
go runRTImport(req)
|
||||
|
||||
logger.Info(fmt.Sprintf("RT 导入任务已启动: 共 %d 个 Token, 前缀: %s", len(req.Tokens), req.Prefix), "", "rt-import")
|
||||
|
||||
Success(w, map[string]interface{}{
|
||||
"message": "导入任务已启动",
|
||||
"total": len(req.Tokens),
|
||||
"prefix": req.Prefix,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRTImportStatus GET /api/rt-import/status - 获取导入状态
|
||||
func HandleRTImportStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
Error(w, http.StatusMethodNotAllowed, "仅支持 GET")
|
||||
return
|
||||
}
|
||||
|
||||
rtImportTaskState.mu.Lock()
|
||||
defer rtImportTaskState.mu.Unlock()
|
||||
|
||||
elapsed := int64(0)
|
||||
if !rtImportTaskState.StartedAt.IsZero() {
|
||||
elapsed = time.Since(rtImportTaskState.StartedAt).Milliseconds()
|
||||
}
|
||||
|
||||
Success(w, map[string]interface{}{
|
||||
"running": rtImportTaskState.Running,
|
||||
"started_at": rtImportTaskState.StartedAt,
|
||||
"total": rtImportTaskState.Total,
|
||||
"completed": rtImportTaskState.Completed,
|
||||
"success": rtImportTaskState.Success,
|
||||
"failed": rtImportTaskState.Failed,
|
||||
"results": rtImportTaskState.Results,
|
||||
"elapsed_ms": elapsed,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleRTImportStop POST /api/rt-import/stop - 停止导入
|
||||
func HandleRTImportStop(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||||
return
|
||||
}
|
||||
|
||||
if !rtImportTaskState.Running {
|
||||
Error(w, http.StatusBadRequest, "没有正在运行的导入任务")
|
||||
return
|
||||
}
|
||||
|
||||
rtImportTaskState.Running = false
|
||||
if rtImportTaskState.stopCh != nil {
|
||||
select {
|
||||
case <-rtImportTaskState.stopCh:
|
||||
// 已关闭
|
||||
default:
|
||||
close(rtImportTaskState.stopCh)
|
||||
}
|
||||
}
|
||||
logger.Warning("RT 导入任务已收到停止信号", "", "rt-import")
|
||||
Success(w, map[string]string{"message": "已发送停止信号"})
|
||||
}
|
||||
|
||||
// runRTImport 执行批量导入
|
||||
func runRTImport(req rtImportRequest) {
|
||||
defer func() {
|
||||
rtImportTaskState.Running = false
|
||||
completed := atomic.LoadInt32(&rtImportTaskState.Completed)
|
||||
success := atomic.LoadInt32(&rtImportTaskState.Success)
|
||||
failed := atomic.LoadInt32(&rtImportTaskState.Failed)
|
||||
logger.Success(fmt.Sprintf("RT 导入完成: 总数 %d, 成功 %d, 失败 %d",
|
||||
completed, success, failed), "", "rt-import")
|
||||
}()
|
||||
|
||||
for i, rt := range req.Tokens {
|
||||
// 检查停止信号
|
||||
if isRTImportStopped() {
|
||||
logger.Warning(fmt.Sprintf("RT 导入已停止,跳过剩余 %d 个 Token", len(req.Tokens)-i), "", "rt-import")
|
||||
break
|
||||
}
|
||||
|
||||
result := processOneRT(i, rt, req)
|
||||
|
||||
rtImportTaskState.mu.Lock()
|
||||
rtImportTaskState.Results = append(rtImportTaskState.Results, result)
|
||||
rtImportTaskState.mu.Unlock()
|
||||
|
||||
atomic.AddInt32(&rtImportTaskState.Completed, 1)
|
||||
if result.Success {
|
||||
atomic.AddInt32(&rtImportTaskState.Success, 1)
|
||||
} else {
|
||||
atomic.AddInt32(&rtImportTaskState.Failed, 1)
|
||||
}
|
||||
|
||||
// 限速,避免请求过快
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// processOneRT 处理单条 RT: 验证 → 创建账号
|
||||
func processOneRT(index int, rt string, req rtImportRequest) rtImportResult {
|
||||
result := rtImportResult{
|
||||
Index: index + 1,
|
||||
RT: maskRT(rt),
|
||||
}
|
||||
|
||||
logger.Info(fmt.Sprintf("[%d/%d] 开始处理 RT: %s", index+1, req.Prefix, maskRT(rt)), "", "rt-import")
|
||||
|
||||
// Step 1: 通过 S2A 验证 RT
|
||||
tokenInfo, err := validateRTViaS2A(rt, req.ProxyID)
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("验证失败: %v", err)
|
||||
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), "", "rt-import")
|
||||
return result
|
||||
}
|
||||
|
||||
email, _ := tokenInfo["email"].(string)
|
||||
if email == "" {
|
||||
email = "unknown"
|
||||
}
|
||||
result.Email = email
|
||||
logger.Info(fmt.Sprintf("[%d/%d] 验证成功: %s", index+1, len(req.Tokens), email), email, "rt-import")
|
||||
|
||||
// Step 2: 通过 S2A 创建账号
|
||||
acctID, err := createAccountViaS2A(tokenInfo, req)
|
||||
if err != nil {
|
||||
result.Error = fmt.Sprintf("创建账号失败: %v", err)
|
||||
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), email, "rt-import")
|
||||
return result
|
||||
}
|
||||
|
||||
result.Success = true
|
||||
result.AcctID = acctID
|
||||
logger.Success(fmt.Sprintf("[%d/%d] 账号创建成功: %s-%s (ID: %d)", index+1, len(req.Tokens), req.Prefix, email, acctID), email, "rt-import")
|
||||
return result
|
||||
}
|
||||
|
||||
// validateRTViaS2A 通过 S2A API 验证 Refresh Token
|
||||
func validateRTViaS2A(rt string, proxyID *int) (map[string]interface{}, error) {
|
||||
payload := map[string]interface{}{
|
||||
"refresh_token": rt,
|
||||
}
|
||||
if proxyID != nil {
|
||||
payload["proxy_id"] = *proxyID
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
url := config.Global.S2AApiBase + "/api/v1/admin/openai/refresh-token"
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
setS2AHeaders(httpReq)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 检查 S2A 包装的 data 字段
|
||||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// createAccountViaS2A 通过 S2A API 创建账号
|
||||
func createAccountViaS2A(tokenInfo map[string]interface{}, req rtImportRequest) (int, error) {
|
||||
email, _ := tokenInfo["email"].(string)
|
||||
if email == "" {
|
||||
email = "unknown"
|
||||
}
|
||||
name := fmt.Sprintf("%s-%s", req.Prefix, email)
|
||||
|
||||
// 构建 credentials
|
||||
credentials := map[string]interface{}{}
|
||||
for _, key := range []string{"access_token", "refresh_token", "token_type", "expires_in", "expires_at", "scope"} {
|
||||
if v, ok := tokenInfo[key]; ok && v != nil {
|
||||
credentials[key] = v
|
||||
}
|
||||
}
|
||||
// 可选字段
|
||||
for _, key := range []string{"chatgpt_account_id", "chatgpt_user_id", "organization_id"} {
|
||||
if v, ok := tokenInfo[key]; ok && v != nil {
|
||||
credentials[key] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 构建 extra
|
||||
extra := map[string]interface{}{}
|
||||
if email != "" && email != "unknown" {
|
||||
extra["email"] = email
|
||||
}
|
||||
if nameVal, ok := tokenInfo["name"]; ok && nameVal != nil {
|
||||
extra["name"] = nameVal
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"name": name,
|
||||
"platform": "openai",
|
||||
"type": "oauth",
|
||||
"credentials": credentials,
|
||||
"concurrency": req.Concurrency,
|
||||
"priority": req.Priority,
|
||||
"group_ids": req.GroupIDs,
|
||||
"rate_multiplier": req.RateMultiplier,
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
payload["extra"] = extra
|
||||
}
|
||||
if req.ProxyID != nil {
|
||||
payload["proxy_id"] = *req.ProxyID
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
url := config.Global.S2AApiBase + "/api/v1/admin/accounts"
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
setS2AHeaders(httpReq)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return 0, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 从响应中提取账号ID
|
||||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||||
if id, ok := data["id"].(float64); ok {
|
||||
return int(id), nil
|
||||
}
|
||||
}
|
||||
if id, ok := result["id"].(float64); ok {
|
||||
return int(id), nil
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
Reference in New Issue
Block a user