Files
codexautopool/backend/internal/demote/demote.go

293 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package demote
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"codex-pool/internal/client"
)
// DemoteRequest 降级请求
type DemoteRequest struct {
AccessToken string `json:"access_token"` // JWT Token 或完整的 session JSON
AccountID string `json:"account_id"` // 可选Team ID
Role string `json:"role"` // 目标角色: standard-user 或 account-admin
Proxy string `json:"proxy"` // 可选,代理
}
// DemoteResult 降级结果
type DemoteResult struct {
Success bool `json:"success"`
Email string `json:"email,omitempty"`
OriginalRole string `json:"original_role,omitempty"`
NewRole string `json:"new_role,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
// SessionData ChatGPT session 数据结构
type SessionData struct {
AccessToken string `json:"accessToken"`
User struct {
ID string `json:"id"`
Email string `json:"email"`
} `json:"user"`
Account struct {
ID string `json:"id"`
} `json:"account"`
}
// JWTPayload JWT 解析结构
type JWTPayload struct {
Auth struct {
AccountUserID string `json:"chatgpt_account_user_id"`
} `json:"https://api.openai.com/auth"`
Profile struct {
Email string `json:"email"`
} `json:"https://api.openai.com/profile"`
}
// decodeJWTPayload 解码 JWT 的 payload 部分
func decodeJWTPayload(token string) (*JWTPayload, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format")
}
payload := parts[1]
// 补齐 base64 padding
if m := len(payload) % 4; m != 0 {
payload += strings.Repeat("=", 4-m)
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// 尝试标准 base64
decoded, err = base64.StdEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
}
}
var result JWTPayload
if err := json.Unmarshal(decoded, &result); err != nil {
return nil, fmt.Errorf("failed to parse JWT payload: %v", err)
}
return &result, nil
}
// extractUserInfo 从 Token 和 Session 中提取用户信息
func extractUserInfo(token string, session *SessionData) (userID, accountID, email string) {
// 优先从 session 获取(更准确)
if session != nil {
if session.User.ID != "" {
userID = session.User.ID
}
if session.Account.ID != "" {
accountID = session.Account.ID
}
if session.User.Email != "" {
email = session.User.Email
}
// 如果已经获取到必要信息,直接返回
if userID != "" && accountID != "" {
return
}
}
// 备选:从 JWT 解析
payload, err := decodeJWTPayload(token)
if err != nil {
return
}
// 解析 chatgpt_account_user_id格式可能是 user-xxx__account-id
accountUserID := payload.Auth.AccountUserID
if accountUserID != "" {
if strings.Contains(accountUserID, "__") {
parts := strings.Split(accountUserID, "__")
if userID == "" {
userID = parts[0]
}
if accountID == "" && len(parts) > 1 {
accountID = parts[1]
}
} else if userID == "" {
userID = accountUserID
}
}
if email == "" {
email = payload.Profile.Email
}
return
}
// DemoteOwner 执行 Owner 降级
func DemoteOwner(req DemoteRequest) *DemoteResult {
// 验证角色
validRoles := map[string]bool{
"standard-user": true,
"account-admin": true,
}
if !validRoles[req.Role] {
return &DemoteResult{
Success: false,
Error: "无效的角色,必须是 standard-user 或 account-admin",
}
}
accessToken := strings.TrimSpace(req.AccessToken)
var session *SessionData
// 检测是否是完整的 session JSON
if strings.HasPrefix(accessToken, "{") {
if err := json.Unmarshal([]byte(accessToken), &session); err != nil {
return &DemoteResult{
Success: false,
Error: "无法解析 session JSON: " + err.Error(),
}
}
accessToken = session.AccessToken
if accessToken == "" {
return &DemoteResult{
Success: false,
Error: "session JSON 中没有 accessToken",
}
}
}
// 提取用户信息
userID, accountID, email := extractUserInfo(accessToken, session)
// 如果请求中指定了 account_id优先使用
if req.AccountID != "" {
accountID = req.AccountID
}
if userID == "" {
return &DemoteResult{
Success: false,
Error: "无法获取 user_id请提供完整的 session JSON",
}
}
if accountID == "" {
return &DemoteResult{
Success: false,
Error: "无法获取 account_id请提供完整的 session JSON 或指定 account_id",
}
}
// 创建 TLS 客户端
var tlsClient *client.TLSClient
var lastErr error
// 403 重试机制 - 最多 3 次
for retry := 0; retry < 3; retry++ {
var err error
tlsClient, err = client.New(req.Proxy)
if err != nil {
lastErr = err
continue
}
// 初始化会话 - 先访问 chatgpt.com 通过 CF 验证
resp, err := tlsClient.Get("https://chatgpt.com")
if err != nil {
lastErr = err
tlsClient.Close()
continue
}
resp.Body.Close()
if resp.StatusCode == 403 {
lastErr = fmt.Errorf("Cloudflare 403")
tlsClient.Close()
continue
}
if resp.StatusCode == 200 {
lastErr = nil
break
}
lastErr = fmt.Errorf("HTTP %d", resp.StatusCode)
tlsClient.Close()
}
if lastErr != nil {
return &DemoteResult{
Success: false,
Email: email,
Error: fmt.Sprintf("初始化会话失败已重试3次: %v", lastErr),
}
}
defer tlsClient.Close()
// 构建降级 API 请求
apiURL := fmt.Sprintf("https://chatgpt.com/backend-api/accounts/%s/users/%s", accountID, userID)
payload := map[string]string{"role": req.Role}
jsonBody, _ := json.Marshal(payload)
patchReq, err := http.NewRequest("PATCH", apiURL, bytes.NewReader(jsonBody))
if err != nil {
return &DemoteResult{
Success: false,
Email: email,
Error: "创建请求失败: " + err.Error(),
}
}
patchReq.Header.Set("Authorization", "Bearer "+accessToken)
patchReq.Header.Set("Content-Type", "application/json")
patchReq.Header.Set("Accept", "*/*")
patchReq.Header.Set("Origin", "https://chatgpt.com")
patchReq.Header.Set("Referer", "https://chatgpt.com/")
resp, err := tlsClient.Do(patchReq)
if err != nil {
return &DemoteResult{
Success: false,
Email: email,
Error: "请求失败: " + err.Error(),
}
}
defer resp.Body.Close()
body, _ := client.ReadBodyString(resp)
if resp.StatusCode == 200 {
roleDisplay := "普通成员"
if req.Role == "account-admin" {
roleDisplay = "管理员"
}
return &DemoteResult{
Success: true,
Email: email,
NewRole: req.Role,
Message: fmt.Sprintf("成功降级为%s", roleDisplay),
}
}
return &DemoteResult{
Success: false,
Email: email,
Error: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, truncateStr(body, 200)),
}
}
func truncateStr(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max] + "..."
}