# ==================== Telegram Bot 主程序 ==================== # 通过 Telegram 远程控制 OpenAI Team 批量注册任务 import asyncio import sys from concurrent.futures import ThreadPoolExecutor from functools import wraps from typing import Optional from telegram import Update, Bot from telegram.ext import ( Application, CommandHandler, MessageHandler, filters, ContextTypes, ) from config import ( TELEGRAM_BOT_TOKEN, TELEGRAM_ADMIN_CHAT_IDS, TELEGRAM_ENABLED, TEAMS, AUTH_PROVIDER, TEAM_JSON_FILE, TELEGRAM_CHECK_INTERVAL, TELEGRAM_LOW_STOCK_THRESHOLD, ) from utils import load_team_tracker from bot_notifier import BotNotifier, set_notifier, progress_finish from s2a_service import s2a_get_dashboard_stats, format_dashboard_stats from logger import log def admin_only(func): """装饰器: 仅允许管理员执行命令""" @wraps(func) async def wrapper(self, update: Update, context: ContextTypes.DEFAULT_TYPE): user_id = update.effective_user.id if user_id not in TELEGRAM_ADMIN_CHAT_IDS: await update.message.reply_text("Unauthorized. Your ID is not in admin list.") return return await func(self, update, context) return wrapper class ProvisionerBot: """OpenAI Team Provisioner Telegram Bot""" def __init__(self): self.executor = ThreadPoolExecutor(max_workers=1) self.current_task: Optional[asyncio.Task] = None self.current_team: Optional[str] = None self.app: Optional[Application] = None self.notifier: Optional[BotNotifier] = None self._shutdown_event = asyncio.Event() async def start(self): """启动 Bot""" if not TELEGRAM_BOT_TOKEN: log.error("Telegram Bot Token not configured") return # 创建 Application self.app = Application.builder().token(TELEGRAM_BOT_TOKEN).build() # 初始化通知器 self.notifier = BotNotifier(self.app.bot, TELEGRAM_ADMIN_CHAT_IDS) set_notifier(self.notifier) # 注册命令处理器 handlers = [ ("start", self.cmd_help), ("help", self.cmd_help), ("status", self.cmd_status), ("team", self.cmd_team), ("run", self.cmd_run), ("run_all", self.cmd_run_all), ("stop", self.cmd_stop), ("logs", self.cmd_logs), ("dashboard", self.cmd_dashboard), ("import", self.cmd_import), ("stock", self.cmd_stock), ] for cmd, handler in handlers: self.app.add_handler(CommandHandler(cmd, handler)) # 注册文件上传处理器 (JSON 文件) self.app.add_handler(MessageHandler( filters.Document.MimeType("application/json"), self.handle_json_file )) # 注册定时检查任务 if TELEGRAM_CHECK_INTERVAL > 0 and AUTH_PROVIDER == "s2a": self.app.job_queue.run_repeating( self.scheduled_stock_check, interval=TELEGRAM_CHECK_INTERVAL, first=60, # 启动后1分钟执行第一次 name="stock_check" ) log.info(f"Stock check scheduled every {TELEGRAM_CHECK_INTERVAL}s") # 启动通知器 await self.notifier.start() log.success("Telegram Bot started") log.info(f"Admin Chat IDs: {TELEGRAM_ADMIN_CHAT_IDS}") # 发送启动通知 await self.notifier.notify("Bot Started\nReady for commands. Send /help for usage.") # 运行 Bot await self.app.initialize() await self.app.start() await self.app.updater.start_polling(drop_pending_updates=True) # 等待关闭信号 await self._shutdown_event.wait() # 清理 await self.app.updater.stop() await self.app.stop() await self.app.shutdown() await self.notifier.stop() def request_shutdown(self): """请求关闭 Bot""" self._shutdown_event.set() @admin_only async def cmd_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """显示帮助信息""" help_text = """OpenAI Team Provisioner Bot Commands: /status - View all teams status /team <n> - View team N details /run <n> - Start processing team N /run_all - Start processing all teams /stop - Stop current task /logs [n] - View recent n logs (default 10) /dashboard - View S2A dashboard stats /stock - Check account stock /import - Upload accounts to team.json /help - Show this help Upload Accounts: Send a JSON file or use /import with JSON data: [{"account":"email","password":"pwd","token":"jwt"},...] Then use /run to process them. Examples: /run 0 - Process first team /team 1 - View second team status /logs 20 - View last 20 logs""" await update.message.reply_text(help_text, parse_mode="HTML") @admin_only async def cmd_status(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """查看所有 Team 状态""" tracker = load_team_tracker() teams_data = tracker.get("teams", {}) if not teams_data: await update.message.reply_text("No data yet. Run tasks first.") return lines = ["Teams Status\n"] for team_name, accounts in teams_data.items(): total = len(accounts) completed = sum(1 for a in accounts if a.get("status") == "completed") failed = sum(1 for a in accounts if "fail" in a.get("status", "").lower()) pending = total - completed - failed status_icon = "OK" if completed == total else ("FAIL" if failed > 0 else "...") lines.append( f"[{status_icon}] {team_name}: {completed}/{total} " f"(F:{failed} P:{pending})" ) # 当前任务状态 if self.current_task and not self.current_task.done(): lines.append(f"\nRunning: {self.current_team or 'Unknown'}") await update.message.reply_text("\n".join(lines), parse_mode="HTML") @admin_only async def cmd_team(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """查看指定 Team 详情""" if not context.args: await update.message.reply_text("Usage: /team \nExample: /team 0") return try: team_idx = int(context.args[0]) except ValueError: await update.message.reply_text("Invalid team index. Must be a number.") return if team_idx < 0 or team_idx >= len(TEAMS): await update.message.reply_text(f"Team index out of range. Valid: 0-{len(TEAMS)-1}") return team = TEAMS[team_idx] team_name = team.get("name", f"Team{team_idx}") tracker = load_team_tracker() accounts = tracker.get("teams", {}).get(team_name, []) lines = [f"Team {team_idx}: {team_name}\n"] lines.append(f"Owner: {team.get('owner_email', 'N/A')}") lines.append(f"Accounts: {len(accounts)}\n") if accounts: for acc in accounts: email = acc.get("email", "") status = acc.get("status", "unknown") role = acc.get("role", "member") icon = {"completed": "OK", "authorized": "AUTH", "registered": "REG"}.get( status, "FAIL" if "fail" in status.lower() else "..." ) role_tag = " [O]" if role == "owner" else "" lines.append(f"[{icon}] {email}{role_tag}") else: lines.append("No accounts processed yet.") await update.message.reply_text("\n".join(lines), parse_mode="HTML") @admin_only async def cmd_run(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """启动处理指定 Team""" if self.current_task and not self.current_task.done(): await update.message.reply_text( f"Task already running: {self.current_team}\nUse /stop to cancel." ) return if not context.args: await update.message.reply_text("Usage: /run \nExample: /run 0") return try: team_idx = int(context.args[0]) except ValueError: await update.message.reply_text("Invalid team index. Must be a number.") return if team_idx < 0 or team_idx >= len(TEAMS): await update.message.reply_text(f"Team index out of range. Valid: 0-{len(TEAMS)-1}") return team_name = TEAMS[team_idx].get("name", f"Team{team_idx}") self.current_team = team_name await update.message.reply_text(f"Starting task for Team {team_idx}: {team_name}...") # 在后台线程执行任务 loop = asyncio.get_event_loop() self.current_task = loop.run_in_executor( self.executor, self._run_team_task, team_idx ) # 添加完成回调 self.current_task = asyncio.ensure_future(self._wrap_task(self.current_task, team_name)) @admin_only async def cmd_run_all(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """启动处理所有 Team""" if self.current_task and not self.current_task.done(): await update.message.reply_text( f"Task already running: {self.current_team}\nUse /stop to cancel." ) return self.current_team = "ALL" await update.message.reply_text(f"Starting task for ALL teams ({len(TEAMS)} teams)...") loop = asyncio.get_event_loop() self.current_task = loop.run_in_executor( self.executor, self._run_all_teams_task ) self.current_task = asyncio.ensure_future(self._wrap_task(self.current_task, "ALL")) async def _wrap_task(self, task, team_name: str): """包装任务以处理完成通知""" try: result = await task success = sum(1 for r in (result or []) if r.get("status") == "completed") failed = len(result or []) - success await self.notifier.notify_task_completed(team_name, success, failed) except Exception as e: await self.notifier.notify_error(f"Task failed: {team_name}", str(e)) finally: self.current_team = None # 清理进度跟踪 progress_finish() def _run_team_task(self, team_idx: int): """执行单个 Team 任务 (在线程池中运行)""" # 延迟导入避免循环依赖 from run import run_single_team return run_single_team(team_idx) def _run_all_teams_task(self): """执行所有 Team 任务 (在线程池中运行)""" from run import run_all_teams return run_all_teams() @admin_only async def cmd_stop(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """停止当前任务""" if not self.current_task or self.current_task.done(): await update.message.reply_text("No task is running.") return # 注意: 由于任务在线程池中运行,无法直接取消 # 这里只能发送信号 await update.message.reply_text( f"Requesting stop for: {self.current_team}\n" "Note: Current account processing will complete before stopping." ) # 设置全局停止标志 try: import run run._shutdown_requested = True except Exception: pass @admin_only async def cmd_logs(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """查看最近日志""" try: n = int(context.args[0]) if context.args else 10 except ValueError: n = 10 n = min(n, 50) # 限制最大条数 try: from config import BASE_DIR log_file = BASE_DIR / "logs" / "app.log" if not log_file.exists(): await update.message.reply_text("No log file found.") return with open(log_file, "r", encoding="utf-8", errors="ignore") as f: lines = f.readlines() recent = lines[-n:] if len(lines) >= n else lines if not recent: await update.message.reply_text("Log file is empty.") return # 格式化日志 (移除 ANSI 颜色码) import re ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') clean_lines = [ansi_escape.sub('', line.strip()) for line in recent] log_text = "\n".join(clean_lines) if len(log_text) > 4000: log_text = log_text[-4000:] await update.message.reply_text(f"{log_text}", parse_mode="HTML") except Exception as e: await update.message.reply_text(f"Error reading logs: {e}") @admin_only async def cmd_dashboard(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """查看 S2A 仪表盘统计""" if AUTH_PROVIDER != "s2a": await update.message.reply_text( f"Dashboard only available for S2A provider.\n" f"Current provider: {AUTH_PROVIDER}" ) return await update.message.reply_text("Fetching dashboard stats...") try: stats = s2a_get_dashboard_stats() if stats: text = format_dashboard_stats(stats) await update.message.reply_text(text, parse_mode="HTML") else: await update.message.reply_text( "Failed to fetch dashboard stats.\n" "Check S2A configuration and API connection." ) except Exception as e: await update.message.reply_text(f"Error: {e}") @admin_only async def cmd_stock(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """查看账号存货""" if AUTH_PROVIDER != "s2a": await update.message.reply_text( f"Stock check only available for S2A provider.\n" f"Current provider: {AUTH_PROVIDER}" ) return stats = s2a_get_dashboard_stats() if not stats: await update.message.reply_text("Failed to fetch stock info.") return text = self._format_stock_message(stats) await update.message.reply_text(text, parse_mode="HTML") async def scheduled_stock_check(self, context: ContextTypes.DEFAULT_TYPE): """定时检查账号存货""" try: stats = s2a_get_dashboard_stats() if not stats: return normal = stats.get("normal_accounts", 0) total = stats.get("total_accounts", 0) # 只在低库存时发送通知 if normal <= TELEGRAM_LOW_STOCK_THRESHOLD: text = self._format_stock_message(stats, is_alert=True) for chat_id in TELEGRAM_ADMIN_CHAT_IDS: try: await context.bot.send_message( chat_id=chat_id, text=text, parse_mode="HTML" ) except Exception: pass except Exception as e: log.warning(f"Stock check failed: {e}") def _format_stock_message(self, stats: dict, is_alert: bool = False) -> str: """格式化存货消息""" total = stats.get("total_accounts", 0) normal = stats.get("normal_accounts", 0) error = stats.get("error_accounts", 0) ratelimit = stats.get("ratelimit_accounts", 0) overload = stats.get("overload_accounts", 0) # 计算健康度 health_pct = (normal / total * 100) if total > 0 else 0 # 状态图标 if normal <= TELEGRAM_LOW_STOCK_THRESHOLD: status_icon = "LOW STOCK" status_line = f"{status_icon}" elif health_pct >= 80: status_icon = "OK" status_line = f"{status_icon}" elif health_pct >= 50: status_icon = "WARN" status_line = f"{status_icon}" else: status_icon = "CRITICAL" status_line = f"{status_icon}" title = "LOW STOCK ALERT" if is_alert else "Account Stock" lines = [ f"{title}", "", f"Status: {status_line}", f"Health: {health_pct:.1f}%", "", f"Normal: {normal}", f"Error: {error}", f"RateLimit: {ratelimit}", f"Total: {total}", ] if is_alert: lines.append("") lines.append(f"Threshold: {TELEGRAM_LOW_STOCK_THRESHOLD}") return "\n".join(lines) @admin_only async def cmd_import(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """上传账号到 team.json""" # 获取命令后的 JSON 数据 if not context.args: await update.message.reply_text( "Upload Accounts to team.json\n\n" "Usage:\n" "1. Send a JSON file directly\n" "2. /import followed by JSON data\n\n" "JSON format:\n" "[{\"account\":\"email\",\"password\":\"pwd\",\"token\":\"jwt\"},...]\n\n" "After upload, use /run to start processing.", parse_mode="HTML" ) return # 尝试解析 JSON json_text = " ".join(context.args) await self._process_import_json(update, json_text) @admin_only async def handle_json_file(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """处理上传的 JSON 文件""" # 检查是否是管理员 user_id = update.effective_user.id if user_id not in TELEGRAM_ADMIN_CHAT_IDS: await update.message.reply_text("Unauthorized.") return document = update.message.document if not document: return await update.message.reply_text("Processing JSON file...") try: # 下载文件 file = await document.get_file() file_bytes = await file.download_as_bytearray() json_text = file_bytes.decode("utf-8") await self._process_import_json(update, json_text) except Exception as e: await update.message.reply_text(f"Error reading file: {e}") async def _process_import_json(self, update: Update, json_text: str): """处理导入的 JSON 数据,保存到 team.json""" import json from pathlib import Path try: new_accounts = json.loads(json_text) except json.JSONDecodeError as e: await update.message.reply_text(f"Invalid JSON format: {e}") return if not isinstance(new_accounts, list): # 如果是单个对象,转成列表 new_accounts = [new_accounts] if not new_accounts: await update.message.reply_text("No accounts in JSON data") return # 验证格式 valid_accounts = [] for acc in new_accounts: if not isinstance(acc, dict): continue # 支持 account 或 email 字段 email = acc.get("account") or acc.get("email", "") token = acc.get("token", "") password = acc.get("password", "") if email and token: valid_accounts.append({ "account": email, "password": password, "token": token }) if not valid_accounts: await update.message.reply_text("No valid accounts found (need account/email and token)") return # 读取现有 team.json team_json_path = Path(TEAM_JSON_FILE) existing_accounts = [] if team_json_path.exists(): try: with open(team_json_path, "r", encoding="utf-8") as f: existing_accounts = json.load(f) if not isinstance(existing_accounts, list): existing_accounts = [existing_accounts] except Exception: existing_accounts = [] # 检查重复 existing_emails = set() for acc in existing_accounts: email = acc.get("account") or acc.get("user", {}).get("email", "") if email: existing_emails.add(email.lower()) added = 0 skipped = 0 for acc in valid_accounts: email = acc.get("account", "").lower() if email in existing_emails: skipped += 1 else: existing_accounts.append(acc) existing_emails.add(email) added += 1 # 保存到 team.json try: with open(team_json_path, "w", encoding="utf-8") as f: json.dump(existing_accounts, f, ensure_ascii=False, indent=2) await update.message.reply_text( f"Upload Complete\n\n" f"Added: {added}\n" f"Skipped (duplicate): {skipped}\n" f"Total in team.json: {len(existing_accounts)}\n\n" f"Use /run_all or /run <n> to start processing.", parse_mode="HTML" ) except Exception as e: await update.message.reply_text(f"Error saving to team.json: {e}") async def main(): """主函数""" if not TELEGRAM_ENABLED: print("Telegram Bot is disabled. Set telegram.enabled = true in config.toml") sys.exit(1) if not TELEGRAM_BOT_TOKEN: print("Telegram Bot Token not configured. Set telegram.bot_token in config.toml") sys.exit(1) if not TELEGRAM_ADMIN_CHAT_IDS: print("No admin chat IDs configured. Set telegram.admin_chat_ids in config.toml") sys.exit(1) bot = ProvisionerBot() # 处理 Ctrl+C import signal def signal_handler(sig, frame): log.info("Shutting down...") bot.request_shutdown() signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) await bot.start() if __name__ == "__main__": asyncio.run(main())