package database import ( "database/sql" "fmt" "log" "time" _ "github.com/lib/pq" "go-helper/internal/model" ) // DB wraps the sql.DB connection and provides typed query methods. type DB struct { *sql.DB } // New opens a PostgreSQL connection and auto-creates tables. func New(databaseURL string) *DB { conn, err := sql.Open("postgres", databaseURL) if err != nil { log.Fatalf("[DB] 无法连接数据库: %v", err) } if err := conn.Ping(); err != nil { log.Fatalf("[DB] 数据库连接测试失败: %v", err) } conn.SetMaxOpenConns(10) conn.SetMaxIdleConns(5) conn.SetConnMaxLifetime(5 * time.Minute) d := &DB{conn} d.migrate() log.Println("[DB] 数据库初始化完成") return d } func (d *DB) migrate() { queries := []string{ `CREATE TABLE IF NOT EXISTS gpt_accounts ( id BIGSERIAL PRIMARY KEY, email TEXT NOT NULL, token TEXT NOT NULL, refresh_token TEXT DEFAULT '', user_count INT DEFAULT 0, invite_count INT DEFAULT 0, chatgpt_account_id TEXT DEFAULT '', oai_device_id TEXT DEFAULT '', expire_at TEXT DEFAULT '', is_open BOOLEAN DEFAULT TRUE, is_banned BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT NOW(), updated_at TIMESTAMP DEFAULT NOW() )`, `CREATE TABLE IF NOT EXISTS redemption_codes ( id BIGSERIAL PRIMARY KEY, code TEXT UNIQUE NOT NULL, is_redeemed BOOLEAN DEFAULT FALSE, redeemed_at TEXT, redeemed_by TEXT, account_email TEXT DEFAULT '', channel TEXT DEFAULT 'common', created_at TIMESTAMP DEFAULT NOW() )`, `CREATE TABLE IF NOT EXISTS telegram_admins ( id BIGSERIAL PRIMARY KEY, user_id BIGINT UNIQUE NOT NULL, added_by BIGINT NOT NULL, created_at TIMESTAMP DEFAULT NOW() )`, } for _, q := range queries { if _, err := d.Exec(q); err != nil { log.Fatalf("[DB] 建表失败: %v", err) } } } // --------------- GptAccount CRUD --------------- // GetAllAccounts returns all accounts ordered by creation time desc. func (d *DB) GetAllAccounts() ([]model.GptAccount, error) { rows, err := d.Query(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts ORDER BY created_at DESC`) if err != nil { return nil, err } defer rows.Close() return scanAccounts(rows) } // GetOpenAccounts returns non-banned, open accounts that still have capacity. func (d *DB) GetOpenAccounts(capacity int) ([]model.GptAccount, error) { rows, err := d.Query(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts WHERE is_open = TRUE AND is_banned = FALSE AND (user_count + invite_count) < $1 AND token != '' AND chatgpt_account_id != '' ORDER BY (user_count + invite_count) ASC, RANDOM()`, capacity) if err != nil { return nil, err } defer rows.Close() return scanAccounts(rows) } // GetAccountByID fetches a single account by its ID. func (d *DB) GetAccountByID(id int64) (*model.GptAccount, error) { row := d.QueryRow(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts WHERE id = $1 LIMIT 1`, id) return scanAccount(row) } // GetAccountByChatGPTAccountID fetches a single account by its OpenAI account ID. func (d *DB) GetAccountByChatGPTAccountID(accountID string) (*model.GptAccount, error) { row := d.QueryRow(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts WHERE chatgpt_account_id = $1 LIMIT 1`, accountID) return scanAccount(row) } // GetAccountByEmail fetches a single account by email (case-insensitive). func (d *DB) GetAccountByEmail(email string) (*model.GptAccount, error) { row := d.QueryRow(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts WHERE lower(email) = lower($1) LIMIT 1`, email) return scanAccount(row) } // GetAccountsWithRT returns accounts that have a refresh token and are not banned. func (d *DB) GetAccountsWithRT() ([]model.GptAccount, error) { rows, err := d.Query(` SELECT id, email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned, created_at, updated_at FROM gpt_accounts WHERE refresh_token != '' AND is_banned = FALSE ORDER BY id`) if err != nil { return nil, err } defer rows.Close() return scanAccounts(rows) } // CreateAccount inserts a new account and returns its ID. func (d *DB) CreateAccount(a *model.GptAccount) (int64, error) { var id int64 err := d.QueryRow(` INSERT INTO gpt_accounts (email, token, refresh_token, user_count, invite_count, chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) RETURNING id`, a.Email, a.Token, a.RefreshToken, a.UserCount, a.InviteCount, a.ChatgptAccountID, a.OaiDeviceID, a.ExpireAt, a.IsOpen, a.IsBanned, ).Scan(&id) return id, err } // UpdateAccountTokens updates the access token and refresh token. func (d *DB) UpdateAccountTokens(id int64, accessToken, refreshToken string) error { _, err := d.Exec(` UPDATE gpt_accounts SET token = $1, refresh_token = $2, updated_at = NOW() WHERE id = $3`, accessToken, refreshToken, id) return err } // UpdateAccountEmail updates the email of an account and its associated codes. func (d *DB) UpdateAccountEmail(id int64, oldEmail, newEmail string) error { tx, err := d.Begin() if err != nil { return err } defer tx.Rollback() _, err = tx.Exec(` UPDATE gpt_accounts SET email = $1, updated_at = NOW() WHERE id = $2`, newEmail, id) if err != nil { return err } _, err = tx.Exec(` UPDATE redemption_codes SET account_email = $1 WHERE account_email = $2`, newEmail, oldEmail) if err != nil { return err } return tx.Commit() } // UpdateAccountCounts updates user_count and invite_count. func (d *DB) UpdateAccountCounts(id int64, userCount, inviteCount int) error { _, err := d.Exec(` UPDATE gpt_accounts SET user_count = $1, invite_count = $2, updated_at = NOW() WHERE id = $3`, userCount, inviteCount, id) return err } // UpdateAccountInfo updates chatgpt_account_id, expire_at, etc. from API data. func (d *DB) UpdateAccountInfo(id int64, chatgptAccountID, expireAt string) error { _, err := d.Exec(` UPDATE gpt_accounts SET chatgpt_account_id = $1, expire_at = $2, updated_at = NOW() WHERE id = $3`, chatgptAccountID, expireAt, id) return err } // BanAccount marks an account as banned and not open. func (d *DB) BanAccount(id int64) error { _, err := d.Exec(` UPDATE gpt_accounts SET is_banned = TRUE, is_open = FALSE, updated_at = NOW() WHERE id = $1`, id) return err } // --------------- RedemptionCode CRUD --------------- // GetCodeByCode fetches a redemption code by its code string. func (d *DB) GetCodeByCode(code string) (*model.RedemptionCode, error) { row := d.QueryRow(` SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at FROM redemption_codes WHERE code = $1`, code) var rc model.RedemptionCode err := row.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy, &rc.AccountEmail, &rc.Channel, &rc.CreatedAt) if err != nil { return nil, err } return &rc, nil } // CreateCode inserts a batch of redemption codes for a given account email. func (d *DB) CreateCodes(accountEmail string, codes []string) error { tx, err := d.Begin() if err != nil { return err } defer tx.Rollback() stmt, err := tx.Prepare(` INSERT INTO redemption_codes (code, account_email) VALUES ($1, $2) ON CONFLICT (code) DO NOTHING`) if err != nil { return err } defer stmt.Close() for _, c := range codes { if _, err := stmt.Exec(c, accountEmail); err != nil { return err } } return tx.Commit() } // RedeemCode marks a code as redeemed. func (d *DB) RedeemCode(codeID int64, redeemedBy string) error { _, err := d.Exec(` UPDATE redemption_codes SET is_redeemed = TRUE, redeemed_at = NOW()::TEXT, redeemed_by = $1 WHERE id = $2`, redeemedBy, codeID) return err } // CountAvailableCodes returns the number of unredeemed codes. func (d *DB) CountAvailableCodes() (int, error) { var count int err := d.QueryRow(`SELECT COUNT(*) FROM redemption_codes WHERE is_redeemed = FALSE`).Scan(&count) return count, err } // CountAvailableCodesByAccount returns the number of unredeemed codes for a specific account. func (d *DB) CountAvailableCodesByAccount(accountEmail string) (int, error) { var count int err := d.QueryRow(` SELECT COUNT(*) FROM redemption_codes WHERE is_redeemed = FALSE AND account_email = $1`, accountEmail).Scan(&count) return count, err } // GetCodesByAccount returns all codes for an account email. func (d *DB) GetCodesByAccount(accountEmail string) ([]model.RedemptionCode, error) { rows, err := d.Query(` SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at FROM redemption_codes WHERE account_email = $1 ORDER BY id`, accountEmail) if err != nil { return nil, err } defer rows.Close() var list []model.RedemptionCode for rows.Next() { var rc model.RedemptionCode if err := rows.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy, &rc.AccountEmail, &rc.Channel, &rc.CreatedAt); err != nil { return nil, err } list = append(list, rc) } return list, rows.Err() } // GetAllCodes returns all redemption codes. func (d *DB) GetAllCodes() ([]model.RedemptionCode, error) { rows, err := d.Query(` SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at FROM redemption_codes ORDER BY id`) if err != nil { return nil, err } defer rows.Close() var list []model.RedemptionCode for rows.Next() { var rc model.RedemptionCode if err := rows.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy, &rc.AccountEmail, &rc.Channel, &rc.CreatedAt); err != nil { return nil, err } list = append(list, rc) } return list, rows.Err() } // DeleteCode deletes a specific redemption code. func (d *DB) DeleteCode(code string) error { _, err := d.Exec(`DELETE FROM redemption_codes WHERE code = $1`, code) return err } // DeleteUnusedCodes deletes all redemption codes that haven't been redeemed. // Returns the number of codes deleted. func (d *DB) DeleteUnusedCodes() (int64, error) { res, err := d.Exec(`DELETE FROM redemption_codes WHERE is_redeemed = false`) if err != nil { return 0, err } return res.RowsAffected() } // DeleteAccount removes an account and its associated codes. func (d *DB) DeleteAccount(id int64) error { // Get account email first. acct, err := d.GetAccountByID(id) if err != nil { return err } // Delete associated redemption codes (using the current email). _, _ = d.Exec(`DELETE FROM redemption_codes WHERE account_email = $1`, acct.Email) // Delete the account. _, err = d.Exec(`DELETE FROM gpt_accounts WHERE id = $1`, id) // Also clean up any orphaned codes that don't match any existing account email. // This fixes issues where codes were generated under a team name before the email was updated. _, _ = d.Exec(`DELETE FROM redemption_codes WHERE account_email NOT IN (SELECT email FROM gpt_accounts)`) return err } // --------------- TelegramAdmin CRUD --------------- // AddAdmin inserts a new admin into the database. func (d *DB) AddAdmin(userID int64, addedBy int64) error { _, err := d.Exec(` INSERT INTO telegram_admins (user_id, added_by) VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING`, userID, addedBy) return err } // RemoveAdmin deletes an admin from the database by user ID. func (d *DB) RemoveAdmin(userID int64) error { _, err := d.Exec(`DELETE FROM telegram_admins WHERE user_id = $1`, userID) return err } // GetAllAdmins returns a list of all admins stored in the database. func (d *DB) GetAllAdmins() ([]model.TelegramAdmin, error) { rows, err := d.Query(`SELECT id, user_id, added_by, created_at FROM telegram_admins ORDER BY created_at ASC`) if err != nil { return nil, err } defer rows.Close() var list []model.TelegramAdmin for rows.Next() { var a model.TelegramAdmin if err := rows.Scan(&a.ID, &a.UserID, &a.AddedBy, &a.CreatedAt); err != nil { return nil, err } list = append(list, a) } return list, nil } // IsAdmin checks if a specific user ID exists in the admin table. func (d *DB) IsAdmin(userID int64) (bool, error) { var count int err := d.QueryRow(`SELECT count(*) FROM telegram_admins WHERE user_id = $1`, userID).Scan(&count) if err != nil { return false, err } return count > 0, nil } // --------------- helpers --------------- func scanAccounts(rows *sql.Rows) ([]model.GptAccount, error) { var list []model.GptAccount for rows.Next() { a, err := scanAccountFromRows(rows) if err != nil { return nil, err } list = append(list, *a) } return list, rows.Err() } func scanAccountFromRows(rows *sql.Rows) (*model.GptAccount, error) { var a model.GptAccount err := rows.Scan(&a.ID, &a.Email, &a.Token, &a.RefreshToken, &a.UserCount, &a.InviteCount, &a.ChatgptAccountID, &a.OaiDeviceID, &a.ExpireAt, &a.IsOpen, &a.IsBanned, &a.CreatedAt, &a.UpdatedAt) if err != nil { return nil, err } return &a, nil } func scanAccount(row *sql.Row) (*model.GptAccount, error) { var a model.GptAccount err := row.Scan(&a.ID, &a.Email, &a.Token, &a.RefreshToken, &a.UserCount, &a.InviteCount, &a.ChatgptAccountID, &a.OaiDeviceID, &a.ExpireAt, &a.IsOpen, &a.IsBanned, &a.CreatedAt, &a.UpdatedAt) if err != nil { return nil, fmt.Errorf("账号不存在: %w", err) } return &a, nil }