feat: 初始化 ChatGPT Team 管理机器人
核心功能: - 实现基于 Telegram Inline Button 交互的后台面板与用户端 - 支持通过账密登录和 RT (Refresh Token) 方式添加 ChatGPT Team 账号 - 支持管理、拉取和删除待处理邀请,支持一键清空多余邀请 - 支持按剩余容量自动生成邀请兑换码,支持分页查看与一键清空未使用兑换码 - 随机邀请功能:成功拉人后自动核销兑换码 - 定时检测 Token 状态,实现自动续订/刷新并拦截封禁账号 (处理 401/402 错误) 系统与配置: - 使用 PostgreSQL 数据库管理账号、邀请和兑换记录 - 支持在端内动态添加、移除管理员 - 完善 Docker 部署配置与 .gitignore 规则
This commit is contained in:
111
internal/chatgpt/account.go
Normal file
111
internal/chatgpt/account.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
// FetchAccountInfo queries the ChatGPT accounts check API and returns team accounts
|
||||
// with subscription expiry information.
|
||||
func (c *Client) FetchAccountInfo(accessToken string) ([]model.TeamAccountInfo, error) {
|
||||
token := strings.TrimPrefix(strings.TrimSpace(accessToken), "Bearer ")
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("缺少 access token")
|
||||
}
|
||||
|
||||
apiURL := "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27"
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
req.Header.Set("Oai-Client-Version", oaiClientVersion)
|
||||
req.Header.Set("Oai-Language", "zh-CN")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求 ChatGPT API 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 401 || resp.StatusCode == 402 {
|
||||
return nil, fmt.Errorf("Token 已过期或被封禁")
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("ChatGPT API 错误 %d: %s", resp.StatusCode, truncate(string(body), 300))
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Accounts map[string]json.RawMessage `json:"accounts"`
|
||||
AccountOrdering []string `json:"account_ordering"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
var results []model.TeamAccountInfo
|
||||
|
||||
// Determine order.
|
||||
seen := make(map[string]bool)
|
||||
var orderedIDs []string
|
||||
for _, id := range data.AccountOrdering {
|
||||
if _, ok := data.Accounts[id]; ok && !seen[id] {
|
||||
orderedIDs = append(orderedIDs, id)
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
for id := range data.Accounts {
|
||||
if !seen[id] && id != "default" {
|
||||
orderedIDs = append(orderedIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range orderedIDs {
|
||||
raw := data.Accounts[id]
|
||||
var acc struct {
|
||||
Account struct {
|
||||
Name string `json:"name"`
|
||||
PlanType string `json:"plan_type"`
|
||||
} `json:"account"`
|
||||
Entitlement struct {
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
HasActiveSubscription bool `json:"has_active_subscription"`
|
||||
} `json:"entitlement"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &acc); err != nil {
|
||||
continue
|
||||
}
|
||||
if acc.Account.PlanType != "team" {
|
||||
continue
|
||||
}
|
||||
results = append(results, model.TeamAccountInfo{
|
||||
AccountID: id,
|
||||
Name: acc.Account.Name,
|
||||
PlanType: acc.Account.PlanType,
|
||||
ExpiresAt: acc.Entitlement.ExpiresAt,
|
||||
HasActiveSubscription: acc.Entitlement.HasActiveSubscription,
|
||||
})
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, fmt.Errorf("未找到 Team 类型的账号")
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "..."
|
||||
}
|
||||
88
internal/chatgpt/client.go
Normal file
88
internal/chatgpt/client.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
oaiClientVersion = "prod-eddc2f6ff65fee2d0d6439e379eab94fe3047f72"
|
||||
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36"
|
||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
)
|
||||
|
||||
// Client wraps HTTP operations with common headers and optional proxy.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
// NewClient creates a ChatGPT API client with optional proxy support.
|
||||
func NewClient(proxyURL string) *Client {
|
||||
c := &Client{proxyURL: proxyURL}
|
||||
c.httpClient = c.buildHTTPClient()
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) buildHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
if c.proxyURL != "" {
|
||||
parsed, err := url.Parse(c.proxyURL)
|
||||
if err == nil {
|
||||
scheme := parsed.Scheme
|
||||
if scheme == "socks5" || scheme == "socks5h" {
|
||||
var auth *proxy.Auth
|
||||
if parsed.User != nil {
|
||||
pass, _ := parsed.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: parsed.User.Username(),
|
||||
Password: pass,
|
||||
}
|
||||
}
|
||||
dialer, dErr := proxy.SOCKS5("tcp", parsed.Host, auth, proxy.Direct)
|
||||
if dErr == nil {
|
||||
if ctxDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
transport.DialContext = ctxDialer.DialContext
|
||||
} else {
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// buildHeaders returns common headers for ChatGPT backend API calls.
|
||||
func buildHeaders(token, chatgptAccountID string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Accept", "*/*")
|
||||
h.Set("Accept-Language", "zh-CN,zh;q=0.9")
|
||||
h.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
h.Set("Chatgpt-Account-Id", chatgptAccountID)
|
||||
h.Set("Oai-Client-Version", oaiClientVersion)
|
||||
h.Set("Oai-Language", "zh-CN")
|
||||
h.Set("User-Agent", userAgent)
|
||||
h.Set("Origin", "https://chatgpt.com")
|
||||
h.Set("Referer", "https://chatgpt.com/admin/members")
|
||||
return h
|
||||
}
|
||||
93
internal/chatgpt/invite.go
Normal file
93
internal/chatgpt/invite.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
maxInviteAttempts = 3
|
||||
retryBaseDelay = 800 * time.Millisecond
|
||||
retryMaxDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
// InviteUser sends an invitation to the given email on the specified team account.
|
||||
func (c *Client) InviteUser(email string, account *model.GptAccount) error {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
if email == "" {
|
||||
return fmt.Errorf("缺少邀请邮箱")
|
||||
}
|
||||
if account.Token == "" || account.ChatgptAccountID == "" {
|
||||
return fmt.Errorf("账号配置不完整")
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites", account.ChatgptAccountID)
|
||||
payload := map[string]interface{}{
|
||||
"email_addresses": []string{email},
|
||||
"role": "standard-user",
|
||||
"resend_emails": true,
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(payload)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxInviteAttempts; attempt++ {
|
||||
req, err := http.NewRequest("POST", apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers := buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header = headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if account.OaiDeviceID != "" {
|
||||
req.Header.Set("Oai-Device-Id", account.OaiDeviceID)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("网络错误: %w", err)
|
||||
log.Printf("[Invite] 尝试 %d/%d 网络错误: %v", attempt, maxInviteAttempts, err)
|
||||
if attempt < maxInviteAttempts {
|
||||
sleepRetry(attempt)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
log.Printf("[Invite] 成功邀请 %s 到账号 %s", email, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncate(string(respBody), 300))
|
||||
log.Printf("[Invite] 尝试 %d/%d 失败: %v", attempt, maxInviteAttempts, lastErr)
|
||||
|
||||
if !isRetryableStatus(resp.StatusCode) || attempt >= maxInviteAttempts {
|
||||
break
|
||||
}
|
||||
sleepRetry(attempt)
|
||||
}
|
||||
return fmt.Errorf("邀请失败(已尝试 %d 次): %v", maxInviteAttempts, lastErr)
|
||||
}
|
||||
|
||||
func sleepRetry(attempt int) {
|
||||
delay := float64(retryBaseDelay) * math.Pow(2, float64(attempt-1))
|
||||
if delay > float64(retryMaxDelay) {
|
||||
delay = float64(retryMaxDelay)
|
||||
}
|
||||
time.Sleep(time.Duration(delay))
|
||||
}
|
||||
|
||||
func isRetryableStatus(status int) bool {
|
||||
return status == 408 || status == 429 || (status >= 500 && status <= 599)
|
||||
}
|
||||
195
internal/chatgpt/member.go
Normal file
195
internal/chatgpt/member.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go-helper/internal/model"
|
||||
)
|
||||
|
||||
// GetUsers fetches the list of members for a team account.
|
||||
func (c *Client) GetUsers(account *model.GptAccount) (int, []model.ChatGPTUser, error) {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/users?offset=0&limit=100&query=",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("获取成员列表网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "获取成员"); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Total int `json:"total"`
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
AccountUserID string `json:"account_user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Name string `json:"name"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return 0, nil, fmt.Errorf("解析成员数据失败: %w", err)
|
||||
}
|
||||
|
||||
var users []model.ChatGPTUser
|
||||
for _, item := range data.Items {
|
||||
users = append(users, model.ChatGPTUser{
|
||||
ID: item.ID,
|
||||
AccountUserID: item.AccountUserID,
|
||||
Email: item.Email,
|
||||
Role: item.Role,
|
||||
Name: item.Name,
|
||||
})
|
||||
}
|
||||
return data.Total, users, nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user from the team account.
|
||||
func (c *Client) DeleteUser(account *model.GptAccount, userID string) error {
|
||||
normalizedID := userID
|
||||
if !strings.HasPrefix(normalizedID, "user-") {
|
||||
normalizedID = "user-" + normalizedID
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/users/%s",
|
||||
account.ChatgptAccountID, normalizedID)
|
||||
|
||||
req, err := http.NewRequest("DELETE", apiURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除用户网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "删除用户"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[Member] 成功删除用户 %s (账号 %s)", normalizedID, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetInvites fetches pending invitations for a team account.
|
||||
func (c *Client) GetInvites(account *model.GptAccount) (int, []model.ChatGPTInvite, error) {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites?offset=0&limit=100&query=",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header.Set("Referer", "https://chatgpt.com/admin/members?tab=invites")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("获取邀请列表网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "获取邀请列表"); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Total int `json:"total"`
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
EmailAddress string `json:"email_address"`
|
||||
Role string `json:"role"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return 0, nil, fmt.Errorf("解析邀请数据失败: %w", err)
|
||||
}
|
||||
|
||||
var invites []model.ChatGPTInvite
|
||||
for _, item := range data.Items {
|
||||
invites = append(invites, model.ChatGPTInvite{
|
||||
ID: item.ID,
|
||||
EmailAddress: item.EmailAddress,
|
||||
Role: item.Role,
|
||||
})
|
||||
}
|
||||
return data.Total, invites, nil
|
||||
}
|
||||
|
||||
// DeleteInvite cancels a pending invitation by email on the team account.
|
||||
func (c *Client) DeleteInvite(account *model.GptAccount, email string) error {
|
||||
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/invites",
|
||||
account.ChatgptAccountID)
|
||||
|
||||
payload := map[string]string{"email_address": strings.TrimSpace(strings.ToLower(email))}
|
||||
bodyBytes, _ := json.Marshal(payload)
|
||||
|
||||
req, err := http.NewRequest("DELETE", apiURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header = buildHeaders(account.Token, account.ChatgptAccountID)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Referer", "https://chatgpt.com/admin/members?tab=invites")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除邀请网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if err := checkAPIError(resp.StatusCode, body, "删除邀请"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("[Member] 成功删除邀请 %s (账号 %s)", email, account.ChatgptAccountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAPIError inspects the HTTP status code and returns a descriptive error.
|
||||
func checkAPIError(statusCode int, body []byte, label string) error {
|
||||
if statusCode >= 200 && statusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := truncate(string(body), 500)
|
||||
|
||||
// Check for account_deactivated.
|
||||
if strings.Contains(bodyStr, "account_deactivated") {
|
||||
return fmt.Errorf("账号已停用 (account_deactivated)")
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401, 402:
|
||||
return fmt.Errorf("Token 已过期或被封禁")
|
||||
case 403:
|
||||
return fmt.Errorf("资源不存在或无权访问")
|
||||
case 429:
|
||||
return fmt.Errorf("API 请求过于频繁,请稍后重试")
|
||||
default:
|
||||
return fmt.Errorf("%s失败 (HTTP %d): %s", label, statusCode, bodyStr)
|
||||
}
|
||||
}
|
||||
233
internal/chatgpt/oauth.go
Normal file
233
internal/chatgpt/oauth.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthSession holds PKCE session data for an in-progress OAuth login.
|
||||
type OAuthSession struct {
|
||||
CodeVerifier string
|
||||
State string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// OAuthResult holds the tokens and account info after a successful OAuth exchange.
|
||||
type OAuthResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
Email string
|
||||
AccountID string
|
||||
Name string
|
||||
PlanType string
|
||||
}
|
||||
|
||||
// OAuthManager manages PKCE sessions for the manual URL-paste OAuth flow.
|
||||
type OAuthManager struct {
|
||||
client *Client
|
||||
sessions map[string]*OAuthSession // state -> session
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewOAuthManager creates a new OAuth manager.
|
||||
func NewOAuthManager(client *Client) *OAuthManager {
|
||||
return &OAuthManager{
|
||||
client: client,
|
||||
sessions: make(map[string]*OAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURL creates an OpenAI OAuth authorization URL with PKCE.
|
||||
// The redirect_uri points to localhost which won't load — user copies the URL from browser.
|
||||
func (m *OAuthManager) GenerateAuthURL() (authURL, state string, err error) {
|
||||
// Generate PKCE code verifier.
|
||||
verifierBytes := make([]byte, 64)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return "", "", fmt.Errorf("生成 PKCE 失败: %w", err)
|
||||
}
|
||||
codeVerifier := hex.EncodeToString(verifierBytes)
|
||||
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
// Generate state.
|
||||
stateBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(stateBytes); err != nil {
|
||||
return "", "", fmt.Errorf("生成 state 失败: %w", err)
|
||||
}
|
||||
state = hex.EncodeToString(stateBytes)
|
||||
|
||||
// Store session.
|
||||
m.mu.Lock()
|
||||
m.sessions[state] = &OAuthSession{
|
||||
CodeVerifier: codeVerifier,
|
||||
State: state,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Build auth URL — redirect to localhost, user will copy the URL.
|
||||
redirectURI := "http://localhost:1455/auth/callback"
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", openaiClientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", "openid profile email offline_access")
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
params.Set("state", state)
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
|
||||
authURL = fmt.Sprintf("https://auth.openai.com/oauth/authorize?%s", params.Encode())
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
// ExchangeCallbackURL parses the pasted callback URL to extract code and state,
|
||||
// then exchanges the authorization code for tokens.
|
||||
func (m *OAuthManager) ExchangeCallbackURL(callbackURL string) (*OAuthResult, error) {
|
||||
callbackURL = strings.TrimSpace(callbackURL)
|
||||
|
||||
parsed, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("URL 格式错误: %w", err)
|
||||
}
|
||||
|
||||
code := parsed.Query().Get("code")
|
||||
state := parsed.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
errMsg := parsed.Query().Get("error_description")
|
||||
if errMsg == "" {
|
||||
errMsg = parsed.Query().Get("error")
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = "回调 URL 中未找到 code 参数"
|
||||
}
|
||||
return nil, fmt.Errorf("%s", errMsg)
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
return nil, fmt.Errorf("回调 URL 中未找到 state 参数")
|
||||
}
|
||||
|
||||
// Look up the session.
|
||||
m.mu.Lock()
|
||||
session, ok := m.sessions[state]
|
||||
if ok {
|
||||
delete(m.sessions, state)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("会话已过期或无效,请重新使用 /login")
|
||||
}
|
||||
|
||||
// Check if session is expired (10 minutes).
|
||||
if time.Since(session.CreatedAt) > 10*time.Minute {
|
||||
return nil, fmt.Errorf("登录会话已过期(超过10分钟),请重新使用 /login")
|
||||
}
|
||||
|
||||
// Exchange code for tokens.
|
||||
return m.exchangeCode(code, session.CodeVerifier)
|
||||
}
|
||||
|
||||
func (m *OAuthManager) exchangeCode(code, codeVerifier string) (*OAuthResult, error) {
|
||||
redirectURI := "http://localhost:1455/auth/callback"
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("client_id", openaiClientID)
|
||||
form.Set("code_verifier", codeVerifier)
|
||||
|
||||
req, err := http.NewRequest("POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := m.client.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("交换授权码失败 (HTTP %d): %s", resp.StatusCode, truncate(string(body), 300))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("解析 token 响应失败: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("未返回有效的 access token")
|
||||
}
|
||||
|
||||
result := &OAuthResult{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
}
|
||||
|
||||
// Decode ID token for user info.
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := decodeJWTPayload(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
result.Email, _ = claims["email"].(string)
|
||||
result.Name, _ = claims["name"].(string)
|
||||
if authClaims, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
result.AccountID, _ = authClaims["chatgpt_account_id"].(string)
|
||||
result.PlanType, _ = authClaims["chatgpt_plan_type"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decodeJWTPayload(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
109
internal/chatgpt/token.go
Normal file
109
internal/chatgpt/token.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// TokenResult holds the result of a token refresh.
|
||||
type TokenResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
IDToken string
|
||||
ExpiresIn int
|
||||
}
|
||||
|
||||
// GetEmail extracts the user's email from the ID token or Access token.
|
||||
func (tr *TokenResult) GetEmail() string {
|
||||
if tr.IDToken != "" {
|
||||
claims, err := decodeJWTPayload(tr.IDToken)
|
||||
if err == nil {
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
}
|
||||
if tr.AccessToken != "" {
|
||||
claims, err := decodeJWTPayload(tr.AccessToken)
|
||||
if err == nil {
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RefreshAccessToken exchanges a refresh token for a new access token.
|
||||
func (c *Client) RefreshAccessToken(refreshToken string) (*TokenResult, error) {
|
||||
rt := strings.TrimSpace(refreshToken)
|
||||
if rt == "" {
|
||||
return nil, fmt.Errorf("refresh token 为空")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("client_id", openaiClientID)
|
||||
form.Set("refresh_token", rt)
|
||||
form.Set("scope", "openid profile email")
|
||||
|
||||
req, err := http.NewRequest("POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新 token 网络错误: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
var errData struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &errData)
|
||||
msg := errData.ErrorDescription
|
||||
if msg == "" {
|
||||
msg = errData.Error
|
||||
}
|
||||
if msg == "" {
|
||||
msg = fmt.Sprintf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("刷新 token 失败: %s", msg)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析刷新结果失败: %w", err)
|
||||
}
|
||||
if result.AccessToken == "" {
|
||||
return nil, fmt.Errorf("刷新 token 失败: 未返回有效凭证")
|
||||
}
|
||||
|
||||
newRT := result.RefreshToken
|
||||
if newRT == "" {
|
||||
newRT = rt
|
||||
}
|
||||
|
||||
return &TokenResult{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: newRT,
|
||||
IDToken: result.IDToken,
|
||||
ExpiresIn: result.ExpiresIn,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user