424 lines
12 KiB
Go
424 lines
12 KiB
Go
package api
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"codex-pool/internal/config"
|
||
"codex-pool/internal/logger"
|
||
)
|
||
|
||
// ============================================================================
|
||
// 批量 RT 导入模块
|
||
// 功能: 读取 Refresh Token 列表,通过 S2A API 验证并创建账号
|
||
// ============================================================================
|
||
|
||
// rtImportRequest 导入请求
|
||
type rtImportRequest struct {
|
||
Tokens []string `json:"tokens"`
|
||
Prefix string `json:"prefix"` // "team" 或 "free"
|
||
Concurrency int `json:"concurrency"` // S2A 账号并发数
|
||
Priority int `json:"priority"` // S2A 账号优先级
|
||
GroupIDs []int `json:"group_ids"` // S2A 分组ID
|
||
ProxyID *int `json:"proxy_id"` // S2A 代理ID
|
||
RateMultiplier float64 `json:"rate_multiplier"` // 计费倍率
|
||
}
|
||
|
||
// rtImportResult 单条导入结果
|
||
type rtImportResult struct {
|
||
Index int `json:"index"`
|
||
RT string `json:"rt"` // 脱敏后的 RT 前缀
|
||
Email string `json:"email"` // 验证得到的邮箱
|
||
Success bool `json:"success"`
|
||
Error string `json:"error,omitempty"`
|
||
AcctID int `json:"account_id,omitempty"` // S2A 创建的账号ID
|
||
}
|
||
|
||
// rtImportState 导入任务状态
|
||
type rtImportState struct {
|
||
Running bool `json:"running"`
|
||
StartedAt time.Time `json:"started_at"`
|
||
Total int `json:"total"`
|
||
Completed int32 `json:"completed"`
|
||
Success int32 `json:"success"`
|
||
Failed int32 `json:"failed"`
|
||
Results []rtImportResult `json:"results"`
|
||
mu sync.Mutex
|
||
stopCh chan struct{}
|
||
}
|
||
|
||
var rtImportTaskState = &rtImportState{}
|
||
|
||
// isRTImportStopped 检查导入任务是否已被停止
|
||
func isRTImportStopped() bool {
|
||
select {
|
||
case <-rtImportTaskState.stopCh:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// maskRT 脱敏处理 RT,只显示前16个字符
|
||
func maskRT(rt string) string {
|
||
if len(rt) <= 16 {
|
||
return rt
|
||
}
|
||
return rt[:16] + "..."
|
||
}
|
||
|
||
// HandleRTImportStart POST /api/rt-import/start - 启动批量 RT 导入
|
||
func HandleRTImportStart(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||
return
|
||
}
|
||
|
||
// 检查是否正在运行
|
||
if rtImportTaskState.Running {
|
||
Error(w, http.StatusConflict, "已有导入任务正在运行")
|
||
return
|
||
}
|
||
|
||
// 检查 S2A 配置
|
||
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
|
||
Error(w, http.StatusBadRequest, "S2A 配置未设置,请先配置 S2A API 地址和 Admin Key")
|
||
return
|
||
}
|
||
|
||
var req rtImportRequest
|
||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||
Error(w, http.StatusBadRequest, fmt.Sprintf("请求格式错误: %v", err))
|
||
return
|
||
}
|
||
|
||
// 参数校验
|
||
if len(req.Tokens) == 0 {
|
||
Error(w, http.StatusBadRequest, "没有提供 Refresh Token")
|
||
return
|
||
}
|
||
if req.Prefix != "team" && req.Prefix != "free" {
|
||
Error(w, http.StatusBadRequest, "前缀必须是 team 或 free")
|
||
return
|
||
}
|
||
|
||
// 默认值
|
||
if req.Concurrency <= 0 {
|
||
if config.Global != nil && config.Global.Concurrency > 0 {
|
||
req.Concurrency = config.Global.Concurrency
|
||
} else {
|
||
req.Concurrency = 3
|
||
}
|
||
}
|
||
if req.Priority <= 0 {
|
||
if config.Global != nil && config.Global.Priority > 0 {
|
||
req.Priority = config.Global.Priority
|
||
} else {
|
||
req.Priority = 10
|
||
}
|
||
}
|
||
if req.RateMultiplier <= 0 {
|
||
req.RateMultiplier = 1.0
|
||
}
|
||
if len(req.GroupIDs) == 0 && config.Global != nil && len(config.Global.GroupIDs) > 0 {
|
||
req.GroupIDs = config.Global.GroupIDs
|
||
}
|
||
|
||
// 初始化状态
|
||
rtImportTaskState.Running = true
|
||
rtImportTaskState.stopCh = make(chan struct{})
|
||
rtImportTaskState.StartedAt = time.Now()
|
||
rtImportTaskState.Total = len(req.Tokens)
|
||
rtImportTaskState.Completed = 0
|
||
rtImportTaskState.Success = 0
|
||
rtImportTaskState.Failed = 0
|
||
rtImportTaskState.Results = make([]rtImportResult, 0, len(req.Tokens))
|
||
|
||
// 异步执行
|
||
go runRTImport(req)
|
||
|
||
logger.Info(fmt.Sprintf("RT 导入任务已启动: 共 %d 个 Token, 前缀: %s", len(req.Tokens), req.Prefix), "", "rt-import")
|
||
|
||
Success(w, map[string]interface{}{
|
||
"message": "导入任务已启动",
|
||
"total": len(req.Tokens),
|
||
"prefix": req.Prefix,
|
||
})
|
||
}
|
||
|
||
// HandleRTImportStatus GET /api/rt-import/status - 获取导入状态
|
||
func HandleRTImportStatus(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
Error(w, http.StatusMethodNotAllowed, "仅支持 GET")
|
||
return
|
||
}
|
||
|
||
rtImportTaskState.mu.Lock()
|
||
defer rtImportTaskState.mu.Unlock()
|
||
|
||
elapsed := int64(0)
|
||
if !rtImportTaskState.StartedAt.IsZero() {
|
||
elapsed = time.Since(rtImportTaskState.StartedAt).Milliseconds()
|
||
}
|
||
|
||
Success(w, map[string]interface{}{
|
||
"running": rtImportTaskState.Running,
|
||
"started_at": rtImportTaskState.StartedAt,
|
||
"total": rtImportTaskState.Total,
|
||
"completed": rtImportTaskState.Completed,
|
||
"success": rtImportTaskState.Success,
|
||
"failed": rtImportTaskState.Failed,
|
||
"results": rtImportTaskState.Results,
|
||
"elapsed_ms": elapsed,
|
||
})
|
||
}
|
||
|
||
// HandleRTImportStop POST /api/rt-import/stop - 停止导入
|
||
func HandleRTImportStop(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodPost {
|
||
Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
|
||
return
|
||
}
|
||
|
||
if !rtImportTaskState.Running {
|
||
Error(w, http.StatusBadRequest, "没有正在运行的导入任务")
|
||
return
|
||
}
|
||
|
||
rtImportTaskState.Running = false
|
||
if rtImportTaskState.stopCh != nil {
|
||
select {
|
||
case <-rtImportTaskState.stopCh:
|
||
// 已关闭
|
||
default:
|
||
close(rtImportTaskState.stopCh)
|
||
}
|
||
}
|
||
logger.Warning("RT 导入任务已收到停止信号", "", "rt-import")
|
||
Success(w, map[string]string{"message": "已发送停止信号"})
|
||
}
|
||
|
||
// runRTImport 执行批量导入
|
||
func runRTImport(req rtImportRequest) {
|
||
defer func() {
|
||
rtImportTaskState.Running = false
|
||
completed := atomic.LoadInt32(&rtImportTaskState.Completed)
|
||
success := atomic.LoadInt32(&rtImportTaskState.Success)
|
||
failed := atomic.LoadInt32(&rtImportTaskState.Failed)
|
||
logger.Success(fmt.Sprintf("RT 导入完成: 总数 %d, 成功 %d, 失败 %d",
|
||
completed, success, failed), "", "rt-import")
|
||
}()
|
||
|
||
for i, rt := range req.Tokens {
|
||
// 检查停止信号
|
||
if isRTImportStopped() {
|
||
logger.Warning(fmt.Sprintf("RT 导入已停止,跳过剩余 %d 个 Token", len(req.Tokens)-i), "", "rt-import")
|
||
break
|
||
}
|
||
|
||
result := processOneRT(i, rt, req)
|
||
|
||
rtImportTaskState.mu.Lock()
|
||
rtImportTaskState.Results = append(rtImportTaskState.Results, result)
|
||
rtImportTaskState.mu.Unlock()
|
||
|
||
atomic.AddInt32(&rtImportTaskState.Completed, 1)
|
||
if result.Success {
|
||
atomic.AddInt32(&rtImportTaskState.Success, 1)
|
||
} else {
|
||
atomic.AddInt32(&rtImportTaskState.Failed, 1)
|
||
}
|
||
|
||
// 限速,避免请求过快
|
||
time.Sleep(500 * time.Millisecond)
|
||
}
|
||
}
|
||
|
||
// processOneRT 处理单条 RT: 验证 → 创建账号
|
||
func processOneRT(index int, rt string, req rtImportRequest) rtImportResult {
|
||
result := rtImportResult{
|
||
Index: index + 1,
|
||
RT: maskRT(rt),
|
||
}
|
||
|
||
logger.Info(fmt.Sprintf("[%d/%d] 开始处理 RT: %s", index+1, req.Prefix, maskRT(rt)), "", "rt-import")
|
||
|
||
// Step 1: 通过 S2A 验证 RT
|
||
tokenInfo, err := validateRTViaS2A(rt, req.ProxyID)
|
||
if err != nil {
|
||
result.Error = fmt.Sprintf("验证失败: %v", err)
|
||
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), "", "rt-import")
|
||
return result
|
||
}
|
||
|
||
email, _ := tokenInfo["email"].(string)
|
||
if email == "" {
|
||
email = "unknown"
|
||
}
|
||
result.Email = email
|
||
logger.Info(fmt.Sprintf("[%d/%d] 验证成功: %s", index+1, len(req.Tokens), email), email, "rt-import")
|
||
|
||
// Step 2: 通过 S2A 创建账号
|
||
acctID, err := createAccountViaS2A(tokenInfo, req)
|
||
if err != nil {
|
||
result.Error = fmt.Sprintf("创建账号失败: %v", err)
|
||
logger.Error(fmt.Sprintf("[%d/%d] %s", index+1, len(req.Tokens), result.Error), email, "rt-import")
|
||
return result
|
||
}
|
||
|
||
result.Success = true
|
||
result.AcctID = acctID
|
||
logger.Success(fmt.Sprintf("[%d/%d] 账号创建成功: %s-%s (ID: %d)", index+1, len(req.Tokens), req.Prefix, email, acctID), email, "rt-import")
|
||
return result
|
||
}
|
||
|
||
// validateRTViaS2A 通过 S2A API 验证 Refresh Token
|
||
func validateRTViaS2A(rt string, proxyID *int) (map[string]interface{}, error) {
|
||
payload := map[string]interface{}{
|
||
"refresh_token": rt,
|
||
}
|
||
if proxyID != nil {
|
||
payload["proxy_id"] = *proxyID
|
||
}
|
||
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
url := config.Global.S2AApiBase + "/api/v1/admin/openai/refresh-token"
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
setS2AHeaders(httpReq)
|
||
|
||
client := &http.Client{Timeout: 30 * time.Second}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
var result map[string]interface{}
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
// 检查 S2A 包装的 data 字段
|
||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||
return data, nil
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// createAccountViaS2A 通过 S2A API 创建账号
|
||
func createAccountViaS2A(tokenInfo map[string]interface{}, req rtImportRequest) (int, error) {
|
||
email, _ := tokenInfo["email"].(string)
|
||
if email == "" {
|
||
email = "unknown"
|
||
}
|
||
name := fmt.Sprintf("%s-%s", req.Prefix, email)
|
||
|
||
// 构建 credentials
|
||
credentials := map[string]interface{}{}
|
||
for _, key := range []string{"access_token", "refresh_token", "token_type", "expires_in", "expires_at", "scope"} {
|
||
if v, ok := tokenInfo[key]; ok && v != nil {
|
||
credentials[key] = v
|
||
}
|
||
}
|
||
// 可选字段
|
||
for _, key := range []string{"chatgpt_account_id", "chatgpt_user_id", "organization_id"} {
|
||
if v, ok := tokenInfo[key]; ok && v != nil {
|
||
credentials[key] = v
|
||
}
|
||
}
|
||
|
||
// 构建 extra
|
||
extra := map[string]interface{}{}
|
||
if email != "" && email != "unknown" {
|
||
extra["email"] = email
|
||
}
|
||
if nameVal, ok := tokenInfo["name"]; ok && nameVal != nil {
|
||
extra["name"] = nameVal
|
||
}
|
||
|
||
payload := map[string]interface{}{
|
||
"name": name,
|
||
"platform": "openai",
|
||
"type": "oauth",
|
||
"credentials": credentials,
|
||
"concurrency": req.Concurrency,
|
||
"priority": req.Priority,
|
||
"group_ids": req.GroupIDs,
|
||
"rate_multiplier": req.RateMultiplier,
|
||
}
|
||
if len(extra) > 0 {
|
||
payload["extra"] = extra
|
||
}
|
||
if req.ProxyID != nil {
|
||
payload["proxy_id"] = *req.ProxyID
|
||
}
|
||
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("序列化请求失败: %v", err)
|
||
}
|
||
|
||
url := config.Global.S2AApiBase + "/api/v1/admin/accounts"
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(body))
|
||
if err != nil {
|
||
return 0, fmt.Errorf("创建请求失败: %v", err)
|
||
}
|
||
setS2AHeaders(httpReq)
|
||
|
||
client := &http.Client{Timeout: 30 * time.Second}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("请求失败: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("读取响应失败: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||
return 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
var result map[string]interface{}
|
||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||
return 0, fmt.Errorf("解析响应失败: %v", err)
|
||
}
|
||
|
||
// 从响应中提取账号ID
|
||
if data, ok := result["data"].(map[string]interface{}); ok {
|
||
if id, ok := data["id"].(float64); ok {
|
||
return int(id), nil
|
||
}
|
||
}
|
||
if id, ok := result["id"].(float64); ok {
|
||
return int(id), nil
|
||
}
|
||
|
||
return 0, nil
|
||
}
|