diff --git a/telegram_bot.py b/telegram_bot.py index c1592e4..b63b52a 100644 --- a/telegram_bot.py +++ b/telegram_bot.py @@ -78,6 +78,18 @@ class ProvisionerBot: self.app: Optional[Application] = None self.notifier: Optional[BotNotifier] = None self._shutdown_event = asyncio.Event() + # JSON 导入批量进度跟踪 + self._import_progress_message = None # 进度消息对象 + self._import_progress_lock = asyncio.Lock() # 并发锁 + self._import_batch_stats = { # 批量统计 + "total_files": 0, + "processed_files": 0, + "total_added": 0, + "total_skipped": 0, + "current_file": "", + "errors": [] + } + self._import_batch_timeout_task = None # 超时任务 async def start(self): """启动 Bot""" @@ -1423,9 +1435,105 @@ class ProvisionerBot: json_text = " ".join(context.args) await self._process_import_json(update, json_text) + def _reset_import_batch_stats(self): + """重置批量导入统计""" + self._import_batch_stats = { + "total_files": 0, + "processed_files": 0, + "total_added": 0, + "total_skipped": 0, + "current_file": "", + "errors": [], + "team_json_total": 0 + } + + def _get_import_progress_text(self, is_processing: bool = True) -> str: + """生成导入进度消息文本""" + stats = self._import_batch_stats + + if is_processing: + lines = [ + "⏳ 正在处理 JSON 文件...", + "", + f"📁 文件: {stats['processed_files']}/{stats['total_files']}", + ] + if stats['current_file']: + lines.append(f"📄 当前: {stats['current_file']}") + lines.extend([ + "", + f"新增: {stats['total_added']}", + f"跳过 (重复): {stats['total_skipped']}", + ]) + else: + # 完成状态 + lines = [ + "✅ 导入完成", + "", + f"📁 处理文件: {stats['processed_files']} 个", + f"📄 已更新 team.json", + f"新增: {stats['total_added']}", + f"跳过 (重复): {stats['total_skipped']}", + f"team.json 总数: {stats['team_json_total']}", + ] + if stats['errors']: + lines.append("") + lines.append(f"⚠️ 错误 ({len(stats['errors'])} 个):") + for err in stats['errors'][:3]: # 最多显示3个错误 + lines.append(f" • {err}") + if len(stats['errors']) > 3: + lines.append(f" ... 还有 {len(stats['errors']) - 3} 个错误") + lines.extend([ + "", + "✅ 配置已自动刷新", + "使用 /run_all 或 /run <n> 开始处理" + ]) + + return "\n".join(lines) + + async def _update_import_progress(self, chat_id: int, is_final: bool = False): + """更新导入进度消息""" + text = self._get_import_progress_text(is_processing=not is_final) + + try: + if self._import_progress_message: + await self.app.bot.edit_message_text( + chat_id=chat_id, + message_id=self._import_progress_message.message_id, + text=text, + parse_mode="HTML" + ) + except Exception: + pass # 忽略编辑失败 + + async def _finalize_import_batch(self, chat_id: int): + """完成批量导入,发送最终结果""" + async with self._import_progress_lock: + if self._import_progress_message is None: + return + + # 取消超时任务 + if self._import_batch_timeout_task: + self._import_batch_timeout_task.cancel() + self._import_batch_timeout_task = None + + # 更新最终进度 + await self._update_import_progress(chat_id, is_final=True) + + # 重置状态 + self._import_progress_message = None + self._reset_import_batch_stats() + + async def _import_batch_timeout(self, chat_id: int, delay: float = 2.0): + """批量导入超时处理 - 在一定时间后自动完成批次""" + try: + await asyncio.sleep(delay) + await self._finalize_import_batch(chat_id) + except asyncio.CancelledError: + pass + @admin_only async def handle_json_file(self, update: Update, context: ContextTypes.DEFAULT_TYPE): - """处理上传的 JSON 文件""" + """处理上传的 JSON 文件 - 支持批量导入进度更新""" # 检查是否是管理员 user_id = update.effective_user.id if user_id not in TELEGRAM_ADMIN_CHAT_IDS: @@ -1436,7 +1544,32 @@ class ProvisionerBot: if not document: return - await update.message.reply_text("⏳ 正在处理 JSON 文件...") + chat_id = update.effective_chat.id + file_name = document.file_name or "unknown.json" + + async with self._import_progress_lock: + # 取消之前的超时任务(如果有) + if self._import_batch_timeout_task: + self._import_batch_timeout_task.cancel() + self._import_batch_timeout_task = None + + # 更新统计 + self._import_batch_stats["total_files"] += 1 + self._import_batch_stats["current_file"] = file_name + + # 如果是新批次,发送初始进度消息 + if self._import_progress_message is None: + self._reset_import_batch_stats() + self._import_batch_stats["total_files"] = 1 + self._import_batch_stats["current_file"] = file_name + + self._import_progress_message = await update.message.reply_text( + self._get_import_progress_text(is_processing=True), + parse_mode="HTML" + ) + else: + # 更新进度消息 + await self._update_import_progress(chat_id) try: # 下载文件 @@ -1444,10 +1577,132 @@ class ProvisionerBot: file_bytes = await file.download_as_bytearray() json_text = file_bytes.decode("utf-8") - await self._process_import_json(update, json_text) + # 处理导入并获取结果 + result = await self._process_import_json_batch(json_text) + + async with self._import_progress_lock: + self._import_batch_stats["processed_files"] += 1 + self._import_batch_stats["total_added"] += result.get("added", 0) + self._import_batch_stats["total_skipped"] += result.get("skipped", 0) + self._import_batch_stats["team_json_total"] = result.get("total", 0) + self._import_batch_stats["current_file"] = "" + + if result.get("error"): + self._import_batch_stats["errors"].append(f"{file_name}: {result['error']}") + + # 更新进度 + await self._update_import_progress(chat_id) + + # 设置超时任务(2秒后如果没有新文件则完成批次) + self._import_batch_timeout_task = asyncio.create_task( + self._import_batch_timeout(chat_id, delay=2.0) + ) except Exception as e: - await update.message.reply_text(f"❌ 读取文件失败: {e}") + async with self._import_progress_lock: + self._import_batch_stats["processed_files"] += 1 + self._import_batch_stats["errors"].append(f"{file_name}: {str(e)}") + self._import_batch_stats["current_file"] = "" + + await self._update_import_progress(chat_id) + + # 设置超时任务 + self._import_batch_timeout_task = asyncio.create_task( + self._import_batch_timeout(chat_id, delay=2.0) + ) + + async def _process_import_json_batch(self, json_text: str) -> dict: + """处理导入的 JSON 数据,保存到 team.json (批量版本,返回结果) + + Returns: + dict: {"added": int, "skipped": int, "total": int, "error": str|None} + """ + import json + from pathlib import Path + + result = {"added": 0, "skipped": 0, "total": 0, "error": None} + + try: + new_accounts = json.loads(json_text) + except json.JSONDecodeError as e: + result["error"] = f"JSON 格式错误: {e}" + return result + + if not isinstance(new_accounts, list): + new_accounts = [new_accounts] + + if not new_accounts: + result["error"] = "JSON 数据中没有账号" + return result + + # 验证格式 + valid_accounts = [] + for acc in new_accounts: + if not isinstance(acc, dict): + continue + email = acc.get("account") or acc.get("email", "") + token = acc.get("token", "") + password = acc.get("password", "") + + if email and token: + valid_accounts.append({ + "account": email, + "password": password, + "token": token + }) + + if not valid_accounts: + result["error"] = "未找到有效账号" + return result + + # 读取现有 team.json + team_json_path = Path(TEAM_JSON_FILE) + existing_accounts = [] + + if team_json_path.exists(): + try: + with open(team_json_path, "r", encoding="utf-8") as f: + existing_accounts = json.load(f) + if not isinstance(existing_accounts, list): + existing_accounts = [existing_accounts] + except Exception: + existing_accounts = [] + + # 检查重复 + existing_emails = set() + for acc in existing_accounts: + email = acc.get("account") or acc.get("user", {}).get("email", "") + if email: + existing_emails.add(email.lower()) + + added = 0 + skipped = 0 + for acc in valid_accounts: + email = acc.get("account", "").lower() + if email in existing_emails: + skipped += 1 + else: + existing_accounts.append(acc) + existing_emails.add(email) + added += 1 + + # 保存到 team.json + try: + team_json_path.parent.mkdir(parents=True, exist_ok=True) + with open(team_json_path, "w", encoding="utf-8") as f: + json.dump(existing_accounts, f, ensure_ascii=False, indent=2) + + # 重载配置 + reload_config() + + result["added"] = added + result["skipped"] = skipped + result["total"] = len(existing_accounts) + + except Exception as e: + result["error"] = f"保存失败: {e}" + + return result async def _process_import_json(self, update: Update, json_text: str): """处理导入的 JSON 数据,保存到 team.json"""