diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 34dab4d..08bb2b1 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -116,7 +116,7 @@ class AiBotPlugin(Plugin): await event.mark_read() await self.client.set_typing(event.room_id, timeout=99999) platform = self.get_ai_platform() - chat_completion = await platform.create_chat_completion(event) + chat_completion = await platform.create_chat_completion(self, event) # ai gpt调用 # 关闭typing提示 await self.client.set_typing(event.room_id, timeout=0) diff --git a/maubot_llmplus/local_paltform.py b/maubot_llmplus/local_paltform.py index 8302cd4..81d12ad 100644 --- a/maubot_llmplus/local_paltform.py +++ b/maubot_llmplus/local_paltform.py @@ -1,6 +1,7 @@ import json from aiohttp import ClientSession +from maubot import Plugin from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig @@ -16,9 +17,9 @@ class Ollama(Platform): super().__init__(config, http) self.chat_api = '/api/chat' - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: full_context = [] - context = maubot_llmplus.platforms.get_context(evt) + context = maubot_llmplus.platforms.get_context(plugin, evt) full_context.extend(list(context)) endpoint = f"{self.url}/api/chat" @@ -50,5 +51,5 @@ class LmStudio(Platform): super().__init__(config, http) pass - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: pass diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 719f1f2..85bb2f5 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -49,7 +49,7 @@ class Platform: 调用AI对话接口, 响应结果 """ - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: raise NotImplementedError() def get_type(self) -> str: diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index 5b7c39e..8dd03cb 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -1,4 +1,5 @@ from aiohttp import ClientSession +from maubot import Plugin from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig @@ -10,7 +11,7 @@ class OpenAi(Platform): def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 # 获取额外的其他角色的提示词: role: user role: system @@ -25,7 +26,7 @@ class Anthropic(Platform): def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 # 获取额外的其他角色的提示词: role: user role: system