Files
codexautopool/backend/internal/api/team_reg_exec.go

555 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
"codex-pool/internal/database"
"codex-pool/internal/logger"
)
// TeamRegConfig 注册配置
type TeamRegConfig struct {
Count int `json:"count"` // 注册数量
Concurrency int `json:"concurrency"` // 并发线程数
Proxy string `json:"proxy"` // 代理地址
AutoImport bool `json:"auto_import"` // 完成后自动导入
}
// TeamRegState 运行状态
type TeamRegState struct {
Running bool `json:"running"`
StartedAt time.Time `json:"started_at"`
Config TeamRegConfig `json:"config"`
Logs []string `json:"logs"`
OutputFile string `json:"output_file"` // 生成的 JSON 文件
Imported int `json:"imported"` // 已导入数量
mu sync.Mutex
cmd *exec.Cmd
cancel context.CancelFunc
stdin io.WriteCloser
}
var teamRegState = &TeamRegState{
Logs: make([]string, 0),
}
// HandleTeamRegStart POST /api/team-reg/start - 启动注册进程
func HandleTeamRegStart(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
teamRegState.mu.Lock()
if teamRegState.Running {
teamRegState.mu.Unlock()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": "已有注册任务在运行中",
})
return
}
var config TeamRegConfig
if err := json.NewDecoder(r.Body).Decode(&config); err != nil {
teamRegState.mu.Unlock()
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// 验证参数
if config.Count < 1 {
config.Count = 1
}
if config.Count > 100 {
config.Count = 100
}
if config.Concurrency < 1 {
config.Concurrency = 1
}
if config.Concurrency > 10 {
config.Concurrency = 10
}
// 重置状态
teamRegState.Running = true
teamRegState.StartedAt = time.Now()
teamRegState.Config = config
teamRegState.Logs = make([]string, 0)
teamRegState.OutputFile = ""
teamRegState.Imported = 0
teamRegState.mu.Unlock()
// 启动进程
go runTeamRegProcess(config)
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "注册任务已启动",
})
}
// HandleTeamRegStop POST /api/team-reg/stop - 停止注册进程
func HandleTeamRegStop(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
teamRegState.mu.Lock()
defer teamRegState.mu.Unlock()
if !teamRegState.Running {
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": "没有正在运行的任务",
})
return
}
// 发送 Ctrl+C 信号
if teamRegState.cancel != nil {
teamRegState.cancel()
}
// 如果进程还在,强制终止
if teamRegState.cmd != nil && teamRegState.cmd.Process != nil {
teamRegState.cmd.Process.Kill()
}
teamRegState.Running = false
addTeamRegLog("[系统] 任务已被用户停止")
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "任务已停止",
})
}
// HandleTeamRegStatus GET /api/team-reg/status - 获取状态和日志
func HandleTeamRegStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
teamRegState.mu.Lock()
state := map[string]interface{}{
"running": teamRegState.Running,
"started_at": teamRegState.StartedAt,
"config": teamRegState.Config,
"logs": teamRegState.Logs,
"output_file": teamRegState.OutputFile,
"imported": teamRegState.Imported,
}
teamRegState.mu.Unlock()
json.NewEncoder(w).Encode(state)
}
// HandleTeamRegLogs GET /api/team-reg/logs - SSE 实时日志流
func HandleTeamRegLogs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
lastIndex := 0
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
teamRegState.mu.Lock()
running := teamRegState.Running
logs := teamRegState.Logs
teamRegState.mu.Unlock()
// 发送新日志
if len(logs) > lastIndex {
for i := lastIndex; i < len(logs); i++ {
fmt.Fprintf(w, "data: %s\n\n", logs[i])
}
lastIndex = len(logs)
flusher.Flush()
}
// 发送状态
if !running {
fmt.Fprintf(w, "event: done\ndata: finished\n\n")
flusher.Flush()
return
}
}
}
}
// HandleTeamRegImport POST /api/team-reg/import - 导入生成的 JSON 到数据库
func HandleTeamRegImport(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
teamRegState.mu.Lock()
outputFile := teamRegState.OutputFile
teamRegState.mu.Unlock()
if outputFile == "" {
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": "没有可导入的文件",
})
return
}
count, err := importAccountsFromJSON(outputFile)
if err != nil {
json.NewEncoder(w).Encode(map[string]interface{}{
"success": false,
"message": fmt.Sprintf("导入失败: %v", err),
})
return
}
teamRegState.mu.Lock()
teamRegState.Imported = count
teamRegState.mu.Unlock()
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": fmt.Sprintf("成功导入 %d 个账号", count),
"count": count,
})
}
// runTeamRegProcess 执行 team-reg 进程
func runTeamRegProcess(config TeamRegConfig) {
defer func() {
teamRegState.mu.Lock()
teamRegState.Running = false
teamRegState.mu.Unlock()
}()
// 查找 team-reg 可执行文件
execPath := findTeamRegExecutable()
if execPath == "" {
addTeamRegLog("[错误] 找不到 team-reg 可执行文件")
addTeamRegLog("[提示] 请确保 team-reg 文件位于 backend 目录下")
return
}
addTeamRegLog(fmt.Sprintf("[系统] 找到可执行文件: %s", execPath))
// Linux/macOS 上自动设置执行权限
if runtime.GOOS != "windows" {
if err := os.Chmod(execPath, 0755); err != nil {
addTeamRegLog(fmt.Sprintf("[警告] 设置执行权限失败: %v", err))
} else {
addTeamRegLog("[系统] 已设置执行权限 (chmod +x)")
}
}
addTeamRegLog(fmt.Sprintf("[系统] 配置: 数量=%d, 并发=%d, 代理=%s",
config.Count, config.Concurrency, config.Proxy))
// 创建上下文用于取消
ctx, cancel := context.WithCancel(context.Background())
teamRegState.mu.Lock()
teamRegState.cancel = cancel
teamRegState.mu.Unlock()
// 创建命令
cmd := exec.CommandContext(ctx, execPath)
// 设置工作目录(输出文件会保存在这里)
workDir := filepath.Dir(execPath)
cmd.Dir = workDir
// 获取 stdin, stdout, stderr
stdin, err := cmd.StdinPipe()
if err != nil {
addTeamRegLog(fmt.Sprintf("[错误] 无法获取 stdin: %v", err))
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
addTeamRegLog(fmt.Sprintf("[错误] 无法获取 stdout: %v", err))
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
addTeamRegLog(fmt.Sprintf("[错误] 无法获取 stderr: %v", err))
return
}
teamRegState.mu.Lock()
teamRegState.cmd = cmd
teamRegState.stdin = stdin
teamRegState.mu.Unlock()
// 启动进程
addTeamRegLog("[系统] 启动 team-reg 进程...")
if err := cmd.Start(); err != nil {
addTeamRegLog(fmt.Sprintf("[错误] 启动失败: %v", err))
return
}
// 合并 stdout 和 stderr 读取
go readOutput(stdout, workDir, config)
go readOutput(stderr, workDir, config)
// 等待一小段时间让程序启动
time.Sleep(500 * time.Millisecond)
// 发送输入参数
addTeamRegLog(fmt.Sprintf("[输入] 注册数量: %d", config.Count))
fmt.Fprintf(stdin, "%d\n", config.Count)
time.Sleep(200 * time.Millisecond)
addTeamRegLog(fmt.Sprintf("[输入] 并发线程数: %d", config.Concurrency))
fmt.Fprintf(stdin, "%d\n", config.Concurrency)
time.Sleep(200 * time.Millisecond)
addTeamRegLog(fmt.Sprintf("[输入] 代理地址: %s", config.Proxy))
fmt.Fprintf(stdin, "%s\n", config.Proxy)
// 等待进程完成
err = cmd.Wait()
if err != nil {
if ctx.Err() == context.Canceled {
addTeamRegLog("[系统] 进程已被取消")
} else {
addTeamRegLog(fmt.Sprintf("[系统] 进程退出: %v", err))
}
} else {
addTeamRegLog("[系统] 进程正常完成")
}
// 查找输出文件
outputFile := findLatestOutputFile(workDir)
if outputFile != "" {
teamRegState.mu.Lock()
teamRegState.OutputFile = outputFile
teamRegState.mu.Unlock()
addTeamRegLog(fmt.Sprintf("[系统] 输出文件: %s", filepath.Base(outputFile)))
// 自动导入
if config.AutoImport {
addTeamRegLog("[系统] 自动导入账号到数据库...")
count, err := importAccountsFromJSON(outputFile)
if err != nil {
addTeamRegLog(fmt.Sprintf("[错误] 导入失败: %v", err))
} else {
teamRegState.mu.Lock()
teamRegState.Imported = count
teamRegState.mu.Unlock()
addTeamRegLog(fmt.Sprintf("[系统] 成功导入 %d 个账号", count))
// 导入成功后删除 JSON 文件
if err := os.Remove(outputFile); err != nil {
addTeamRegLog(fmt.Sprintf("[警告] 删除临时文件失败: %v", err))
} else {
addTeamRegLog(fmt.Sprintf("[系统] 已清理临时文件: %s", filepath.Base(outputFile)))
}
}
}
}
// 发送回车退出程序(如果还在运行)
time.Sleep(500 * time.Millisecond)
if stdin != nil {
fmt.Fprintf(stdin, "\n")
}
}
// readOutput 读取进程输出
func readOutput(reader io.Reader, workDir string, config TeamRegConfig) {
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
// 过滤空行和只有空格的行
trimmed := strings.TrimSpace(line)
if trimmed != "" {
addTeamRegLog(trimmed)
}
}
}
// addTeamRegLog 添加日志
func addTeamRegLog(log string) {
teamRegState.mu.Lock()
defer teamRegState.mu.Unlock()
timestamp := time.Now().Format("15:04:05")
fullLog := fmt.Sprintf("[%s] %s", timestamp, log)
teamRegState.Logs = append(teamRegState.Logs, fullLog)
// 限制日志数量
if len(teamRegState.Logs) > 1000 {
teamRegState.Logs = teamRegState.Logs[len(teamRegState.Logs)-1000:]
}
// 同时输出到系统日志
logger.Info(fmt.Sprintf("[TeamReg] %s", log), "", "team-reg")
}
// findTeamRegExecutable 查找 team-reg 可执行文件
func findTeamRegExecutable() string {
// 可能的文件名
var names []string
if runtime.GOOS == "windows" {
names = []string{"team-reg.exe", "team-reg"}
} else {
names = []string{"team-reg", "team-reg.exe"}
}
// 获取当前工作目录
cwd, _ := os.Getwd()
// 获取可执行文件的真实路径(解析符号链接)
execPath, _ := os.Executable()
realExecPath, err := filepath.EvalSymlinks(execPath)
if err == nil {
execPath = realExecPath
}
execDir := filepath.Dir(execPath)
// 可能的路径(按优先级排序)
paths := []string{
execDir, // 可执行文件所在目录最可靠team-reg 应与后端在同一目录)
cwd, // 当前工作目录
filepath.Join(cwd, "backend"), // cwd/backend
filepath.Join(execDir, ".."), // 可执行文件上级目录
filepath.Join(execDir, "backend"), // execDir/backend
".", // 相对当前目录
"backend", // 相对 backend 目录
"..", // 上级目录
filepath.Join("..", "backend"), // ../backend
filepath.Join("..", ".."), // 更上级
filepath.Join("..", "..", "backend"), // ../../backend
}
for _, basePath := range paths {
for _, name := range names {
fullPath := filepath.Join(basePath, name)
if absPath, err := filepath.Abs(fullPath); err == nil {
if _, err := os.Stat(absPath); err == nil {
return absPath
}
}
}
}
return ""
}
// findLatestOutputFile 查找最新的输出文件
func findLatestOutputFile(dir string) string {
pattern := filepath.Join(dir, "accounts-*.json")
matches, err := filepath.Glob(pattern)
if err != nil || len(matches) == 0 {
return ""
}
// 按修改时间排序,取最新的
sort.Slice(matches, func(i, j int) bool {
fi, _ := os.Stat(matches[i])
fj, _ := os.Stat(matches[j])
if fi == nil || fj == nil {
return false
}
return fi.ModTime().After(fj.ModTime())
})
// 确保是最近创建的文件5分钟内
fi, err := os.Stat(matches[0])
if err != nil {
return ""
}
if time.Since(fi.ModTime()) > 5*time.Minute {
return ""
}
return matches[0]
}
// TeamRegAccount team-reg 输出的账号格式
type TeamRegAccount struct {
Account string `json:"account"`
Password string `json:"password"`
Token string `json:"token"`
AccountID string `json:"account_id"`
PlanType string `json:"plan_type"`
}
// importAccountsFromJSON 从 JSON 文件导入账号
func importAccountsFromJSON(filePath string) (int, error) {
if database.Instance == nil {
return 0, fmt.Errorf("数据库未初始化")
}
data, err := os.ReadFile(filePath)
if err != nil {
return 0, err
}
var accounts []TeamRegAccount
if err := json.Unmarshal(data, &accounts); err != nil {
return 0, err
}
// 转换为 database.TeamOwner 格式
var owners []database.TeamOwner
for _, acc := range accounts {
if acc.Account == "" || acc.Password == "" {
continue
}
// 提取 account_id去掉 org- 前缀如果有的话)
accountID := acc.AccountID
if strings.HasPrefix(accountID, "org-") {
accountID = strings.TrimPrefix(accountID, "org-")
}
owners = append(owners, database.TeamOwner{
Email: acc.Account,
Password: acc.Password,
Token: acc.Token,
AccountID: accountID,
})
}
// 批量导入
count, err := database.Instance.AddTeamOwners(owners)
if err != nil {
return 0, err
}
return count, nil
}