diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index cad42f1..23a7758 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -9,6 +9,7 @@ from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from maubot_llmplus.local_paltform import Ollama, LmStudio from maubot_llmplus.platforms import Platform +from maubot_llmplus.plugin import AbsExtraConfigPlugin from maubot_llmplus.thrid_platform import OpenAi, Anthropic """ @@ -27,24 +28,13 @@ class Config(BaseProxyConfig): helper.copy("platforms") helper.copy("additional_prompt") -class AbsAiBotPlugin(Plugin): - default_username: str - user_id: str - def get_bot_name(self) -> str: - return self.config['name'] or \ - self.default_username or \ - self.user_id - - -class AiBotPlugin(AbsAiBotPlugin): +class AiBotPlugin(AbsExtraConfigPlugin): async def start(self) -> None: await super().start() # 加载并更新配置 self.config.load_and_update() - self.default_username = await self.client.get_displayname(self.client.mxid) - self.user_id = self.client.parse_user_id(self.client.mxid)[0] """ 判断sender是否是allowed_users中的成员 @@ -115,6 +105,7 @@ class AiBotPlugin(AbsAiBotPlugin): if parent_event.sender == self.client.mxid: return True + @event.on(EventType.ROOM_MESSAGE) async def on_message(self, event: MessageEvent) -> None: if not await self.should_respond(event): diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 55f0ba3..89291c4 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -8,7 +8,7 @@ from maubot import Plugin from mautrix.types import MessageEvent, EncryptedEvent from mautrix.util.config import BaseProxyConfig -from maubot_llmplus.aibot import AbsAiBotPlugin +from maubot_llmplus.plugin import AbsExtraConfigPlugin """ AI响应对象 @@ -62,7 +62,7 @@ class Platform: -async def get_context(plugin: AbsAiBotPlugin, platform: Platform, evt: MessageEvent) -> deque: +async def get_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent) -> deque: # 创建系统提示词上下文 system_context = deque() # 生成当前时间 diff --git a/maubot_llmplus/plugin.py b/maubot_llmplus/plugin.py new file mode 100644 index 0000000..3b92ad7 --- /dev/null +++ b/maubot_llmplus/plugin.py @@ -0,0 +1,16 @@ +from maubot import Plugin + + +class AbsExtraConfigPlugin(Plugin): + default_username: str + user_id: str + + async def start(self) -> None: + await super().start() + self.default_username = await self.client.get_displayname(self.client.mxid) + self.user_id = self.client.parse_user_id(self.client.mxid)[0] + + def get_bot_name(self) -> str: + return self.config['name'] or \ + self.default_username or \ + self.user_id \ No newline at end of file