Files
maubot-llmplus/maubot_llmplus/aibot.py
2026-03-10 11:01:44 +08:00

396 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import re
from typing import Type
from maubot.handlers import command, event
from maubot import Plugin, MessageEvent
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
from mautrix.util import markdown
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from maubot_llmplus.local_paltform import Ollama, LmStudio
from maubot_llmplus.platforms import Platform
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi, Deepseek, Gemini, Qwen
class AiBotPlugin(AbsExtraConfigPlugin):
async def start(self) -> None:
await super().start()
# 加载并更新配置
self.config.load_and_update()
"""
判断sender是否是allowed_users中的成员
如果是, 则可以发送消息给AI
"""
def is_allow(self, sender: str) -> bool:
# 如果列表中没有元素, 直接返回True
if len(self.config['allowed_users']) <= 0:
return True
for u in self.config['allowed_users']:
self.log.debug(f"bot: {sender} -> {u}")
# 如果sender是allowed_user中的一员, 那么就允许发送消息给AI
if re.match(u, sender):
return True
self.log.debug(f"{sender} doesn't match allowed_users")
pass
def is_allow_command(self, sender: str, command_key: str) -> bool:
allow_users = self.config[command_key]
# 如果一个都没有配置,都没有权限可以执行更新命令
if len(allow_users) <= 0:
return False
# sender是否是配置中的一员, 如果是就允许进行命令修改执行,否则只有可读命令的执行
for u in allow_users:
if re.match(u, sender):
return True
self.log.debug(f"{sender} doesn't match {command_key}")
pass
def is_allow_update_read_command(self, sender: str) -> bool:
return self.is_allow_command(sender, "allow_update_read_command_users")
def is_allow_readonly_command(self, sender: str) -> bool:
is_update_read = self.is_allow_update_read_command(sender)
# 如果读写都有权限就一定会有读权限返回True
if is_update_read:
return True
# 如果没有读写权限,需要判断是否有只读权限
return self.is_allow_command(sender, "allow_readonly_command_users")
"""
判断是否应该让AI进行回应
回应条件:
前提条件:
消息发送者不是机器人本身 or 不是编辑消息 or 不是消息类型
1. @AI机器人时
2. 消息中呼唤 name变量的值的时候
3. 回复机器人消息时
4. 当聊天室中只有两个人, 并且其中一个是机器人时
5. 在thread中
"""
async def should_respond(self, event: MessageEvent) -> bool:
# 发送者是机器人本身, 返回False
if event.sender == self.client.mxid:
return False
# 如果发送的消息中,第一个字符是感叹号,不进行回复
if event.content.body[0] == '!':
return False
# 判断这个用户是否在允许列表中, 不存在返回False
# 如果列表为空, 继续往下执行
if not self.is_allow(event.sender):
return False
# 不是编辑消息 or 不是消息类型, 返回false
if (event.content['msgtype'] != MessageType.TEXT or
event.content.relates_to['rel_type'] == RelationType.REPLACE):
return False
# 检查是否发送消息中有带上机器人的别名
if re.search("(^|\\s)(@)?" + self.get_bot_name() + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
return True
# 当聊天室只有两个人并且其中一个是机器人时
if len(await self.client.get_joined_members(event.room_id)) == 2:
return True
# 在thread中时
if self.config['reply_in_thread'] and event.content.relates_to.rel_type == RelationType.THREAD:
parent_event = await self.client.get_event(room_id=event.room_id,
event_id=event.content.get_thread_parent())
return await self.should_respond(parent_event)
# 如果是回复消息
if event.content.relates_to.in_reply_to:
parent_event = await self.client.get_event(room_id=event.room_id, event_id=event.content.get_reply_to())
if parent_event.sender == self.client.mxid:
return True
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
if not await self.should_respond(event):
return
try:
await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999)
platform = self.get_ai_platform()
if platform.is_streaming_enabled():
await self._handle_streaming(event, platform)
return
chat_completion = await platform.create_chat_completion(self, event)
self.log.debug(
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
await self.client.set_typing(event.room_id, timeout=0)
if chat_completion.result:
resp_content = chat_completion.message['content']
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread'])
else:
resp_content = "调用失败,请检查: " + chat_completion.finish_reason
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread'])
except Exception as e:
self.log.exception(f"Something went wrong: {e}")
await event.respond(f"Something went wrong: {e}")
pass
return None
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
# 发送初始占位消息typing 保持 on让用户知道正在处理
placeholder = TextMessageEventContent(
msgtype=MessageType.TEXT, body="", format=Format.HTML, formatted_body=""
)
response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread'])
self.log.debug("Streaming: placeholder sent")
accumulated = ""
last_edit_len = 0
first_chunk = True
EDIT_THRESHOLD = 100 # 每积累100个字符更新一次消息
async def send_edit(content: TextMessageEventContent) -> None:
"""顺序发送编辑消息用shield确保send_message不被cancel保护mautrix内部锁"""
send_task = asyncio.ensure_future(self.client.send_message(evt.room_id, content))
try:
# shield防止wait_for超时cancel send_task本身避免mautrix锁残留
await asyncio.wait_for(asyncio.shield(send_task), timeout=8.0)
except asyncio.TimeoutError:
self.log.debug("Streaming: edit wait_for timed out, awaiting task completion")
await send_task # send_task仍在运行等待自然完成
except Exception as e:
self.log.warning(f"Streaming: edit error: {e}")
if not send_task.done():
await send_task
try:
async for chunk in platform.create_chat_completion_stream(self, evt):
if first_chunk:
# 收到第一个 chunk 才关掉 typing此前 typing 持续显示(解决高 TTFT 卡顿感)
await self.client.set_typing(evt.room_id, timeout=0)
first_chunk = False
accumulated += chunk
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
display = accumulated + ""
new_content = TextMessageEventContent(
msgtype=MessageType.TEXT,
body=display,
format=Format.HTML,
formatted_body=markdown.render(display)
)
new_content.set_edit(response_event_id)
self.log.debug(f"Streaming: mid-edit, accumulated={len(accumulated)}")
await send_edit(new_content)
last_edit_len = len(accumulated)
except Exception as e:
self.log.exception(f"Streaming error: {e}")
if not accumulated:
accumulated = f"Streaming error: {e}"
finally:
# 确保无论如何 typing 都会关掉(含未收到任何 chunk 的情况)
if first_chunk:
await self.client.set_typing(evt.room_id, timeout=0)
self.log.debug(f"Streaming: loop done, total={len(accumulated)}")
# 输出最终完整内容
if not accumulated:
accumulated = "(无响应)"
final_content = TextMessageEventContent(
msgtype=MessageType.TEXT,
body=accumulated,
format=Format.HTML,
formatted_body=markdown.render(accumulated)
)
final_content.set_edit(response_event_id)
self.log.debug("Streaming: sending final edit")
await send_edit(final_content)
self.log.debug("Streaming: final edit done")
def get_ai_platform(self) -> Platform:
use_platform = self.config.cur_platform
if use_platform == 'openai':
return OpenAi(self.config, self.http)
if use_platform == 'anthropic':
return Anthropic(self.config, self.http)
if use_platform == 'xai':
return XAi(self.config, self.http)
if use_platform == 'deepseek':
return Deepseek(self.config, self.http)
if use_platform == 'gemini':
return Gemini(self.config, self.http)
if use_platform == 'qwen':
return Qwen(self.config, self.http)
if use_platform == 'local_ai#ollama':
return Ollama(self.config, self.http)
if use_platform == 'local_ai#lmstudio':
return LmStudio(self.config, self.http)
else:
raise ValueError(f"not found platform type: {type}")
"""
父命令
"""
@command.new(name="ai", require_subcommand=True)
async def ai_command(self, event: MessageEvent) -> None:
pass
"""
"""
@ai_command.subcommand(help="View the configuration information currently in official use")
async def info(self, event: MessageEvent) -> None:
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
is_allow = self.is_allow_readonly_command(event.sender)
if not is_allow:
await event.reply(f"{event.sender} have not read permission")
return
show_infos = []
# 当前机器人名称
show_infos.append(f"bot name: {self.get_bot_name()}\n\n")
# 查询当前使用的ai平台
show_infos.append(f"platform: {self.get_cur_platform()}\n\n")
show_infos.append("platform detail: \n\n")
# 查询当前ai平台的配置信息
p_m_dict = self.get_config()
for k, v in p_m_dict.items():
show_infos.append(f"- {k}: {v}\n")
# 当前使用的model
show_infos.append(f"\nmodel: {self.config.cur_model}\n")
# TODO 列出model信息
await event.reply("".join(show_infos), markdown=True)
pass
"""
获取配置信息
"""
def get_config(self) -> dict:
platform_config_dict = dict(self.config['platforms'][self.get_cur_platform()])
# 移除敏感配置
platform_config_dict.pop('api_key')
platform_config_dict.pop('url')
return platform_config_dict
"""
获取实际平台名称
"""
def get_cur_platform(self) -> str:
platform_model = self.config.cur_platform
return platform_model.split('#')[0]
@ai_command.subcommand(help="List platforms or query current platform in use")
@command.argument("argus")
async def platform(self, event: MessageEvent, argus: str):
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
is_allow = self.is_allow_readonly_command(event.sender)
if not is_allow:
await event.reply(f"{event.sender} have not read permission")
return
if argus == 'list':
p_dict = dict(self.config['platforms'])
platforms = [f"- {platform}" for platform in set(p_dict.keys())]
await event.reply("\n".join(platforms))
pass
if argus == 'current':
await event.reply(f"current use platform is {self.config.cur_platform}")
pass
@ai_command.subcommand(help="List models or query current model in use")
@command.argument("argus")
async def model(self, event: MessageEvent, argus: str):
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
is_allow = self.is_allow_readonly_command(event.sender)
if not is_allow:
await event.reply(f"{event.sender} have not read permission")
return
# 如果是list表示查看当前可以使用的模型列表
if argus == 'list':
platform = self.get_ai_platform()
models = await platform.list_models()
await event.reply("\n".join(models), markdown=True)
pass
# 如果是current显示出当前的使用模型
if argus == 'current':
await event.reply(f"current use model is {self.config.cur_model}")
pass
@ai_command.subcommand(help="switch model in platform")
@command.argument("argus")
async def use(self, event: MessageEvent, argus: str):
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
is_allow = self.is_allow_update_read_command(event.sender)
if not is_allow:
await event.reply(f"{event.sender} have not update permission")
return
platform = self.get_ai_platform()
# 获取模型列表,判断使用的模型是否存在于列表中
models = await platform.list_models()
if f"- {argus}" in models:
self.log.debug(f"switch model: {argus}")
self.config.cur_model = argus
await event.react("")
else:
await event.reply("not found valid model")
@ai_command.subcommand(help="switch platform")
@command.argument("argus")
async def switch(self, event: MessageEvent, argus: str):
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
is_allow = self.is_allow_update_read_command(event.sender)
if not is_allow:
await event.reply(f"{event.sender} have not update permission")
return
# 判断是否是本地ai模型如果是还需要解析#后的type
if argus == 'local_ai':
await event.reply("local ai platform has ollama and lmstudio. "
"you can type `!ai use local_ai#{type}`. "
"Example: local_ai#ollama")
pass
if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio':
if argus == self.config.cur_platform:
await event.reply(f"current ai platform has be {argus}")
pass
else:
self.config.cur_platform = argus
self.config.cur_model = self.config['platforms'][argus.split("#")[0]]['model']
await event.react("")
# 如果是openai或者是claude
elif argus == 'openai' or argus == 'anthropic' or argus == 'xai' or argus == 'deepseek' or argus == 'gemini' or argus == 'qwen':
if argus == self.config.cur_platform:
await event.reply(f"current ai platform has be {argus}")
pass
else:
self.config.cur_platform = argus
# 使用配置的默认模型
self.config.cur_model = self.config['platforms'][argus]['model']
await event.react("")
else:
await event.reply(f"nof found ai platform: {argus}")
pass
self.log.debug(f"switch platform: {self.config.cur_platform}")
self.log.debug(f"use default config model: {self.config.cur_model}")
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config