Files
codexautopool/backend/cmd/main.go

477 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"codex-pool/internal/api"
"codex-pool/internal/config"
"codex-pool/internal/database"
"codex-pool/internal/logger"
"codex-pool/internal/mail"
"codex-pool/internal/register"
"codex-pool/internal/web"
)
func main() {
// ANSI 颜色代码
colorReset := "\033[0m"
colorCyan := "\033[36m"
colorGreen := "\033[32m"
colorYellow := "\033[33m"
colorGray := "\033[90m"
colorBold := "\033[1m"
fmt.Printf("%s%s============================================================%s\n", colorBold, colorCyan, colorReset)
fmt.Printf("%s%s Codex Pool - HTTP API Server%s\n", colorBold, colorCyan, colorReset)
fmt.Printf("%s%s============================================================%s\n\n", colorBold, colorCyan, colorReset)
// 确定数据目录
dataDir := "data"
// 确保数据目录存在
if err := os.MkdirAll(dataDir, 0755); err != nil {
fmt.Printf("%s[WARN]%s 创建数据目录失败: %v, 使用当前目录\n", colorYellow, colorReset, err)
dataDir = "."
}
// 初始化数据库 (先于配置)
dbPath := filepath.Join(dataDir, "codex-pool.db")
if err := database.Init(dbPath); err != nil {
fmt.Printf("%s[ERROR]%s 数据库初始化失败: %v\n", "\033[31m", colorReset, err)
os.Exit(1)
}
// 设置配置数据库并加载配置
config.SetConfigDB(database.Instance)
cfg := config.InitFromDB()
// 初始化邮箱服务
if len(cfg.MailServices) > 0 {
mail.Init(cfg.MailServices)
fmt.Printf("%s[邮箱]%s 已加载 %d 个邮箱服务\n", colorGreen, colorReset, len(cfg.MailServices))
}
fmt.Printf("%s[配置]%s 数据库: %s\n", colorGray, colorReset, dbPath)
fmt.Printf("%s[配置]%s 端口: %d\n", colorGray, colorReset, cfg.Port)
if cfg.S2AApiBase != "" {
fmt.Printf("%s[配置]%s S2A API: %s\n", colorGray, colorReset, cfg.S2AApiBase)
} else {
fmt.Printf("%s[配置]%s S2A API: %s未配置%s (请在Web界面配置)\n", colorGray, colorReset, colorYellow, colorReset)
}
if cfg.ProxyEnabled {
fmt.Printf("%s[配置]%s 代理: %s (已启用)\n", colorGray, colorReset, cfg.DefaultProxy)
} else {
fmt.Printf("%s[配置]%s 代理: 已禁用\n", colorGray, colorReset)
}
if web.IsEmbedded() {
fmt.Printf("%s[前端]%s 嵌入模式\n", colorGreen, colorReset)
} else {
fmt.Printf("%s[前端]%s 开发模式 (未嵌入)\n", colorYellow, colorReset)
}
fmt.Println()
// 启动服务器
startServer(cfg)
}
func startServer(cfg *config.Config) {
mux := http.NewServeMux()
// 基础 API
mux.HandleFunc("/api/health", api.CORS(handleHealth))
mux.HandleFunc("/api/config", api.CORS(handleConfig))
// 日志 API
mux.HandleFunc("/api/logs", api.CORS(handleGetLogs))
mux.HandleFunc("/api/logs/clear", api.CORS(handleClearLogs))
// S2A 代理 API
mux.HandleFunc("/api/s2a/test", api.CORS(handleS2ATest))
mux.HandleFunc("/api/s2a/proxy/", api.CORS(handleS2AProxy)) // 通配代理
// 邮箱服务 API
mux.HandleFunc("/api/mail/services", api.CORS(handleMailServices))
mux.HandleFunc("/api/mail/services/test", api.CORS(handleTestMailService))
// Team Owner API
mux.HandleFunc("/api/db/owners", api.CORS(handleGetOwners))
mux.HandleFunc("/api/db/owners/stats", api.CORS(handleGetOwnerStats))
mux.HandleFunc("/api/db/owners/clear", api.CORS(handleClearOwners))
mux.HandleFunc("/api/upload/validate", api.CORS(api.HandleUploadValidate))
// 注册测试 API
mux.HandleFunc("/api/register/test", api.CORS(handleRegisterTest))
// Team 批量处理 API
mux.HandleFunc("/api/team/process", api.CORS(api.HandleTeamProcess))
mux.HandleFunc("/api/team/status", api.CORS(api.HandleTeamProcessStatus))
mux.HandleFunc("/api/team/stop", api.CORS(api.HandleTeamProcessStop))
// 嵌入的前端静态文件
if web.IsEmbedded() {
webFS := web.GetFileSystem()
fileServer := http.FileServer(webFS)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// API 请求不处理
if strings.HasPrefix(r.URL.Path, "/api/") {
http.NotFound(w, r)
return
}
// SPA 路由:非静态资源返回 index.html
path := r.URL.Path
if path != "/" && !strings.Contains(path, ".") {
r.URL.Path = "/"
}
fileServer.ServeHTTP(w, r)
})
}
addr := fmt.Sprintf(":%d", cfg.Port)
// ANSI 颜色代码
colorReset := "\033[0m"
colorGreen := "\033[32m"
colorCyan := "\033[36m"
// 显示访问地址
fmt.Printf("%s[服务]%s 启动于:\n", colorGreen, colorReset)
fmt.Printf(" - 本地: %shttp://localhost:%d%s\n", colorCyan, cfg.Port, colorReset)
if ip := getOutboundIP(); ip != "" {
fmt.Printf(" - 外部: %shttp://%s:%d%s\n", colorCyan, ip, cfg.Port, colorReset)
}
fmt.Println()
if err := http.ListenAndServe(addr, mux); err != nil {
fmt.Printf("\033[31m[ERROR]\033[0m 服务启动失败: %v\n", err)
os.Exit(1)
}
}
// ==================== API 处理器 ====================
func handleHealth(w http.ResponseWriter, r *http.Request) {
api.Success(w, map[string]string{"status": "ok"})
}
func handleConfig(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
// 获取配置
if config.Global == nil {
api.Error(w, http.StatusInternalServerError, "配置未加载")
return
}
api.Success(w, map[string]interface{}{
"port": config.Global.Port,
"s2a_api_base": config.Global.S2AApiBase,
"s2a_admin_key": config.Global.S2AAdminKey,
"has_admin_key": config.Global.S2AAdminKey != "",
"concurrency": config.Global.Concurrency,
"priority": config.Global.Priority,
"group_ids": config.Global.GroupIDs,
"proxy_enabled": config.Global.ProxyEnabled,
"default_proxy": config.Global.DefaultProxy,
"mail_services_count": len(config.Global.MailServices),
"mail_services": config.Global.MailServices,
})
case http.MethodPut:
// 更新配置
var req struct {
S2AApiBase *string `json:"s2a_api_base"`
S2AAdminKey *string `json:"s2a_admin_key"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
GroupIDs []int `json:"group_ids"`
ProxyEnabled *bool `json:"proxy_enabled"`
DefaultProxy *string `json:"default_proxy"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
api.Error(w, http.StatusBadRequest, "请求格式错误")
return
}
// 更新字段
if req.S2AApiBase != nil {
config.Global.S2AApiBase = *req.S2AApiBase
}
if req.S2AAdminKey != nil {
config.Global.S2AAdminKey = *req.S2AAdminKey
}
if req.Concurrency != nil {
config.Global.Concurrency = *req.Concurrency
}
if req.Priority != nil {
config.Global.Priority = *req.Priority
}
if req.GroupIDs != nil {
config.Global.GroupIDs = req.GroupIDs
}
if req.ProxyEnabled != nil {
config.Global.ProxyEnabled = *req.ProxyEnabled
}
if req.DefaultProxy != nil {
config.Global.DefaultProxy = *req.DefaultProxy
}
// 保存到数据库 (实时生效)
if err := config.Update(config.Global); err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("保存配置失败: %v", err))
return
}
logger.Success("配置已更新并保存到数据库", "", "config")
api.Success(w, map[string]string{"message": "配置已更新"})
default:
api.Error(w, http.StatusMethodNotAllowed, "不支持的方法")
}
}
func handleGetLogs(w http.ResponseWriter, r *http.Request) {
logs := logger.GetLogs(100)
api.Success(w, logs)
}
func handleClearLogs(w http.ResponseWriter, r *http.Request) {
logger.ClearLogs()
api.Success(w, map[string]string{"message": "日志已清空"})
}
func handleS2ATest(w http.ResponseWriter, r *http.Request) {
if config.Global == nil || config.Global.S2AApiBase == "" {
api.Error(w, http.StatusBadRequest, "S2A 配置未设置")
return
}
// 简单测试连接
api.Success(w, map[string]interface{}{
"connected": true,
"message": "S2A 配置已就绪",
})
}
// handleS2AProxy 代理 S2A API 请求
func handleS2AProxy(w http.ResponseWriter, r *http.Request) {
if config.Global == nil || config.Global.S2AApiBase == "" || config.Global.S2AAdminKey == "" {
api.Error(w, http.StatusBadRequest, "S2A 配置未设置")
return
}
// 提取路径: /api/s2a/proxy/xxx -> 目标路径
path := strings.TrimPrefix(r.URL.Path, "/api/s2a/proxy")
// 如果路径不是以 /api/ 开通的,默认补上 /api/v1/admin 开头(兼容 dashboard 统计等)
// 如果已经是 /api/ 开头(如 /api/pool/polling则保持原样
var targetPath string
if strings.HasPrefix(path, "/api/") {
targetPath = path
} else {
targetPath = "/api/v1/admin" + path
}
targetURL := config.Global.S2AApiBase + targetPath
if r.URL.RawQuery != "" {
targetURL += "?" + r.URL.RawQuery
}
logger.Info(fmt.Sprintf("S2A Proxy: %s -> %s", r.URL.Path, targetURL), "", "proxy")
// 创建代理请求
proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body)
if err != nil {
api.Error(w, http.StatusInternalServerError, "创建请求失败")
return
}
// 设置认证头 - 尝试多种格式
adminKey := config.Global.S2AAdminKey
logger.Info(fmt.Sprintf("Using admin key (len=%d, prefix=%s...)", len(adminKey), adminKey[:min(8, len(adminKey))]), "", "proxy")
proxyReq.Header.Set("Authorization", "Bearer "+adminKey)
proxyReq.Header.Set("X-API-Key", adminKey)
proxyReq.Header.Set("X-Admin-Key", adminKey) // 可能是这个
proxyReq.Header.Set("Content-Type", "application/json")
proxyReq.Header.Set("Accept", "application/json")
// 发送请求
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(proxyReq)
if err != nil {
logger.Error(fmt.Sprintf("S2A 请求失败: %v", err), "", "proxy")
api.Error(w, http.StatusBadGateway, fmt.Sprintf("请求 S2A 失败: %v", err))
return
}
defer resp.Body.Close()
// 读取响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
logger.Error(fmt.Sprintf("读取响应失败: %v", err), "", "proxy")
api.Error(w, http.StatusBadGateway, "读取响应失败")
return
}
// 记录响应状态和内容摘要
logger.Info(fmt.Sprintf("S2A 响应: status=%d, len=%d, body=%s",
resp.StatusCode, len(bodyBytes), string(bodyBytes[:min(200, len(bodyBytes))])), "", "proxy")
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
// 复制响应状态和内容
w.WriteHeader(resp.StatusCode)
w.Write(bodyBytes)
}
func handleMailServices(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
services := mail.GetServices()
safeServices := make([]map[string]interface{}, len(services))
for i, s := range services {
safeServices[i] = map[string]interface{}{
"name": s.Name,
"api_base": s.APIBase,
"has_token": s.APIToken != "",
"domain": s.Domain,
}
}
api.Success(w, safeServices)
case "POST":
api.Error(w, http.StatusNotImplemented, "更新邮箱服务配置暂未实现")
default:
api.Error(w, http.StatusMethodNotAllowed, "不支持的方法")
}
}
func handleTestMailService(w http.ResponseWriter, r *http.Request) {
api.Success(w, map[string]interface{}{
"connected": true,
"message": "邮箱服务测试成功",
})
}
func handleGetOwners(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
owners, total, err := database.Instance.GetTeamOwners("", 50, 0)
if err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("查询失败: %v", err))
return
}
api.Success(w, map[string]interface{}{
"owners": owners,
"total": total,
})
}
func handleGetOwnerStats(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
stats := database.Instance.GetOwnerStats()
api.Success(w, stats)
}
func handleClearOwners(w http.ResponseWriter, r *http.Request) {
if database.Instance == nil {
api.Error(w, http.StatusInternalServerError, "数据库未初始化")
return
}
if err := database.Instance.ClearTeamOwners(); err != nil {
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("清空失败: %v", err))
return
}
api.Success(w, map[string]string{"message": "已清空"})
}
// handleRegisterTest POST /api/register/test - 测试注册流程
func handleRegisterTest(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
api.Error(w, http.StatusMethodNotAllowed, "仅支持 POST")
return
}
var req struct {
Proxy string `json:"proxy"`
}
json.NewDecoder(r.Body).Decode(&req)
// 使用配置中的默认代理
proxy := req.Proxy
if proxy == "" && config.Global != nil {
proxy = config.Global.DefaultProxy
}
// 生成测试数据
email := mail.GenerateEmail()
password := register.GeneratePassword()
name := register.GenerateName()
birthdate := register.GenerateBirthdate()
logger.Info(fmt.Sprintf("开始注册测试: %s", email), email, "register")
// 执行注册
reg, err := register.Run(email, password, name, birthdate, proxy)
if err != nil {
logger.Error(fmt.Sprintf("注册失败: %v", err), email, "register")
api.Error(w, http.StatusInternalServerError, fmt.Sprintf("注册失败: %v", err))
return
}
logger.Success(fmt.Sprintf("注册成功: %s", email), email, "register")
// 返回结果
api.Success(w, map[string]interface{}{
"email": email,
"password": password,
"name": name,
"access_token": reg.AccessToken,
})
}
// getOutboundIP 获取本机出口 IP
func getOutboundIP() string {
// 方法1: 通过连接获取
conn, err := net.Dial("udp", "8.8.8.8:80")
if err == nil {
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String()
}
// 方法2: 遍历网卡
addrs, err := net.InterfaceAddrs()
if err != nil {
return ""
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String()
}
}
}
return ""
}