diff --git a/base-config.yaml b/base-config.yaml index 2895772..941800c 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -73,6 +73,8 @@ platforms: max_words: 1000 max_tokens: 2000 max_context_messages: 20 + # 是否开启流式输出(开启后 Element 中消息会逐步更新) + streaming: false xai: url: https://api.x.ai api_key: diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 25f7ba7..917136e 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -124,15 +124,17 @@ class AiBotPlugin(AbsExtraConfigPlugin): await event.mark_read() await self.client.set_typing(event.room_id, timeout=99999) platform = self.get_ai_platform() + + if platform.is_streaming_enabled(): + await self.client.set_typing(event.room_id, timeout=0) + await self._handle_streaming(event, platform) + return + chat_completion = await platform.create_chat_completion(self, event) self.log.debug( f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}") - # ai gpt调用 - # 关闭typing提示 await self.client.set_typing(event.room_id, timeout=0) - # 打开typing提示 if chat_completion.result: - # if hasattr(chat_completion.message, 'content'): resp_content = chat_completion.message['content'] response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML, formatted_body=markdown.render(resp_content)) @@ -150,6 +152,48 @@ class AiBotPlugin(AbsExtraConfigPlugin): return None + async def _handle_streaming(self, evt: MessageEvent, platform) -> None: + # 发送初始占位消息 + placeholder = TextMessageEventContent( + msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌" + ) + response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread']) + + accumulated = "" + last_edit_len = 0 + EDIT_THRESHOLD = 50 # 每积累50个字符更新一次消息 + + try: + async for chunk in platform.create_chat_completion_stream(self, evt): + accumulated += chunk + if len(accumulated) - last_edit_len >= EDIT_THRESHOLD: + display = accumulated + " ▌" + new_content = TextMessageEventContent( + msgtype=MessageType.TEXT, + body=display, + format=Format.HTML, + formatted_body=markdown.render(display) + ) + new_content.set_edit(response_event_id) + await self.client.send_message(evt.room_id, new_content) + last_edit_len = len(accumulated) + except Exception as e: + self.log.exception(f"Streaming error: {e}") + if not accumulated: + accumulated = f"Streaming error: {e}" + + # 输出最终完整内容 + if not accumulated: + accumulated = "(无响应)" + final_content = TextMessageEventContent( + msgtype=MessageType.TEXT, + body=accumulated, + format=Format.HTML, + formatted_body=markdown.render(accumulated) + ) + final_content.set_edit(response_event_id) + await self.client.send_message(evt.room_id, final_content) + def get_ai_platform(self) -> Platform: use_platform = self.config.cur_platform if use_platform == 'openai': diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 91c8398..bff1501 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -1,7 +1,7 @@ import json from collections import deque from datetime import datetime -from typing import Optional, List, Generator +from typing import Optional, List, Generator, AsyncIterator from aiohttp import ClientSession from maubot import Plugin @@ -55,6 +55,12 @@ class Platform: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: raise NotImplementedError() + async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> AsyncIterator[str]: + raise NotImplementedError() + + def is_streaming_enabled(self) -> bool: + return False + async def list_models(self) -> List[str]: raise NotImplementedError() diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index 6c4399f..7f0bddd 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -129,10 +129,22 @@ class OpenAi(Platform): class Anthropic(Platform): max_tokens: int + streaming: bool def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) self.max_tokens = self.config['max_tokens'] + self.streaming = self.config.get('streaming', False) + + def is_streaming_enabled(self) -> bool: + return self.streaming + + def _build_request(self, full_chat_context: list) -> tuple: + endpoint = f"{self.url}/v1/messages" + headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"} + req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt, + "messages": full_chat_context} + return endpoint, headers, req_body async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: full_chat_context = [] @@ -140,10 +152,7 @@ class Anthropic(Platform): chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt) full_chat_context.extend(list(chat_context)) - endpoint = f"{self.url}/v1/messages" - headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"} - req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt, - "messages": full_chat_context} + endpoint, headers, req_body = self._build_request(full_chat_context) async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response: # plugin.log.debug(f"响应内容:{response.status}, {await response.json()}") @@ -162,7 +171,34 @@ class Anthropic(Platform): finish_reason=response_json['stop_reason'], model=response_json['model'] ) - pass + + async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent): + full_chat_context = [] + system_context = deque() + chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt) + full_chat_context.extend(list(chat_context)) + + endpoint, headers, req_body = self._build_request(full_chat_context) + req_body["stream"] = True + + async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response: + if response.status != 200: + raise ValueError(f"Error: {await response.text()}") + async for line_bytes in response.content: + line = line_bytes.decode("utf-8").strip() + if not line.startswith("data: "): + continue + data_str = line[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + if data.get("type") == "content_block_delta": + delta = data.get("delta", {}) + if delta.get("type") == "text_delta": + yield delta.get("text", "") + except json.JSONDecodeError: + pass async def list_models(self) -> List[str]: # 调用openai接口获取模型列表