package chatgpt import ( "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "sync" "time" ) // OAuthSession holds PKCE session data for an in-progress OAuth login. type OAuthSession struct { CodeVerifier string State string CreatedAt time.Time } // OAuthResult holds the tokens and account info after a successful OAuth exchange. type OAuthResult struct { AccessToken string RefreshToken string Email string AccountID string Name string PlanType string } // OAuthManager manages PKCE sessions for the manual URL-paste OAuth flow. type OAuthManager struct { client *Client sessions map[string]*OAuthSession // state -> session mu sync.Mutex } // NewOAuthManager creates a new OAuth manager. func NewOAuthManager(client *Client) *OAuthManager { return &OAuthManager{ client: client, sessions: make(map[string]*OAuthSession), } } // GenerateAuthURL creates an OpenAI OAuth authorization URL with PKCE. // The redirect_uri points to localhost which won't load — user copies the URL from browser. func (m *OAuthManager) GenerateAuthURL() (authURL, state string, err error) { // Generate PKCE code verifier. verifierBytes := make([]byte, 64) if _, err := rand.Read(verifierBytes); err != nil { return "", "", fmt.Errorf("生成 PKCE 失败: %w", err) } codeVerifier := hex.EncodeToString(verifierBytes) hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) // Generate state. stateBytes := make([]byte, 32) if _, err := rand.Read(stateBytes); err != nil { return "", "", fmt.Errorf("生成 state 失败: %w", err) } state = hex.EncodeToString(stateBytes) // Store session. m.mu.Lock() m.sessions[state] = &OAuthSession{ CodeVerifier: codeVerifier, State: state, CreatedAt: time.Now(), } m.mu.Unlock() // Build auth URL — redirect to localhost, user will copy the URL. redirectURI := "http://localhost:1455/auth/callback" params := url.Values{} params.Set("response_type", "code") params.Set("client_id", openaiClientID) params.Set("redirect_uri", redirectURI) params.Set("scope", "openid profile email offline_access") params.Set("code_challenge", codeChallenge) params.Set("code_challenge_method", "S256") params.Set("state", state) params.Set("id_token_add_organizations", "true") params.Set("codex_cli_simplified_flow", "true") authURL = fmt.Sprintf("https://auth.openai.com/oauth/authorize?%s", params.Encode()) return authURL, state, nil } // ExchangeCallbackURL parses the pasted callback URL to extract code and state, // then exchanges the authorization code for tokens. func (m *OAuthManager) ExchangeCallbackURL(callbackURL string) (*OAuthResult, error) { callbackURL = strings.TrimSpace(callbackURL) parsed, err := url.Parse(callbackURL) if err != nil { return nil, fmt.Errorf("URL 格式错误: %w", err) } code := parsed.Query().Get("code") state := parsed.Query().Get("state") if code == "" { errMsg := parsed.Query().Get("error_description") if errMsg == "" { errMsg = parsed.Query().Get("error") } if errMsg == "" { errMsg = "回调 URL 中未找到 code 参数" } return nil, fmt.Errorf("%s", errMsg) } if state == "" { return nil, fmt.Errorf("回调 URL 中未找到 state 参数") } // Look up the session. m.mu.Lock() session, ok := m.sessions[state] if ok { delete(m.sessions, state) } m.mu.Unlock() if !ok { return nil, fmt.Errorf("会话已过期或无效,请重新使用 /login") } // Check if session is expired (10 minutes). if time.Since(session.CreatedAt) > 10*time.Minute { return nil, fmt.Errorf("登录会话已过期(超过10分钟),请重新使用 /login") } // Exchange code for tokens. return m.exchangeCode(code, session.CodeVerifier) } func (m *OAuthManager) exchangeCode(code, codeVerifier string) (*OAuthResult, error) { redirectURI := "http://localhost:1455/auth/callback" form := url.Values{} form.Set("grant_type", "authorization_code") form.Set("code", code) form.Set("redirect_uri", redirectURI) form.Set("client_id", openaiClientID) form.Set("code_verifier", codeVerifier) req, err := http.NewRequest("POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode())) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := m.client.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("网络错误: %w", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { return nil, fmt.Errorf("交换授权码失败 (HTTP %d): %s", resp.StatusCode, truncate(string(body), 300)) } var tokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` IDToken string `json:"id_token"` } if err := json.Unmarshal(body, &tokenResp); err != nil { return nil, fmt.Errorf("解析 token 响应失败: %w", err) } if tokenResp.AccessToken == "" { return nil, fmt.Errorf("未返回有效的 access token") } result := &OAuthResult{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, } // Decode ID token for user info. if tokenResp.IDToken != "" { claims, err := decodeJWTPayload(tokenResp.IDToken) if err == nil { result.Email, _ = claims["email"].(string) result.Name, _ = claims["name"].(string) if authClaims, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { result.AccountID, _ = authClaims["chatgpt_account_id"].(string) result.PlanType, _ = authClaims["chatgpt_plan_type"].(string) } } } return result, nil } func decodeJWTPayload(token string) (map[string]interface{}, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return nil, fmt.Errorf("invalid JWT format") } payload := parts[1] switch len(payload) % 4 { case 2: payload += "==" case 3: payload += "=" } decoded, err := base64.URLEncoding.DecodeString(payload) if err != nil { decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, err } } var claims map[string]interface{} if err := json.Unmarshal(decoded, &claims); err != nil { return nil, err } return claims, nil }