feat: Initialize the core backend API server, frontend application structure, and implement batch RT import functionality.

This commit is contained in:
2026-02-08 18:35:01 +08:00
parent 571322ffcb
commit 76de666560
7 changed files with 952 additions and 1 deletions

View File

@@ -0,0 +1,423 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
"codex-pool/internal/config"
"codex-pool/internal/logger"
)
// ============================================================================
// 批量 RT 导入模块
// 功能: 读取 Refresh Token 列表,通过 S2A API 验证并创建账号
// ============================================================================
// rtImportRequest 导入请求
type rtImportRequest struct {
Tokens []string `json:"tokens"`
Prefix string `json:"prefix"` // "team" 或 "free"
Concurrency int `json:"concurrency"` // S2A 账号并发数
Priority int `json:"priority"` // S2A 账号优先级
GroupIDs []int `json:"group_ids"` // S2A 分组ID
ProxyID *int `json:"proxy_id"` // S2A 代理ID
RateMultiplier float64 `json:"rate_multiplier"` // 计费倍率
}
// rtImportResult 单条导入结果
type rtImportResult struct {
Index int `json:"index"`
RT string `json:"rt"` // 脱敏后的 RT 前缀
Email string `json:"email"` // 验证得到的邮箱
Success bool `json:"success"`
Error string `json:"error,omitempty"`
AcctID int `json:"account_id,omitempty"` // S2A 创建的账号ID
}
// rtImportState 导入任务状态
type rtImportState struct {
Running bool `json:"running"`
StartedAt time.Time `json:"started_at"`
Total int `json:"total"`
Completed int32 `json:"completed"`
Success int32 `json:"success"`
Failed int32 `json:"failed"`
Results []rtImportResult `json:"results"`
mu sync.Mutex
stopCh chan struct{}
}
var rtImportTaskState = &rtImportState{}
// isRTImportStopped 检查导入任务是否已被停止
func isRTImportStopped() bool {
select {
case <-rtImportTaskState.stopCh:
return true
default:
return false
}
}
// maskRT 脱敏处理 RT只显示前16个字符
func maskRT(rt string) string {
if len(rt) <= 16 {
return rt
}
return rt[:16] + "..."
}
// HandleRTImportStart POST /api/rt-import/start - 启动批量 RT 导入
func HandleRTImportStart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
// 检查是否正在运行
if rtImportTaskState.Running {
Error(w, http.StatusConflict, "已有导入任务正在运行")
return
}
// 检查 S2A 配置
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
Error(w, http.StatusBadRequest, "S2A 配置未设置,请先配置 S2A API 地址和 Admin Key")
return
}
var req rtImportRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
Error(w, http.StatusBadRequest, fmt.Sprintf("请求格式错误: %v", err))
return
}
// 参数校验
if len(req.Tokens) == 0 {
Error(w, http.StatusBadRequest, "没有提供 Refresh Token")
return
}
if req.Prefix != "team" && req.Prefix != "free" {
Error(w, http.StatusBadRequest, "前缀必须是 team 或 free")
return
}
// 默认值
if req.Concurrency <= 0 {
if config.Global != nil && config.Global.Concurrency > 0 {
req.Concurrency = config.Global.Concurrency
} else {
req.Concurrency = 3
}
}
if req.Priority <= 0 {
if config.Global != nil && config.Global.Priority > 0 {
req.Priority = config.Global.Priority
} else {
req.Priority = 10
}
}
if req.RateMultiplier <= 0 {
req.RateMultiplier = 1.0
}
if len(req.GroupIDs) == 0 && config.Global != nil && len(config.Global.GroupIDs) > 0 {
req.GroupIDs = config.Global.GroupIDs
}
// 初始化状态
rtImportTaskState.Running = true
rtImportTaskState.stopCh = make(chan struct{})
rtImportTaskState.StartedAt = time.Now()
rtImportTaskState.Total = len(req.Tokens)
rtImportTaskState.Completed = 0
rtImportTaskState.Success = 0
rtImportTaskState.Failed = 0
rtImportTaskState.Results = make([]rtImportResult, 0, len(req.Tokens))
// 异步执行
go runRTImport(req)
logger.Info(fmt.Sprintf("RT 导入任务已启动: 共 %d 个 Token, 前缀: %s", len(req.Tokens), req.Prefix), "", "rt-import")
Success(w, map[string]interface{}{
"message": "导入任务已启动",
"total": len(req.Tokens),
"prefix": req.Prefix,
})
}
// HandleRTImportStatus GET /api/rt-import/status - 获取导入状态
func HandleRTImportStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
Error(w, http.StatusMethodNotAllowed, "仅支持 GET")
return
}
rtImportTaskState.mu.Lock()
defer rtImportTaskState.mu.Unlock()
elapsed := int64(0)
if !rtImportTaskState.StartedAt.IsZero() {
elapsed = time.Since(rtImportTaskState.StartedAt).Milliseconds()
}
Success(w, map[string]interface{}{
"running": rtImportTaskState.Running,
"started_at": rtImportTaskState.StartedAt,
"total": rtImportTaskState.Total,
"completed": rtImportTaskState.Completed,
"success": rtImportTaskState.Success,
"failed": rtImportTaskState.Failed,
"results": rtImportTaskState.Results,
"elapsed_ms": elapsed,
})
}
// HandleRTImportStop POST /api/rt-import/stop - 停止导入
func HandleRTImportStop(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
if !rtImportTaskState.Running {
Error(w, http.StatusBadRequest, "没有正在运行的导入任务")
return
}
rtImportTaskState.Running = false
if rtImportTaskState.stopCh != nil {
select {
case <-rtImportTaskState.stopCh:
// 已关闭
default:
close(rtImportTaskState.stopCh)
}
}
logger.Warning("RT 导入任务已收到停止信号", "", "rt-import")
Success(w, map[string]string{"message": "已发送停止信号"})
}
// runRTImport 执行批量导入
func runRTImport(req rtImportRequest) {
defer func() {
rtImportTaskState.Running = false
completed := atomic.LoadInt32(&rtImportTaskState.Completed)
success := atomic.LoadInt32(&rtImportTaskState.Success)
failed := atomic.LoadInt32(&rtImportTaskState.Failed)
logger.Success(fmt.Sprintf("RT 导入完成: 总数 %d, 成功 %d, 失败 %d",
completed, success, failed), "", "rt-import")
}()
for i, rt := range req.Tokens {
// 检查停止信号
if isRTImportStopped() {
logger.Warning(fmt.Sprintf("RT 导入已停止,跳过剩余 %d 个 Token", len(req.Tokens)-i), "", "rt-import")
break
}
result := processOneRT(i, rt, req)
rtImportTaskState.mu.Lock()
rtImportTaskState.Results = append(rtImportTaskState.Results, result)
rtImportTaskState.mu.Unlock()
atomic.AddInt32(&rtImportTaskState.Completed, 1)
if result.Success {
atomic.AddInt32(&rtImportTaskState.Success, 1)
} else {
atomic.AddInt32(&rtImportTaskState.Failed, 1)
}
// 限速,避免请求过快
time.Sleep(500 * time.Millisecond)
}
}
// processOneRT 处理单条 RT: 验证 → 创建账号
func processOneRT(index int, rt string, req rtImportRequest) rtImportResult {
result := rtImportResult{
Index: index + 1,
RT: maskRT(rt),
}
logger.Info(fmt.Sprintf("[%d/%d] 开始处理 RT: %s", index+1, req.Prefix, maskRT(rt)), "", "rt-import")
// Step 1: 通过 S2A 验证 RT
tokenInfo, err := validateRTViaS2A(rt, req.ProxyID)
if err != nil {
result.Error = fmt.Sprintf("验证失败: %v", err)
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), "", "rt-import")
return result
}
email, _ := tokenInfo["email"].(string)
if email == "" {
email = "unknown"
}
result.Email = email
logger.Info(fmt.Sprintf("[%d/%d] 验证成功: %s", index+1, len(req.Tokens), email), email, "rt-import")
// Step 2: 通过 S2A 创建账号
acctID, err := createAccountViaS2A(tokenInfo, req)
if err != nil {
result.Error = fmt.Sprintf("创建账号失败: %v", err)
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), email, "rt-import")
return result
}
result.Success = true
result.AcctID = acctID
logger.Success(fmt.Sprintf("[%d/%d] 账号创建成功: %s-%s (ID: %d)", index+1, len(req.Tokens), req.Prefix, email, acctID), email, "rt-import")
return result
}
// validateRTViaS2A 通过 S2A API 验证 Refresh Token
func validateRTViaS2A(rt string, proxyID *int) (map[string]interface{}, error) {
payload := map[string]interface{}{
"refresh_token": rt,
}
if proxyID != nil {
payload["proxy_id"] = *proxyID
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %v", err)
}
url := config.Global.S2AApiBase + "/api/v1/admin/openai/refresh-token"
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
setS2AHeaders(httpReq)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
// 检查 S2A 包装的 data 字段
if data, ok := result["data"].(map[string]interface{}); ok {
return data, nil
}
return result, nil
}
// createAccountViaS2A 通过 S2A API 创建账号
func createAccountViaS2A(tokenInfo map[string]interface{}, req rtImportRequest) (int, error) {
email, _ := tokenInfo["email"].(string)
if email == "" {
email = "unknown"
}
name := fmt.Sprintf("%s-%s", req.Prefix, email)
// 构建 credentials
credentials := map[string]interface{}{}
for _, key := range []string{"access_token", "refresh_token", "token_type", "expires_in", "expires_at", "scope"} {
if v, ok := tokenInfo[key]; ok && v != nil {
credentials[key] = v
}
}
// 可选字段
for _, key := range []string{"chatgpt_account_id", "chatgpt_user_id", "organization_id"} {
if v, ok := tokenInfo[key]; ok && v != nil {
credentials[key] = v
}
}
// 构建 extra
extra := map[string]interface{}{}
if email != "" && email != "unknown" {
extra["email"] = email
}
if nameVal, ok := tokenInfo["name"]; ok && nameVal != nil {
extra["name"] = nameVal
}
payload := map[string]interface{}{
"name": name,
"platform": "openai",
"type": "oauth",
"credentials": credentials,
"concurrency": req.Concurrency,
"priority": req.Priority,
"group_ids": req.GroupIDs,
"rate_multiplier": req.RateMultiplier,
}
if len(extra) > 0 {
payload["extra"] = extra
}
if req.ProxyID != nil {
payload["proxy_id"] = *req.ProxyID
}
body, err := json.Marshal(payload)
if err != nil {
return 0, fmt.Errorf("序列化请求失败: %v", err)
}
url := config.Global.S2AApiBase + "/api/v1/admin/accounts"
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return 0, fmt.Errorf("创建请求失败: %v", err)
}
setS2AHeaders(httpReq)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return 0, fmt.Errorf("请求失败: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return 0, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]interface{}
if err := json.Unmarshal(respBody, &result); err != nil {
return 0, fmt.Errorf("解析响应失败: %v", err)
}
// 从响应中提取账号ID
if data, ok := result["data"].(map[string]interface{}); ok {
if id, ok := data["id"].(float64); ok {
return int(id), nil
}
}
if id, ok := result["id"].(float64); ok {
return int(id), nil
}
return 0, nil
}