add
This commit is contained in:
@@ -73,6 +73,8 @@ platforms:
|
|||||||
max_words: 1000
|
max_words: 1000
|
||||||
max_tokens: 2000
|
max_tokens: 2000
|
||||||
max_context_messages: 20
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
xai:
|
xai:
|
||||||
url: https://api.x.ai
|
url: https://api.x.ai
|
||||||
api_key:
|
api_key:
|
||||||
|
|||||||
@@ -124,15 +124,17 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
await event.mark_read()
|
await event.mark_read()
|
||||||
await self.client.set_typing(event.room_id, timeout=99999)
|
await self.client.set_typing(event.room_id, timeout=99999)
|
||||||
platform = self.get_ai_platform()
|
platform = self.get_ai_platform()
|
||||||
|
|
||||||
|
if platform.is_streaming_enabled():
|
||||||
|
await self.client.set_typing(event.room_id, timeout=0)
|
||||||
|
await self._handle_streaming(event, platform)
|
||||||
|
return
|
||||||
|
|
||||||
chat_completion = await platform.create_chat_completion(self, event)
|
chat_completion = await platform.create_chat_completion(self, event)
|
||||||
self.log.debug(
|
self.log.debug(
|
||||||
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
||||||
# ai gpt调用
|
|
||||||
# 关闭typing提示
|
|
||||||
await self.client.set_typing(event.room_id, timeout=0)
|
await self.client.set_typing(event.room_id, timeout=0)
|
||||||
# 打开typing提示
|
|
||||||
if chat_completion.result:
|
if chat_completion.result:
|
||||||
# if hasattr(chat_completion.message, 'content'):
|
|
||||||
resp_content = chat_completion.message['content']
|
resp_content = chat_completion.message['content']
|
||||||
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||||
formatted_body=markdown.render(resp_content))
|
formatted_body=markdown.render(resp_content))
|
||||||
@@ -150,6 +152,48 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
|
||||||
|
# 发送初始占位消息
|
||||||
|
placeholder = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌"
|
||||||
|
)
|
||||||
|
response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread'])
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
last_edit_len = 0
|
||||||
|
EDIT_THRESHOLD = 50 # 每积累50个字符更新一次消息
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in platform.create_chat_completion_stream(self, evt):
|
||||||
|
accumulated += chunk
|
||||||
|
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
|
||||||
|
display = accumulated + " ▌"
|
||||||
|
new_content = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT,
|
||||||
|
body=display,
|
||||||
|
format=Format.HTML,
|
||||||
|
formatted_body=markdown.render(display)
|
||||||
|
)
|
||||||
|
new_content.set_edit(response_event_id)
|
||||||
|
await self.client.send_message(evt.room_id, new_content)
|
||||||
|
last_edit_len = len(accumulated)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception(f"Streaming error: {e}")
|
||||||
|
if not accumulated:
|
||||||
|
accumulated = f"Streaming error: {e}"
|
||||||
|
|
||||||
|
# 输出最终完整内容
|
||||||
|
if not accumulated:
|
||||||
|
accumulated = "(无响应)"
|
||||||
|
final_content = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT,
|
||||||
|
body=accumulated,
|
||||||
|
format=Format.HTML,
|
||||||
|
formatted_body=markdown.render(accumulated)
|
||||||
|
)
|
||||||
|
final_content.set_edit(response_event_id)
|
||||||
|
await self.client.send_message(evt.room_id, final_content)
|
||||||
|
|
||||||
def get_ai_platform(self) -> Platform:
|
def get_ai_platform(self) -> Platform:
|
||||||
use_platform = self.config.cur_platform
|
use_platform = self.config.cur_platform
|
||||||
if use_platform == 'openai':
|
if use_platform == 'openai':
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List, Generator
|
from typing import Optional, List, Generator, AsyncIterator
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from maubot import Plugin
|
from maubot import Plugin
|
||||||
@@ -55,6 +55,12 @@ class Platform:
|
|||||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> AsyncIterator[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
async def list_models(self) -> List[str]:
|
async def list_models(self) -> List[str]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -129,10 +129,22 @@ class OpenAi(Platform):
|
|||||||
|
|
||||||
class Anthropic(Platform):
|
class Anthropic(Platform):
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
streaming: bool
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
super().__init__(config, http)
|
super().__init__(config, http)
|
||||||
self.max_tokens = self.config['max_tokens']
|
self.max_tokens = self.config['max_tokens']
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
def _build_request(self, full_chat_context: list) -> tuple:
|
||||||
|
endpoint = f"{self.url}/v1/messages"
|
||||||
|
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
|
||||||
|
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
|
||||||
|
"messages": full_chat_context}
|
||||||
|
return endpoint, headers, req_body
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
full_chat_context = []
|
full_chat_context = []
|
||||||
@@ -140,10 +152,7 @@ class Anthropic(Platform):
|
|||||||
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||||
full_chat_context.extend(list(chat_context))
|
full_chat_context.extend(list(chat_context))
|
||||||
|
|
||||||
endpoint = f"{self.url}/v1/messages"
|
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||||
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
|
|
||||||
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
|
|
||||||
"messages": full_chat_context}
|
|
||||||
|
|
||||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||||
@@ -162,7 +171,34 @@ class Anthropic(Platform):
|
|||||||
finish_reason=response_json['stop_reason'],
|
finish_reason=response_json['stop_reason'],
|
||||||
model=response_json['model']
|
model=response_json['model']
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_chat_context = []
|
||||||
|
system_context = deque()
|
||||||
|
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||||
|
full_chat_context.extend(list(chat_context))
|
||||||
|
|
||||||
|
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||||
|
req_body["stream"] = True
|
||||||
|
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
async for line_bytes in response.content:
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
if data.get("type") == "content_block_delta":
|
||||||
|
delta = data.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
yield delta.get("text", "")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def list_models(self) -> List[str]:
|
async def list_models(self) -> List[str]:
|
||||||
# 调用openai接口获取模型列表
|
# 调用openai接口获取模型列表
|
||||||
|
|||||||
Reference in New Issue
Block a user