This commit is contained in:
taylorxie
2026-03-09 22:43:02 +08:00
parent b53a918aaa
commit 9f25fdab12
4 changed files with 98 additions and 10 deletions

View File

@@ -73,6 +73,8 @@ platforms:
max_words: 1000
max_tokens: 2000
max_context_messages: 20
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
streaming: false
xai:
url: https://api.x.ai
api_key:

View File

@@ -124,15 +124,17 @@ class AiBotPlugin(AbsExtraConfigPlugin):
await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999)
platform = self.get_ai_platform()
if platform.is_streaming_enabled():
await self.client.set_typing(event.room_id, timeout=0)
await self._handle_streaming(event, platform)
return
chat_completion = await platform.create_chat_completion(self, event)
self.log.debug(
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
# ai gpt调用
# 关闭typing提示
await self.client.set_typing(event.room_id, timeout=0)
# 打开typing提示
if chat_completion.result:
# if hasattr(chat_completion.message, 'content'):
resp_content = chat_completion.message['content']
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
@@ -150,6 +152,48 @@ class AiBotPlugin(AbsExtraConfigPlugin):
return None
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
# 发送初始占位消息
placeholder = TextMessageEventContent(
msgtype=MessageType.TEXT, body="", format=Format.HTML, formatted_body=""
)
response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread'])
accumulated = ""
last_edit_len = 0
EDIT_THRESHOLD = 50 # 每积累50个字符更新一次消息
try:
async for chunk in platform.create_chat_completion_stream(self, evt):
accumulated += chunk
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
display = accumulated + ""
new_content = TextMessageEventContent(
msgtype=MessageType.TEXT,
body=display,
format=Format.HTML,
formatted_body=markdown.render(display)
)
new_content.set_edit(response_event_id)
await self.client.send_message(evt.room_id, new_content)
last_edit_len = len(accumulated)
except Exception as e:
self.log.exception(f"Streaming error: {e}")
if not accumulated:
accumulated = f"Streaming error: {e}"
# 输出最终完整内容
if not accumulated:
accumulated = "(无响应)"
final_content = TextMessageEventContent(
msgtype=MessageType.TEXT,
body=accumulated,
format=Format.HTML,
formatted_body=markdown.render(accumulated)
)
final_content.set_edit(response_event_id)
await self.client.send_message(evt.room_id, final_content)
def get_ai_platform(self) -> Platform:
use_platform = self.config.cur_platform
if use_platform == 'openai':

View File

@@ -1,7 +1,7 @@
import json
from collections import deque
from datetime import datetime
from typing import Optional, List, Generator
from typing import Optional, List, Generator, AsyncIterator
from aiohttp import ClientSession
from maubot import Plugin
@@ -55,6 +55,12 @@ class Platform:
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError()
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> AsyncIterator[str]:
raise NotImplementedError()
def is_streaming_enabled(self) -> bool:
return False
async def list_models(self) -> List[str]:
raise NotImplementedError()

View File

@@ -129,10 +129,22 @@ class OpenAi(Platform):
class Anthropic(Platform):
max_tokens: int
streaming: bool
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
super().__init__(config, http)
self.max_tokens = self.config['max_tokens']
self.streaming = self.config.get('streaming', False)
def is_streaming_enabled(self) -> bool:
return self.streaming
def _build_request(self, full_chat_context: list) -> tuple:
endpoint = f"{self.url}/v1/messages"
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
"messages": full_chat_context}
return endpoint, headers, req_body
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
full_chat_context = []
@@ -140,10 +152,7 @@ class Anthropic(Platform):
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
full_chat_context.extend(list(chat_context))
endpoint = f"{self.url}/v1/messages"
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
"messages": full_chat_context}
endpoint, headers, req_body = self._build_request(full_chat_context)
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
@@ -162,6 +171,33 @@ class Anthropic(Platform):
finish_reason=response_json['stop_reason'],
model=response_json['model']
)
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
full_chat_context = []
system_context = deque()
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
full_chat_context.extend(list(chat_context))
endpoint, headers, req_body = self._build_request(full_chat_context)
req_body["stream"] = True
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
if response.status != 200:
raise ValueError(f"Error: {await response.text()}")
async for line_bytes in response.content:
line = line_bytes.decode("utf-8").strip()
if not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
if data.get("type") == "content_block_delta":
delta = data.get("delta", {})
if delta.get("type") == "text_delta":
yield delta.get("text", "")
except json.JSONDecodeError:
pass
async def list_models(self) -> List[str]: