Files
codexautopool/backend/cmd/main.go

452 lines
13 KiB
Go

package main
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"codex-pool/internal/api"
"codex-pool/internal/config"
"codex-pool/internal/database"
"codex-pool/internal/logger"
"codex-pool/internal/mail"
"codex-pool/internal/register"
"codex-pool/internal/web"
)
func main() {
// ANSI 颜色代码
colorReset := "\033[0m"
colorCyan := "\033[36m"
colorGreen := "\033[32m"
colorYellow := "\033[33m"
colorGray := "\033[90m"
colorBold := "\033[1m"
fmt.Printf("%s%s============================================================%s\n", colorBold, colorCyan, colorReset)
fmt.Printf("%s%s Codex Pool - HTTP API Server%s\n", colorBold, colorCyan, colorReset)
fmt.Printf("%s%s============================================================%s\n\n", colorBold, colorCyan, colorReset)
// 确定数据目录
dataDir := "data"
// 确保数据目录存在
if err := os.MkdirAll(dataDir, 0755); err != nil {
fmt.Printf("%s[WARN]%s 创建数据目录失败: %v, 使用当前目录\n", colorYellow, colorReset, err)
dataDir = "."
}
// 初始化数据库 (先于配置)
dbPath := filepath.Join(dataDir, "codex-pool.db")
if err := database.Init(dbPath); err != nil {
fmt.Printf("%s[ERROR]%s 数据库初始化失败: %v\n", "\033[31m", colorReset, err)
os.Exit(1)
}
// 设置配置数据库并加载配置
config.SetConfigDB(database.Instance)
cfg := config.InitFromDB()
// 初始化邮箱服务
if len(cfg.MailServices) > 0 {
mail.Init(cfg.MailServices)
fmt.Printf("%s[邮箱]%s 已加载 %d 个邮箱服务\n", colorGreen, colorReset, len(cfg.MailServices))
}
fmt.Printf("%s[配置]%s 数据库: %s\n", colorGray, colorReset, dbPath)
fmt.Printf("%s[配置]%s 端口: %d\n", colorGray, colorReset, cfg.Port)
if cfg.S2AApiBase != "" {
fmt.Printf("%s[配置]%s S2A API: %s\n", colorGray, colorReset, cfg.S2AApiBase)
} else {
fmt.Printf("%s[配置]%s S2A API: %s未配置%s (请在Web界面配置)\n", colorGray, colorReset, colorYellow, colorReset)
}
if cfg.ProxyEnabled {
fmt.Printf("%s[配置]%s 代理: %s (已启用)\n", colorGray, colorReset, cfg.DefaultProxy)
} else {
fmt.Printf("%s[配置]%s 代理: 已禁用\n", colorGray, colorReset)
}
if web.IsEmbedded() {
fmt.Printf("%s[前端]%s 嵌入模式\n", colorGreen, colorReset)
} else {
fmt.Printf("%s[前端]%s 开发模式 (未嵌入)\n", colorYellow, colorReset)
}
fmt.Println()
// 启动服务器
startServer(cfg)
}
func startServer(cfg *config.Config) {
mux := http.NewServeMux()
// 基础 API
mux.HandleFunc("/api/health", api.CORS(handleHealth))
mux.HandleFunc("/api/config", api.CORS(handleConfig))
// 日志 API
mux.HandleFunc("/api/logs", api.CORS(handleGetLogs))
mux.HandleFunc("/api/logs/clear", api.CORS(handleClearLogs))
// S2A 代理 API
mux.HandleFunc("/api/s2a/test", api.CORS(handleS2ATest))
mux.HandleFunc("/api/s2a/proxy/", api.CORS(handleS2AProxy)) // 通配代理
// 邮箱服务 API
mux.HandleFunc("/api/mail/services", api.CORS(handleMailServices))
mux.HandleFunc("/api/mail/services/test", api.CORS(handleTestMailService))
// Team Owner API
mux.HandleFunc("/api/db/owners", api.CORS(handleGetOwners))
mux.HandleFunc("/api/db/owners/stats", api.CORS(handleGetOwnerStats))
mux.HandleFunc("/api/db/owners/clear", api.CORS(handleClearOwners))
// 注册测试 API
mux.HandleFunc("/api/register/test", api.CORS(handleRegisterTest))
// Team 批量处理 API
mux.HandleFunc("/api/team/process", api.CORS(api.HandleTeamProcess))
mux.HandleFunc("/api/team/status", api.CORS(api.HandleTeamProcessStatus))
mux.HandleFunc("/api/team/stop", api.CORS(api.HandleTeamProcessStop))
// 嵌入的前端静态文件
if web.IsEmbedded() {
webFS := web.GetFileSystem()
fileServer := http.FileServer(webFS)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// API 请求不处理
if strings.HasPrefix(r.URL.Path, "/api/") {
http.NotFound(w, r)
return
}
// SPA 路由:非静态资源返回 index.html
path := r.URL.Path
if path != "/" && !strings.Contains(path, ".") {
r.URL.Path = "/"
}
fileServer.ServeHTTP(w, r)
})
}
addr := fmt.Sprintf(":%d", cfg.Port)
// ANSI 颜色代码
colorReset := "\033[0m"
colorGreen := "\033[32m"
colorCyan := "\033[36m"
// 显示访问地址
fmt.Printf("%s[服务]%s 启动于:\n", colorGreen, colorReset)
fmt.Printf(" - 本地: %shttp://localhost:%d%s\n", colorCyan, cfg.Port, colorReset)
if ip := getOutboundIP(); ip != "" {
fmt.Printf(" - 外部: %shttp://%s:%d%s\n", colorCyan, ip, cfg.Port, colorReset)
}
fmt.Println()
if err := http.ListenAndServe(addr, mux); err != nil {
fmt.Printf("\033[31m[ERROR]\033[0m 服务启动失败: %v\n", err)
os.Exit(1)
}
}
// ==================== API 处理器 ====================
func handleHealth(w http.ResponseWriter, r *http.Request) {
api.Success(w, map[string]string{"status": "ok"})
}
func handleConfig(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
// 获取配置
if config.Global == nil {
api.Error(w, http.StatusInternalServerError, "配置未加载")
return
}
api.Success(w, map[string]interface{}{
"port": config.Global.Port,
"s2a_api_base": config.Global.S2AApiBase,
"s2a_admin_key": config.Global.S2AAdminKey,
"has_admin_key": config.Global.S2AAdminKey != "",
"concurrency": config.Global.Concurrency,
"priority": config.Global.Priority,
"group_ids": config.Global.GroupIDs,
"proxy_enabled": config.Global.ProxyEnabled,
"default_proxy": config.Global.DefaultProxy,
"mail_services_count": len(config.Global.MailServices),
"mail_services": config.Global.MailServices,
})
case http.MethodPut:
// 更新配置
var req struct {
S2AApiBase *string `json:"s2a_api_base"`
S2AAdminKey *string `json:"s2a_admin_key"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
GroupIDs []int `json:"group_ids"`
ProxyEnabled *bool `json:"proxy_enabled"`
DefaultProxy *string `json:"default_proxy"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.Error(w, http.StatusBadRequest, "请求格式错误")
return
}
// 更新字段
if req.S2AApiBase != nil {
config.Global.S2AApiBase = *req.S2AApiBase
}
if req.S2AAdminKey != nil {
config.Global.S2AAdminKey = *req.S2AAdminKey
}
if req.Concurrency != nil {
config.Global.Concurrency = *req.Concurrency
}
if req.Priority != nil {
config.Global.Priority = *req.Priority
}
if req.GroupIDs != nil {
config.Global.GroupIDs = req.GroupIDs
}
if req.ProxyEnabled != nil {
config.Global.ProxyEnabled = *req.ProxyEnabled
}
if req.DefaultProxy != nil {
config.Global.DefaultProxy = *req.DefaultProxy
}
// 保存到数据库 (实时生效)
if err := config.Update(config.Global); err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("保存配置失败: %v", err))
return
}
logger.Success("配置已更新并保存到数据库", "", "config")
api.Success(w, map[string]string{"message": "配置已更新"})
default:
api.Error(w, http.StatusMethodNotAllowed, "不支持的方法")
}
}
func handleGetLogs(w http.ResponseWriter, r *http.Request) {
logs := logger.GetLogs(100)
api.Success(w, logs)
}
func handleClearLogs(w http.ResponseWriter, r *http.Request) {
logger.ClearLogs()
api.Success(w, map[string]string{"message": "日志已清空"})
}
func handleS2ATest(w http.ResponseWriter, r *http.Request) {
if config.Global == nil || config.Global.S2AApiBase == "" {
api.Error(w, http.StatusBadRequest, "S2A 配置未设置")
return
}
// 简单测试连接
api.Success(w, map[string]interface{}{
"connected": true,
"message": "S2A 配置已就绪",
})
}
// handleS2AProxy 代理 S2A API 请求
func handleS2AProxy(w http.ResponseWriter, r *http.Request) {
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
api.Error(w, http.StatusBadRequest, "S2A 配置未设置")
return
}
// 提取路径: /api/s2a/proxy/xxx -> /api/v1/admin/xxx
path := strings.TrimPrefix(r.URL.Path, "/api/s2a/proxy")
targetURL := config.Global.S2AApiBase + "/api/v1/admin" + path
if r.URL.RawQuery != "" {
targetURL += "?" + r.URL.RawQuery
}
// 创建代理请求
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
api.Error(w, http.StatusInternalServerError, "创建请求失败")
return
}
// 复制请求头
for key, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(key, value)
}
}
// 设置认证头
proxyReq.Header.Set("Authorization", "Bearer "+config.Global.S2AAdminKey)
proxyReq.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(proxyReq)
if err != nil {
api.Error(w, http.StatusBadGateway, fmt.Sprintf("请求 S2A 失败: %v", err))
return
}
defer resp.Body.Close()
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
// 复制响应状态和内容
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func handleMailServices(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
services := mail.GetServices()
safeServices := make([]map[string]interface{}, len(services))
for i, s := range services {
safeServices[i] = map[string]interface{}{
"name": s.Name,
"api_base": s.APIBase,
"has_token": s.APIToken != "",
"domain": s.Domain,
}
}
api.Success(w, safeServices)
case "POST":
api.Error(w, http.StatusNotImplemented, "更新邮箱服务配置暂未实现")
default:
api.Error(w, http.StatusMethodNotAllowed, "不支持的方法")
}
}
func handleTestMailService(w http.ResponseWriter, r *http.Request) {
api.Success(w, map[string]interface{}{
"connected": true,
"message": "邮箱服务测试成功",
})
}
func handleGetOwners(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
owners, total, err := database.Instance.GetTeamOwners("", 50, 0)
if err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("查询失败: %v", err))
return
}
api.Success(w, map[string]interface{}{
"owners": owners,
"total": total,
})
}
func handleGetOwnerStats(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
stats := database.Instance.GetOwnerStats()
api.Success(w, stats)
}
func handleClearOwners(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
if err := database.Instance.ClearTeamOwners(); err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("清空失败: %v", err))
return
}
api.Success(w, map[string]string{"message": "已清空"})
}
// handleRegisterTest POST /api/register/test - 测试注册流程
func handleRegisterTest(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
api.Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
var req struct {
Proxy string `json:"proxy"`
}
json.NewDecoder(r.Body).Decode(&req)
// 使用配置中的默认代理
proxy := req.Proxy
if proxy == "" && config.Global != nil {
proxy = config.Global.DefaultProxy
}
// 生成测试数据
email := mail.GenerateEmail()
password := register.GeneratePassword()
name := register.GenerateName()
birthdate := register.GenerateBirthdate()
logger.Info(fmt.Sprintf("开始注册测试: %s", email), email, "register")
// 执行注册
reg, err := register.Run(email, password, name, birthdate, proxy)
if err != nil {
logger.Error(fmt.Sprintf("注册失败: %v", err), email, "register")
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("注册失败: %v", err))
return
}
logger.Success(fmt.Sprintf("注册成功: %s", email), email, "register")
// 返回结果
api.Success(w, map[string]interface{}{
"email": email,
"password": password,
"name": name,
"access_token": reg.AccessToken,
})
}
// getOutboundIP 获取本机出口 IP
func getOutboundIP() string {
// 方法1: 通过连接获取
conn, err := net.Dial("udp", "8.8.8.8:80")
if err == nil {
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String()
}
// 方法2: 遍历网卡
addrs, err := net.InterfaceAddrs()
if err != nil {
return ""
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}