package db import ( "database/sql" "fmt" "log" "os" "time" _ "github.com/lib/pq" ) var DB *sql.DB // Config 数据库配置 type Config struct { Host string Port int User string Password string DBName string SSLMode string } // DefaultConfig 从环境变量读取默认配置 func DefaultConfig() Config { return Config{ Host: getEnv("DB_HOST", "localhost"), Port: getEnvInt("DB_PORT", 5432), User: getEnv("DB_USER", "postgres"), Password: getEnv("DB_PASSWORD", ""), DBName: getEnv("DB_NAME", "gpt_manager"), SSLMode: getEnv("DB_SSLMODE", "disable"), } } // Connect 连接数据库 func Connect(cfg Config) (*sql.DB, error) { dsn := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, ) db, err := sql.Open("postgres", dsn) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // 配置连接池 db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // 测试连接 if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } log.Println("Database connected successfully") return db, nil } // Init 初始化全局数据库连接 func Init() error { cfg := DefaultConfig() db, err := Connect(cfg) if err != nil { return err } DB = db return nil } // Close 关闭数据库连接 func Close() error { if DB != nil { return DB.Close() } return nil } // getEnv 获取环境变量,带默认值 func getEnv(key, defaultVal string) string { if val := os.Getenv(key); val != "" { return val } return defaultVal } // getEnvInt 获取整数类型环境变量 func getEnvInt(key string, defaultVal int) int { if val := os.Getenv(key); val != "" { var i int if _, err := fmt.Sscanf(val, "%d", &i); err == nil { return i } } return defaultVal }