feat: 初始化 ChatGPT Team 管理机器人
核心功能: - 实现基于 Telegram Inline Button 交互的后台面板与用户端 - 支持通过账密登录和 RT (Refresh Token) 方式添加 ChatGPT Team 账号 - 支持管理、拉取和删除待处理邀请,支持一键清空多余邀请 - 支持按剩余容量自动生成邀请兑换码,支持分页查看与一键清空未使用兑换码 - 随机邀请功能:成功拉人后自动核销兑换码 - 定时检测 Token 状态,实现自动续订/刷新并拦截封禁账号 (处理 401/402 错误) 系统与配置: - 使用 PostgreSQL 数据库管理账号、邀请和兑换记录 - 支持在端内动态添加、移除管理员 - 完善 Docker 部署配置与 .gitignore 规则
This commit is contained in:
17
.env.example
Normal file
17
.env.example
Normal file
@@ -0,0 +1,17 @@
|
||||
# PostgreSQL 连接串
|
||||
DATABASE_URL=postgres://postgres:postgres@localhost:5432/teamhelper?sslmode=disable
|
||||
|
||||
# Telegram Bot Token(从 @BotFather 获取)
|
||||
TELEGRAM_BOT_TOKEN=your-bot-token-here
|
||||
|
||||
# 管理员 Telegram 用户 ID(逗号分隔)
|
||||
TELEGRAM_ADMIN_IDS=123456789
|
||||
|
||||
# 可选:代理地址(支持 http/socks5)
|
||||
# PROXY_URL=socks5://127.0.0.1:1080
|
||||
|
||||
# Token 定时检测间隔(分钟)
|
||||
TOKEN_CHECK_INTERVAL=30
|
||||
|
||||
# Team 容量上限
|
||||
TEAM_CAPACITY=5
|
||||
52
.gitignore
vendored
Normal file
52
.gitignore
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
*.test
|
||||
*.out
|
||||
bin/
|
||||
|
||||
# Test binary
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# environment variables
|
||||
*.env
|
||||
|
||||
# sqlite data
|
||||
data/
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# IDE and editor configurations
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Temporary files or config files specific to this app
|
||||
config.yaml
|
||||
config.json
|
||||
logs/
|
||||
.log
|
||||
20
Dockerfile
Normal file
20
Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Build stage
|
||||
FROM golang:1.21-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /app/go-helper ./cmd
|
||||
|
||||
# Run stage
|
||||
FROM alpine:3.19
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata
|
||||
ENV TZ=Asia/Shanghai
|
||||
|
||||
WORKDIR /app
|
||||
COPY --from=builder /app/go-helper .
|
||||
COPY .env.example .env
|
||||
|
||||
ENTRYPOINT ["./go-helper"]
|
||||
35
cmd/main.go
Normal file
35
cmd/main.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"go-helper/internal/bot"
|
||||
"go-helper/internal/chatgpt"
|
||||
"go-helper/internal/config"
|
||||
"go-helper/internal/database"
|
||||
"go-helper/internal/scheduler"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
// Load configuration.
|
||||
cfg := config.Load()
|
||||
|
||||
// Initialise database.
|
||||
db := database.New(cfg.DatabaseURL)
|
||||
defer db.Close()
|
||||
|
||||
// Create ChatGPT API client.
|
||||
client := chatgpt.NewClient(cfg.ProxyURL)
|
||||
|
||||
// Create OAuth manager (no server needed, uses URL-paste flow).
|
||||
oauth := chatgpt.NewOAuthManager(client)
|
||||
|
||||
// Start scheduled token checker.
|
||||
scheduler.StartTokenChecker(db, client, cfg.TokenCheckInterval)
|
||||
|
||||
// Start Telegram bot (blocking).
|
||||
log.Println("[Main] 启动 Telegram Bot...")
|
||||
bot.Start(db, cfg, client, oauth)
|
||||
}
|
||||
10
go.mod
Normal file
10
go.mod
Normal file
@@ -0,0 +1,10 @@
|
||||
module go-helper
|
||||
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.11.2
|
||||
golang.org/x/net v0.51.0
|
||||
)
|
||||
8
go.sum
Normal file
8
go.sum
Normal file
@@ -0,0 +1,8 @@
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
|
||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
|
||||
github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
2077
internal/bot/telegram.go
Normal file
2077
internal/bot/telegram.go
Normal file
File diff suppressed because it is too large
Load Diff
111
internal/chatgpt/account.go
Normal file
111
internal/chatgpt/account.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
// FetchAccountInfo queries the ChatGPT accounts check API and returns team accounts
|
||||
// with subscription expiry information.
|
||||
func (c *Client) FetchAccountInfo(accessToken string) ([]model.TeamAccountInfo, error) {
|
||||
token := strings.TrimPrefix(strings.TrimSpace(accessToken), "Bearer ")
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("缺少 access token")
|
||||
}
|
||||
|
||||
apiURL := "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
req.Header.Set("Oai-Client-Version", oaiClientVersion)
|
||||
req.Header.Set("Oai-Language", "zh-CN")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求 ChatGPT API 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 || resp.StatusCode == 402 {
|
||||
return nil, fmt.Errorf("Token 已过期或被封禁")
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("ChatGPT API 错误 %d: %s", resp.StatusCode, truncate(string(body), 300))
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Accounts map[string]json.RawMessage `json:"accounts"`
|
||||
AccountOrdering []string `json:"account_ordering"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
var results []model.TeamAccountInfo
|
||||
|
||||
// Determine order.
|
||||
seen := make(map[string]bool)
|
||||
var orderedIDs []string
|
||||
for _, id := range data.AccountOrdering {
|
||||
if _, ok := data.Accounts[id]; ok && !seen[id] {
|
||||
orderedIDs = append(orderedIDs, id)
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
for id := range data.Accounts {
|
||||
if !seen[id] && id != "default" {
|
||||
orderedIDs = append(orderedIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range orderedIDs {
|
||||
raw := data.Accounts[id]
|
||||
var acc struct {
|
||||
Account struct {
|
||||
Name string `json:"name"`
|
||||
PlanType string `json:"plan_type"`
|
||||
} `json:"account"`
|
||||
Entitlement struct {
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
HasActiveSubscription bool `json:"has_active_subscription"`
|
||||
} `json:"entitlement"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &acc); err != nil {
|
||||
continue
|
||||
}
|
||||
if acc.Account.PlanType != "team" {
|
||||
continue
|
||||
}
|
||||
results = append(results, model.TeamAccountInfo{
|
||||
AccountID: id,
|
||||
Name: acc.Account.Name,
|
||||
PlanType: acc.Account.PlanType,
|
||||
ExpiresAt: acc.Entitlement.ExpiresAt,
|
||||
HasActiveSubscription: acc.Entitlement.HasActiveSubscription,
|
||||
})
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("未找到 Team 类型的账号")
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "..."
|
||||
}
|
||||
88
internal/chatgpt/client.go
Normal file
88
internal/chatgpt/client.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
oaiClientVersion = "prod-eddc2f6ff65fee2d0d6439e379eab94fe3047f72"
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36"
|
||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
)
|
||||
|
||||
// Client wraps HTTP operations with common headers and optional proxy.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
// NewClient creates a ChatGPT API client with optional proxy support.
|
||||
func NewClient(proxyURL string) *Client {
|
||||
c := &Client{proxyURL: proxyURL}
|
||||
c.httpClient = c.buildHTTPClient()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) buildHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
if c.proxyURL != "" {
|
||||
parsed, err := url.Parse(c.proxyURL)
|
||||
if err == nil {
|
||||
scheme := parsed.Scheme
|
||||
if scheme == "socks5" || scheme == "socks5h" {
|
||||
var auth *proxy.Auth
|
||||
if parsed.User != nil {
|
||||
pass, _ := parsed.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: parsed.User.Username(),
|
||||
Password: pass,
|
||||
}
|
||||
}
|
||||
dialer, dErr := proxy.SOCKS5("tcp", parsed.Host, auth, proxy.Direct)
|
||||
if dErr == nil {
|
||||
if ctxDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
transport.DialContext = ctxDialer.DialContext
|
||||
} else {
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// buildHeaders returns common headers for ChatGPT backend API calls.
|
||||
func buildHeaders(token, chatgptAccountID string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Accept", "*/*")
|
||||
h.Set("Accept-Language", "zh-CN,zh;q=0.9")
|
||||
h.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
h.Set("Chatgpt-Account-Id", chatgptAccountID)
|
||||
h.Set("Oai-Client-Version", oaiClientVersion)
|
||||
h.Set("Oai-Language", "zh-CN")
|
||||
h.Set("User-Agent", userAgent)
|
||||
h.Set("Origin", "https://chatgpt.com")
|
||||
h.Set("Referer", "https://chatgpt.com/admin/members")
|
||||
return h
|
||||
}
|
||||
93
internal/chatgpt/invite.go
Normal file
93
internal/chatgpt/invite.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
maxInviteAttempts = 3
|
||||
retryBaseDelay = 800 * time.Millisecond
|
||||
retryMaxDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
// InviteUser sends an invitation to the given email on the specified team account.
|
||||
func (c *Client) InviteUser(email string, account *model.GptAccount) error {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" {
|
||||
return fmt.Errorf("缺少邀请邮箱")
|
||||
}
|
||||
if account.Token == "" || account.ChatgptAccountID == "" {
|
||||
return fmt.Errorf("账号配置不完整")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites", account.ChatgptAccountID)
|
||||
payload := map[string]interface{}{
|
||||
"email_addresses": []string{email},
|
||||
"role": "standard-user",
|
||||
"resend_emails": true,
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(payload)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxInviteAttempts; attempt++ {
|
||||
req, err := http.NewRequest("POST", apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header = headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if account.OaiDeviceID != "" {
|
||||
req.Header.Set("Oai-Device-Id", account.OaiDeviceID)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("网络错误: %w", err)
|
||||
log.Printf("[Invite] 尝试 %d/%d 网络错误: %v", attempt, maxInviteAttempts, err)
|
||||
if attempt < maxInviteAttempts {
|
||||
sleepRetry(attempt)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
log.Printf("[Invite] 成功邀请 %s 到账号 %s", email, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 300))
|
||||
log.Printf("[Invite] 尝试 %d/%d 失败: %v", attempt, maxInviteAttempts, lastErr)
|
||||
|
||||
if !isRetryableStatus(resp.StatusCode) || attempt >= maxInviteAttempts {
|
||||
break
|
||||
}
|
||||
sleepRetry(attempt)
|
||||
}
|
||||
return fmt.Errorf("邀请失败(已尝试 %d 次): %v", maxInviteAttempts, lastErr)
|
||||
}
|
||||
|
||||
func sleepRetry(attempt int) {
|
||||
delay := float64(retryBaseDelay) * math.Pow(2, float64(attempt-1))
|
||||
if delay > float64(retryMaxDelay) {
|
||||
delay = float64(retryMaxDelay)
|
||||
}
|
||||
time.Sleep(time.Duration(delay))
|
||||
}
|
||||
|
||||
func isRetryableStatus(status int) bool {
|
||||
return status == 408 || status == 429 || (status >= 500 && status <= 599)
|
||||
}
|
||||
195
internal/chatgpt/member.go
Normal file
195
internal/chatgpt/member.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
// GetUsers fetches the list of members for a team account.
|
||||
func (c *Client) GetUsers(account *model.GptAccount) (int, []model.ChatGPTUser, error) {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/users?offset=0&limit=100&query=",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("获取成员列表网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "获取成员"); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Total int `json:"total"`
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
AccountUserID string `json:"account_user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return 0, nil, fmt.Errorf("解析成员数据失败: %w", err)
|
||||
}
|
||||
|
||||
var users []model.ChatGPTUser
|
||||
for _, item := range data.Items {
|
||||
users = append(users, model.ChatGPTUser{
|
||||
ID: item.ID,
|
||||
AccountUserID: item.AccountUserID,
|
||||
Email: item.Email,
|
||||
Role: item.Role,
|
||||
Name: item.Name,
|
||||
})
|
||||
}
|
||||
return data.Total, users, nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user from the team account.
|
||||
func (c *Client) DeleteUser(account *model.GptAccount, userID string) error {
|
||||
normalizedID := userID
|
||||
if !strings.HasPrefix(normalizedID, "user-") {
|
||||
normalizedID = "user-" + normalizedID
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/users/%s",
|
||||
account.ChatgptAccountID, normalizedID)
|
||||
|
||||
req, err := http.NewRequest("DELETE", apiURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除用户网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "删除用户"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[Member] 成功删除用户 %s (账号 %s)", normalizedID, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInvites fetches pending invitations for a team account.
|
||||
func (c *Client) GetInvites(account *model.GptAccount) (int, []model.ChatGPTInvite, error) {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites?offset=0&limit=100&query=",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header.Set("Referer", "https://chatgpt.com/admin/members?tab=invites")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("获取邀请列表网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "获取邀请列表"); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Total int `json:"total"`
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
Role string `json:"role"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return 0, nil, fmt.Errorf("解析邀请数据失败: %w", err)
|
||||
}
|
||||
|
||||
var invites []model.ChatGPTInvite
|
||||
for _, item := range data.Items {
|
||||
invites = append(invites, model.ChatGPTInvite{
|
||||
ID: item.ID,
|
||||
EmailAddress: item.EmailAddress,
|
||||
Role: item.Role,
|
||||
})
|
||||
}
|
||||
return data.Total, invites, nil
|
||||
}
|
||||
|
||||
// DeleteInvite cancels a pending invitation by email on the team account.
|
||||
func (c *Client) DeleteInvite(account *model.GptAccount, email string) error {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
payload := map[string]string{"email_address": strings.TrimSpace(strings.ToLower(email))}
|
||||
bodyBytes, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequest("DELETE", apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Referer", "https://chatgpt.com/admin/members?tab=invites")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除邀请网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "删除邀请"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[Member] 成功删除邀请 %s (账号 %s)", email, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAPIError inspects the HTTP status code and returns a descriptive error.
|
||||
func checkAPIError(statusCode int, body []byte, label string) error {
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := truncate(string(body), 500)
|
||||
|
||||
// Check for account_deactivated.
|
||||
if strings.Contains(bodyStr, "account_deactivated") {
|
||||
return fmt.Errorf("账号已停用 (account_deactivated)")
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401, 402:
|
||||
return fmt.Errorf("Token 已过期或被封禁")
|
||||
case 403:
|
||||
return fmt.Errorf("资源不存在或无权访问")
|
||||
case 429:
|
||||
return fmt.Errorf("API 请求过于频繁,请稍后重试")
|
||||
default:
|
||||
return fmt.Errorf("%s失败 (HTTP %d): %s", label, statusCode, bodyStr)
|
||||
}
|
||||
}
|
||||
233
internal/chatgpt/oauth.go
Normal file
233
internal/chatgpt/oauth.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthSession holds PKCE session data for an in-progress OAuth login.
|
||||
type OAuthSession struct {
|
||||
CodeVerifier string
|
||||
State string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OAuthResult holds the tokens and account info after a successful OAuth exchange.
|
||||
type OAuthResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
Email string
|
||||
AccountID string
|
||||
Name string
|
||||
PlanType string
|
||||
}
|
||||
|
||||
// OAuthManager manages PKCE sessions for the manual URL-paste OAuth flow.
|
||||
type OAuthManager struct {
|
||||
client *Client
|
||||
sessions map[string]*OAuthSession // state -> session
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewOAuthManager creates a new OAuth manager.
|
||||
func NewOAuthManager(client *Client) *OAuthManager {
|
||||
return &OAuthManager{
|
||||
client: client,
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURL creates an OpenAI OAuth authorization URL with PKCE.
|
||||
// The redirect_uri points to localhost which won't load — user copies the URL from browser.
|
||||
func (m *OAuthManager) GenerateAuthURL() (authURL, state string, err error) {
|
||||
// Generate PKCE code verifier.
|
||||
verifierBytes := make([]byte, 64)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return "", "", fmt.Errorf("生成 PKCE 失败: %w", err)
|
||||
}
|
||||
codeVerifier := hex.EncodeToString(verifierBytes)
|
||||
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Generate state.
|
||||
stateBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return "", "", fmt.Errorf("生成 state 失败: %w", err)
|
||||
}
|
||||
state = hex.EncodeToString(stateBytes)
|
||||
|
||||
// Store session.
|
||||
m.mu.Lock()
|
||||
m.sessions[state] = &OAuthSession{
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Build auth URL — redirect to localhost, user will copy the URL.
|
||||
redirectURI := "http://localhost:1455/auth/callback"
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", openaiClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", "openid profile email offline_access")
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("state", state)
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
authURL = fmt.Sprintf("https://auth.openai.com/oauth/authorize?%s", params.Encode())
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
// ExchangeCallbackURL parses the pasted callback URL to extract code and state,
|
||||
// then exchanges the authorization code for tokens.
|
||||
func (m *OAuthManager) ExchangeCallbackURL(callbackURL string) (*OAuthResult, error) {
|
||||
callbackURL = strings.TrimSpace(callbackURL)
|
||||
|
||||
parsed, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("URL 格式错误: %w", err)
|
||||
}
|
||||
|
||||
code := parsed.Query().Get("code")
|
||||
state := parsed.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
errMsg := parsed.Query().Get("error_description")
|
||||
if errMsg == "" {
|
||||
errMsg = parsed.Query().Get("error")
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = "回调 URL 中未找到 code 参数"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
return nil, fmt.Errorf("回调 URL 中未找到 state 参数")
|
||||
}
|
||||
|
||||
// Look up the session.
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[state]
|
||||
if ok {
|
||||
delete(m.sessions, state)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("会话已过期或无效,请重新使用 /login")
|
||||
}
|
||||
|
||||
// Check if session is expired (10 minutes).
|
||||
if time.Since(session.CreatedAt) > 10*time.Minute {
|
||||
return nil, fmt.Errorf("登录会话已过期(超过10分钟),请重新使用 /login")
|
||||
}
|
||||
|
||||
// Exchange code for tokens.
|
||||
return m.exchangeCode(code, session.CodeVerifier)
|
||||
}
|
||||
|
||||
func (m *OAuthManager) exchangeCode(code, codeVerifier string) (*OAuthResult, error) {
|
||||
redirectURI := "http://localhost:1455/auth/callback"
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("client_id", openaiClientID)
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequest("POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := m.client.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("交换授权码失败 (HTTP %d): %s", resp.StatusCode, truncate(string(body), 300))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("解析 token 响应失败: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("未返回有效的 access token")
|
||||
}
|
||||
|
||||
result := &OAuthResult{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
}
|
||||
|
||||
// Decode ID token for user info.
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := decodeJWTPayload(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
result.Email, _ = claims["email"].(string)
|
||||
result.Name, _ = claims["name"].(string)
|
||||
if authClaims, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
result.AccountID, _ = authClaims["chatgpt_account_id"].(string)
|
||||
result.PlanType, _ = authClaims["chatgpt_plan_type"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decodeJWTPayload(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
109
internal/chatgpt/token.go
Normal file
109
internal/chatgpt/token.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// TokenResult holds the result of a token refresh.
|
||||
type TokenResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
IDToken string
|
||||
ExpiresIn int
|
||||
}
|
||||
|
||||
// GetEmail extracts the user's email from the ID token or Access token.
|
||||
func (tr *TokenResult) GetEmail() string {
|
||||
if tr.IDToken != "" {
|
||||
claims, err := decodeJWTPayload(tr.IDToken)
|
||||
if err == nil {
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
}
|
||||
if tr.AccessToken != "" {
|
||||
claims, err := decodeJWTPayload(tr.AccessToken)
|
||||
if err == nil {
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RefreshAccessToken exchanges a refresh token for a new access token.
|
||||
func (c *Client) RefreshAccessToken(refreshToken string) (*TokenResult, error) {
|
||||
rt := strings.TrimSpace(refreshToken)
|
||||
if rt == "" {
|
||||
return nil, fmt.Errorf("refresh token 为空")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("client_id", openaiClientID)
|
||||
form.Set("refresh_token", rt)
|
||||
form.Set("scope", "openid profile email")
|
||||
|
||||
req, err := http.NewRequest("POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新 token 网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &errData)
|
||||
msg := errData.ErrorDescription
|
||||
if msg == "" {
|
||||
msg = errData.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("刷新 token 失败: %s", msg)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析刷新结果失败: %w", err)
|
||||
}
|
||||
if result.AccessToken == "" {
|
||||
return nil, fmt.Errorf("刷新 token 失败: 未返回有效凭证")
|
||||
}
|
||||
|
||||
newRT := result.RefreshToken
|
||||
if newRT == "" {
|
||||
newRT = rt
|
||||
}
|
||||
|
||||
return &TokenResult{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: newRT,
|
||||
IDToken: result.IDToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
86
internal/config/config.go
Normal file
86
internal/config/config.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
// Config holds all application configuration.
|
||||
type Config struct {
|
||||
DatabaseURL string
|
||||
TelegramBotToken string
|
||||
TelegramAdminIDs []int64
|
||||
ProxyURL string
|
||||
TokenCheckInterval int // minutes
|
||||
TeamCapacity int
|
||||
OAuthCallbackPort int
|
||||
}
|
||||
|
||||
// Load reads configuration from environment variables / .env file.
|
||||
func Load() *Config {
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &Config{
|
||||
DatabaseURL: getEnv("DATABASE_URL", "postgres://postgres:postgres@localhost:5432/teamhelper?sslmode=disable"),
|
||||
TelegramBotToken: getEnv("TELEGRAM_BOT_TOKEN", ""),
|
||||
ProxyURL: getEnv("PROXY_URL", ""),
|
||||
TokenCheckInterval: getEnvInt("TOKEN_CHECK_INTERVAL", 30),
|
||||
TeamCapacity: getEnvInt("TEAM_CAPACITY", 6),
|
||||
OAuthCallbackPort: getEnvInt("OAUTH_CALLBACK_PORT", 1455),
|
||||
}
|
||||
|
||||
raw := getEnv("TELEGRAM_ADMIN_IDS", "")
|
||||
if raw != "" {
|
||||
for _, s := range strings.Split(raw, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("[Config] 无法解析管理员ID '%s': %v", s, err)
|
||||
continue
|
||||
}
|
||||
cfg.TelegramAdminIDs = append(cfg.TelegramAdminIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.TelegramBotToken == "" {
|
||||
log.Fatal("[Config] TELEGRAM_BOT_TOKEN 未配置")
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// IsAdmin checks if the given Telegram user ID is an admin.
|
||||
func (c *Config) IsAdmin(userID int64) bool {
|
||||
for _, id := range c.TelegramAdminIDs {
|
||||
if id == userID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
i, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return i
|
||||
}
|
||||
459
internal/database/db.go
Normal file
459
internal/database/db.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
// DB wraps the sql.DB connection and provides typed query methods.
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// New opens a PostgreSQL connection and auto-creates tables.
|
||||
func New(databaseURL string) *DB {
|
||||
conn, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("[DB] 无法连接数据库: %v", err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
log.Fatalf("[DB] 数据库连接测试失败: %v", err)
|
||||
}
|
||||
|
||||
conn.SetMaxOpenConns(10)
|
||||
conn.SetMaxIdleConns(5)
|
||||
conn.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
d := &DB{conn}
|
||||
d.migrate()
|
||||
log.Println("[DB] 数据库初始化完成")
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DB) migrate() {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS gpt_accounts (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
refresh_token TEXT DEFAULT '',
|
||||
user_count INT DEFAULT 0,
|
||||
invite_count INT DEFAULT 0,
|
||||
chatgpt_account_id TEXT DEFAULT '',
|
||||
oai_device_id TEXT DEFAULT '',
|
||||
expire_at TEXT DEFAULT '',
|
||||
is_open BOOLEAN DEFAULT TRUE,
|
||||
is_banned BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS redemption_codes (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
code TEXT UNIQUE NOT NULL,
|
||||
is_redeemed BOOLEAN DEFAULT FALSE,
|
||||
redeemed_at TEXT,
|
||||
redeemed_by TEXT,
|
||||
account_email TEXT DEFAULT '',
|
||||
channel TEXT DEFAULT 'common',
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS telegram_admins (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id BIGINT UNIQUE NOT NULL,
|
||||
added_by BIGINT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
}
|
||||
for _, q := range queries {
|
||||
if _, err := d.Exec(q); err != nil {
|
||||
log.Fatalf("[DB] 建表失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------- GptAccount CRUD ---------------
|
||||
|
||||
// GetAllAccounts returns all accounts ordered by creation time desc.
|
||||
func (d *DB) GetAllAccounts() ([]model.GptAccount, error) {
|
||||
rows, err := d.Query(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts ORDER BY created_at DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanAccounts(rows)
|
||||
}
|
||||
|
||||
// GetOpenAccounts returns non-banned, open accounts that still have capacity.
|
||||
func (d *DB) GetOpenAccounts(capacity int) ([]model.GptAccount, error) {
|
||||
rows, err := d.Query(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts
|
||||
WHERE is_open = TRUE AND is_banned = FALSE
|
||||
AND (user_count + invite_count) < $1
|
||||
AND token != '' AND chatgpt_account_id != ''
|
||||
ORDER BY (user_count + invite_count) ASC, RANDOM()`, capacity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanAccounts(rows)
|
||||
}
|
||||
|
||||
// GetAccountByID fetches a single account by its ID.
|
||||
func (d *DB) GetAccountByID(id int64) (*model.GptAccount, error) {
|
||||
row := d.QueryRow(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts WHERE id = $1 LIMIT 1`, id)
|
||||
return scanAccount(row)
|
||||
}
|
||||
|
||||
// GetAccountByChatGPTAccountID fetches a single account by its OpenAI account ID.
|
||||
func (d *DB) GetAccountByChatGPTAccountID(accountID string) (*model.GptAccount, error) {
|
||||
row := d.QueryRow(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts WHERE chatgpt_account_id = $1 LIMIT 1`, accountID)
|
||||
return scanAccount(row)
|
||||
}
|
||||
|
||||
// GetAccountByEmail fetches a single account by email (case-insensitive).
|
||||
func (d *DB) GetAccountByEmail(email string) (*model.GptAccount, error) {
|
||||
row := d.QueryRow(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts WHERE lower(email) = lower($1) LIMIT 1`, email)
|
||||
return scanAccount(row)
|
||||
}
|
||||
|
||||
// GetAccountsWithRT returns accounts that have a refresh token and are not banned.
|
||||
func (d *DB) GetAccountsWithRT() ([]model.GptAccount, error) {
|
||||
rows, err := d.Query(`
|
||||
SELECT id, email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned,
|
||||
created_at, updated_at
|
||||
FROM gpt_accounts
|
||||
WHERE refresh_token != '' AND is_banned = FALSE
|
||||
ORDER BY id`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanAccounts(rows)
|
||||
}
|
||||
|
||||
// CreateAccount inserts a new account and returns its ID.
|
||||
func (d *DB) CreateAccount(a *model.GptAccount) (int64, error) {
|
||||
var id int64
|
||||
err := d.QueryRow(`
|
||||
INSERT INTO gpt_accounts (email, token, refresh_token, user_count, invite_count,
|
||||
chatgpt_account_id, oai_device_id, expire_at, is_open, is_banned)
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) RETURNING id`,
|
||||
a.Email, a.Token, a.RefreshToken, a.UserCount, a.InviteCount,
|
||||
a.ChatgptAccountID, a.OaiDeviceID, a.ExpireAt, a.IsOpen, a.IsBanned,
|
||||
).Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
// UpdateAccountTokens updates the access token and refresh token.
|
||||
func (d *DB) UpdateAccountTokens(id int64, accessToken, refreshToken string) error {
|
||||
_, err := d.Exec(`
|
||||
UPDATE gpt_accounts
|
||||
SET token = $1, refresh_token = $2, updated_at = NOW()
|
||||
WHERE id = $3`, accessToken, refreshToken, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateAccountEmail updates the email of an account and its associated codes.
|
||||
func (d *DB) UpdateAccountEmail(id int64, oldEmail, newEmail string) error {
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(`
|
||||
UPDATE gpt_accounts
|
||||
SET email = $1, updated_at = NOW()
|
||||
WHERE id = $2`, newEmail, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
UPDATE redemption_codes
|
||||
SET account_email = $1
|
||||
WHERE account_email = $2`, newEmail, oldEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// UpdateAccountCounts updates user_count and invite_count.
|
||||
func (d *DB) UpdateAccountCounts(id int64, userCount, inviteCount int) error {
|
||||
_, err := d.Exec(`
|
||||
UPDATE gpt_accounts
|
||||
SET user_count = $1, invite_count = $2, updated_at = NOW()
|
||||
WHERE id = $3`, userCount, inviteCount, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateAccountInfo updates chatgpt_account_id, expire_at, etc. from API data.
|
||||
func (d *DB) UpdateAccountInfo(id int64, chatgptAccountID, expireAt string) error {
|
||||
_, err := d.Exec(`
|
||||
UPDATE gpt_accounts
|
||||
SET chatgpt_account_id = $1, expire_at = $2, updated_at = NOW()
|
||||
WHERE id = $3`, chatgptAccountID, expireAt, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// BanAccount marks an account as banned and not open.
|
||||
func (d *DB) BanAccount(id int64) error {
|
||||
_, err := d.Exec(`
|
||||
UPDATE gpt_accounts
|
||||
SET is_banned = TRUE, is_open = FALSE, updated_at = NOW()
|
||||
WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// --------------- RedemptionCode CRUD ---------------
|
||||
|
||||
// GetCodeByCode fetches a redemption code by its code string.
|
||||
func (d *DB) GetCodeByCode(code string) (*model.RedemptionCode, error) {
|
||||
row := d.QueryRow(`
|
||||
SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at
|
||||
FROM redemption_codes WHERE code = $1`, code)
|
||||
var rc model.RedemptionCode
|
||||
err := row.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy,
|
||||
&rc.AccountEmail, &rc.Channel, &rc.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
// CreateCode inserts a batch of redemption codes for a given account email.
|
||||
func (d *DB) CreateCodes(accountEmail string, codes []string) error {
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO redemption_codes (code, account_email) VALUES ($1, $2)
|
||||
ON CONFLICT (code) DO NOTHING`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, c := range codes {
|
||||
if _, err := stmt.Exec(c, accountEmail); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RedeemCode marks a code as redeemed.
|
||||
func (d *DB) RedeemCode(codeID int64, redeemedBy string) error {
|
||||
_, err := d.Exec(`
|
||||
UPDATE redemption_codes
|
||||
SET is_redeemed = TRUE, redeemed_at = NOW()::TEXT, redeemed_by = $1
|
||||
WHERE id = $2`, redeemedBy, codeID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CountAvailableCodes returns the number of unredeemed codes.
|
||||
func (d *DB) CountAvailableCodes() (int, error) {
|
||||
var count int
|
||||
err := d.QueryRow(`SELECT COUNT(*) FROM redemption_codes WHERE is_redeemed = FALSE`).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountAvailableCodesByAccount returns the number of unredeemed codes for a specific account.
|
||||
func (d *DB) CountAvailableCodesByAccount(accountEmail string) (int, error) {
|
||||
var count int
|
||||
err := d.QueryRow(`
|
||||
SELECT COUNT(*) FROM redemption_codes
|
||||
WHERE is_redeemed = FALSE AND account_email = $1`, accountEmail).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetCodesByAccount returns all codes for an account email.
|
||||
func (d *DB) GetCodesByAccount(accountEmail string) ([]model.RedemptionCode, error) {
|
||||
rows, err := d.Query(`
|
||||
SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at
|
||||
FROM redemption_codes WHERE account_email = $1 ORDER BY id`, accountEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []model.RedemptionCode
|
||||
for rows.Next() {
|
||||
var rc model.RedemptionCode
|
||||
if err := rows.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy,
|
||||
&rc.AccountEmail, &rc.Channel, &rc.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, rc)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// GetAllCodes returns all redemption codes.
|
||||
func (d *DB) GetAllCodes() ([]model.RedemptionCode, error) {
|
||||
rows, err := d.Query(`
|
||||
SELECT id, code, is_redeemed, redeemed_at, redeemed_by, account_email, channel, created_at
|
||||
FROM redemption_codes ORDER BY id`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []model.RedemptionCode
|
||||
for rows.Next() {
|
||||
var rc model.RedemptionCode
|
||||
if err := rows.Scan(&rc.ID, &rc.Code, &rc.IsRedeemed, &rc.RedeemedAt, &rc.RedeemedBy,
|
||||
&rc.AccountEmail, &rc.Channel, &rc.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, rc)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteCode deletes a specific redemption code.
|
||||
func (d *DB) DeleteCode(code string) error {
|
||||
_, err := d.Exec(`DELETE FROM redemption_codes WHERE code = $1`, code)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUnusedCodes deletes all redemption codes that haven't been redeemed.
|
||||
// Returns the number of codes deleted.
|
||||
func (d *DB) DeleteUnusedCodes() (int64, error) {
|
||||
res, err := d.Exec(`DELETE FROM redemption_codes WHERE is_redeemed = false`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// DeleteAccount removes an account and its associated codes.
|
||||
func (d *DB) DeleteAccount(id int64) error {
|
||||
// Get account email first.
|
||||
acct, err := d.GetAccountByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Delete associated redemption codes (using the current email).
|
||||
_, _ = d.Exec(`DELETE FROM redemption_codes WHERE account_email = $1`, acct.Email)
|
||||
|
||||
// Delete the account.
|
||||
_, err = d.Exec(`DELETE FROM gpt_accounts WHERE id = $1`, id)
|
||||
|
||||
// Also clean up any orphaned codes that don't match any existing account email.
|
||||
// This fixes issues where codes were generated under a team name before the email was updated.
|
||||
_, _ = d.Exec(`DELETE FROM redemption_codes WHERE account_email NOT IN (SELECT email FROM gpt_accounts)`)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// --------------- TelegramAdmin CRUD ---------------
|
||||
|
||||
// AddAdmin inserts a new admin into the database.
|
||||
func (d *DB) AddAdmin(userID int64, addedBy int64) error {
|
||||
_, err := d.Exec(`
|
||||
INSERT INTO telegram_admins (user_id, added_by)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (user_id) DO NOTHING`, userID, addedBy)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveAdmin deletes an admin from the database by user ID.
|
||||
func (d *DB) RemoveAdmin(userID int64) error {
|
||||
_, err := d.Exec(`DELETE FROM telegram_admins WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAllAdmins returns a list of all admins stored in the database.
|
||||
func (d *DB) GetAllAdmins() ([]model.TelegramAdmin, error) {
|
||||
rows, err := d.Query(`SELECT id, user_id, added_by, created_at FROM telegram_admins ORDER BY created_at ASC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []model.TelegramAdmin
|
||||
for rows.Next() {
|
||||
var a model.TelegramAdmin
|
||||
if err := rows.Scan(&a.ID, &a.UserID, &a.AddedBy, &a.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, a)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// IsAdmin checks if a specific user ID exists in the admin table.
|
||||
func (d *DB) IsAdmin(userID int64) (bool, error) {
|
||||
var count int
|
||||
err := d.QueryRow(`SELECT count(*) FROM telegram_admins WHERE user_id = $1`, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// --------------- helpers ---------------
|
||||
|
||||
func scanAccounts(rows *sql.Rows) ([]model.GptAccount, error) {
|
||||
var list []model.GptAccount
|
||||
for rows.Next() {
|
||||
a, err := scanAccountFromRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, *a)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
func scanAccountFromRows(rows *sql.Rows) (*model.GptAccount, error) {
|
||||
var a model.GptAccount
|
||||
err := rows.Scan(&a.ID, &a.Email, &a.Token, &a.RefreshToken,
|
||||
&a.UserCount, &a.InviteCount, &a.ChatgptAccountID, &a.OaiDeviceID,
|
||||
&a.ExpireAt, &a.IsOpen, &a.IsBanned, &a.CreatedAt, &a.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
func scanAccount(row *sql.Row) (*model.GptAccount, error) {
|
||||
var a model.GptAccount
|
||||
err := row.Scan(&a.ID, &a.Email, &a.Token, &a.RefreshToken,
|
||||
&a.UserCount, &a.InviteCount, &a.ChatgptAccountID, &a.OaiDeviceID,
|
||||
&a.ExpireAt, &a.IsOpen, &a.IsBanned, &a.CreatedAt, &a.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("账号不存在: %w", err)
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
65
internal/model/model.go
Normal file
65
internal/model/model.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// GptAccount represents a ChatGPT Team account stored in the database.
|
||||
type GptAccount struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
UserCount int `json:"user_count"`
|
||||
InviteCount int `json:"invite_count"`
|
||||
ChatgptAccountID string `json:"chatgpt_account_id"`
|
||||
OaiDeviceID string `json:"oai_device_id"`
|
||||
ExpireAt string `json:"expire_at"`
|
||||
IsOpen bool `json:"is_open"`
|
||||
IsBanned bool `json:"is_banned"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// RedemptionCode represents a redemption code stored in the database.
|
||||
type RedemptionCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
IsRedeemed bool `json:"is_redeemed"`
|
||||
RedeemedAt *string `json:"redeemed_at"`
|
||||
RedeemedBy *string `json:"redeemed_by"`
|
||||
AccountEmail string `json:"account_email"`
|
||||
Channel string `json:"channel"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// TeamAccountInfo is the data returned from the ChatGPT accounts/check API.
|
||||
type TeamAccountInfo struct {
|
||||
AccountID string `json:"account_id"`
|
||||
Name string `json:"name"`
|
||||
PlanType string `json:"plan_type"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
HasActiveSubscription bool `json:"has_active_subscription"`
|
||||
}
|
||||
|
||||
// ChatGPTUser represents a team member returned from the users API.
|
||||
type ChatGPTUser struct {
|
||||
ID string `json:"id"`
|
||||
AccountUserID string `json:"account_user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// ChatGPTInvite represents a pending invite returned from the invites API.
|
||||
type ChatGPTInvite struct {
|
||||
ID string `json:"id"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// TelegramAdmin represents a Telegram admin user stored in the database.
|
||||
type TelegramAdmin struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
AddedBy int64 `json:"added_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
139
internal/redeem/redeem.go
Normal file
139
internal/redeem/redeem.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package redeem
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"go-helper/internal/chatgpt"
|
||||
"go-helper/internal/database"
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
emailRegex = regexp.MustCompile(`^[^\s@]+@[^\s@]+\.[^\s@]+$`)
|
||||
codeRegex = regexp.MustCompile(`^[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}$`)
|
||||
codeChars = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") // exclude confusable chars
|
||||
)
|
||||
|
||||
// RedeemResult contains information about a successful redemption.
|
||||
type RedeemResult struct {
|
||||
AccountEmail string
|
||||
InviteOK bool
|
||||
Message string
|
||||
}
|
||||
|
||||
// Redeem validates the code, finds an available account, and sends an invite.
|
||||
func Redeem(db *database.DB, client *chatgpt.Client, code, email string, capacity int) (*RedeemResult, error) {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
code = strings.TrimSpace(strings.ToUpper(code))
|
||||
|
||||
if email == "" {
|
||||
return nil, fmt.Errorf("请输入邮箱地址")
|
||||
}
|
||||
if !emailRegex.MatchString(email) {
|
||||
return nil, fmt.Errorf("邮箱格式不正确")
|
||||
}
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("请输入兑换码")
|
||||
}
|
||||
if !codeRegex.MatchString(code) {
|
||||
return nil, fmt.Errorf("兑换码格式不正确(格式:XXXX-XXXX-XXXX)")
|
||||
}
|
||||
|
||||
// 1. Look up the code.
|
||||
rc, err := db.GetCodeByCode(code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("兑换码不存在或已失效")
|
||||
}
|
||||
if rc.IsRedeemed {
|
||||
return nil, fmt.Errorf("该兑换码已被使用")
|
||||
}
|
||||
|
||||
// 2. Find a usable account.
|
||||
var account *model.GptAccount
|
||||
|
||||
if rc.AccountEmail != "" {
|
||||
// Code is bound to a specific account.
|
||||
accounts, err := db.GetOpenAccounts(capacity + 100) // get all open
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找账号失败: %v", err)
|
||||
}
|
||||
for i := range accounts {
|
||||
if strings.EqualFold(accounts[i].Email, rc.AccountEmail) {
|
||||
if accounts[i].UserCount+accounts[i].InviteCount < capacity {
|
||||
account = &accounts[i]
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if account == nil {
|
||||
return nil, fmt.Errorf("该兑换码绑定的账号不可用或已满")
|
||||
}
|
||||
} else {
|
||||
// Find any open account with capacity.
|
||||
accounts, err := db.GetOpenAccounts(capacity)
|
||||
if err != nil || len(accounts) == 0 {
|
||||
return nil, fmt.Errorf("暂无可用账号,请稍后再试")
|
||||
}
|
||||
account = &accounts[0]
|
||||
}
|
||||
|
||||
// 3. Send invite.
|
||||
inviteErr := client.InviteUser(email, account)
|
||||
|
||||
// 4. Mark code as redeemed regardless of invite outcome.
|
||||
if err := db.RedeemCode(rc.ID, email); err != nil {
|
||||
return nil, fmt.Errorf("更新兑换码状态失败: %v", err)
|
||||
}
|
||||
|
||||
// 5. Sync counts.
|
||||
syncCounts(db, client, account)
|
||||
|
||||
result := &RedeemResult{AccountEmail: account.Email}
|
||||
if inviteErr != nil {
|
||||
result.InviteOK = false
|
||||
result.Message = fmt.Sprintf("兑换成功,但邀请发送失败: %v\n请联系管理员手动添加", inviteErr)
|
||||
} else {
|
||||
result.InviteOK = true
|
||||
result.Message = "兑换成功!邀请邮件已发送到您的邮箱,请查收。"
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateCode creates a random code in XXXX-XXXX-XXXX format.
|
||||
func GenerateCode() string {
|
||||
parts := make([]byte, 12)
|
||||
for i := range parts {
|
||||
parts[i] = codeChars[rand.Intn(len(codeChars))]
|
||||
}
|
||||
return fmt.Sprintf("%s-%s-%s", string(parts[0:4]), string(parts[4:8]), string(parts[8:12]))
|
||||
}
|
||||
|
||||
// GenerateCodes creates n unique codes.
|
||||
func GenerateCodes(n int) []string {
|
||||
seen := make(map[string]bool)
|
||||
var codes []string
|
||||
for len(codes) < n {
|
||||
c := GenerateCode()
|
||||
if !seen[c] {
|
||||
seen[c] = true
|
||||
codes = append(codes, c)
|
||||
}
|
||||
}
|
||||
return codes
|
||||
}
|
||||
|
||||
// syncCounts updates user_count and invite_count from the ChatGPT API.
|
||||
func syncCounts(db *database.DB, client *chatgpt.Client, account *model.GptAccount) {
|
||||
userTotal, _, err := client.GetUsers(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
inviteTotal, _, err := client.GetInvites(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = db.UpdateAccountCounts(account.ID, userTotal, inviteTotal)
|
||||
}
|
||||
96
internal/scheduler/scheduler.go
Normal file
96
internal/scheduler/scheduler.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go-helper/internal/chatgpt"
|
||||
"go-helper/internal/database"
|
||||
)
|
||||
|
||||
// StartTokenChecker runs a periodic loop that refreshes access tokens
|
||||
// for all accounts that have a refresh token.
|
||||
func StartTokenChecker(db *database.DB, client *chatgpt.Client, intervalMinutes int) {
|
||||
if intervalMinutes <= 0 {
|
||||
intervalMinutes = 30
|
||||
}
|
||||
interval := time.Duration(intervalMinutes) * time.Minute
|
||||
log.Printf("[Scheduler] Token 定时检测已启动,间隔 %d 分钟", intervalMinutes)
|
||||
|
||||
go func() {
|
||||
// Run once immediately at startup.
|
||||
checkAndRefreshAll(db, client)
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
checkAndRefreshAll(db, client)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkAndRefreshAll(db *database.DB, client *chatgpt.Client) {
|
||||
accounts, err := db.GetAccountsWithRT()
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] 获取账号列表失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[Scheduler] 开始检查 %d 个账号的 Token 状态", len(accounts))
|
||||
|
||||
refreshed, failed, banned := 0, 0, 0
|
||||
for _, acc := range accounts {
|
||||
result, err := client.RefreshAccessToken(acc.RefreshToken)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] 刷新失败 [ID=%d %s]: %v", acc.ID, acc.Email, err)
|
||||
failed++
|
||||
|
||||
// Check if the error indicates token is completely invalid.
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if strings.Contains(errMsg, "invalid_grant") ||
|
||||
strings.Contains(errMsg, "token has been revoked") ||
|
||||
strings.Contains(errMsg, "unauthorized") {
|
||||
log.Printf("[Scheduler] RT 无效,标记封号 [ID=%d %s]", acc.ID, acc.Email)
|
||||
_ = db.BanAccount(acc.ID)
|
||||
banned++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := db.UpdateAccountTokens(acc.ID, result.AccessToken, result.RefreshToken); err != nil {
|
||||
log.Printf("[Scheduler] 更新 Token 失败 [ID=%d]: %v", acc.ID, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
// Also try to fetch account info to update subscription expiry.
|
||||
infos, err := client.FetchAccountInfo(result.AccessToken)
|
||||
if err != nil {
|
||||
// Token works but account info fetch failed — might be account_deactivated.
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if strings.Contains(errMsg, "account_deactivated") || strings.Contains(errMsg, "已停用") {
|
||||
log.Printf("[Scheduler] 账号已停用,标记封号 [ID=%d %s]", acc.ID, acc.Email)
|
||||
_ = db.BanAccount(acc.ID)
|
||||
banned++
|
||||
continue
|
||||
}
|
||||
} else if len(infos) > 0 {
|
||||
// Update expire_at from subscription info.
|
||||
for _, info := range infos {
|
||||
if info.AccountID == acc.ChatgptAccountID && info.ExpiresAt != "" {
|
||||
_ = db.UpdateAccountInfo(acc.ID, acc.ChatgptAccountID, info.ExpiresAt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
refreshed++
|
||||
}
|
||||
|
||||
log.Printf("[Scheduler] 检查完成: 刷新成功 %d, 失败 %d, 封号 %d", refreshed, failed, banned)
|
||||
}
|
||||
Reference in New Issue
Block a user