From 9199aacc69f2af9052a363be1dcc734a84e71c86 Mon Sep 17 00:00:00 2001 From: taylor Date: Sun, 13 Oct 2024 16:46:13 +0800 Subject: [PATCH] =?UTF-8?q?add:=20=E6=B7=BB=E5=8A=A0ollama=E8=B0=83?= =?UTF-8?q?=E7=94=A8AI=20chat=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- maubot_llmplus/aibot.py | 2 +- maubot_llmplus/local_paltform.py | 7 ++++--- maubot_llmplus/platforms.py | 2 +- maubot_llmplus/thrid_platform.py | 5 +++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 34dab4d..08bb2b1 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -116,7 +116,7 @@ class AiBotPlugin(Plugin): await event.mark_read() await self.client.set_typing(event.room_id, timeout=99999) platform = self.get_ai_platform() - chat_completion = await platform.create_chat_completion(event) + chat_completion = await platform.create_chat_completion(self, event) # ai gpt调用 # 关闭typing提示 await self.client.set_typing(event.room_id, timeout=0) diff --git a/maubot_llmplus/local_paltform.py b/maubot_llmplus/local_paltform.py index 8302cd4..81d12ad 100644 --- a/maubot_llmplus/local_paltform.py +++ b/maubot_llmplus/local_paltform.py @@ -1,6 +1,7 @@ import json from aiohttp import ClientSession +from maubot import Plugin from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig @@ -16,9 +17,9 @@ class Ollama(Platform): super().__init__(config, http) self.chat_api = '/api/chat' - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: full_context = [] - context = maubot_llmplus.platforms.get_context(evt) + context = maubot_llmplus.platforms.get_context(plugin, evt) full_context.extend(list(context)) endpoint = f"{self.url}/api/chat" @@ -50,5 +51,5 @@ class LmStudio(Platform): super().__init__(config, http) pass - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: pass diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 719f1f2..85bb2f5 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -49,7 +49,7 @@ class Platform: 调用AI对话接口, 响应结果 """ - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: raise NotImplementedError() def get_type(self) -> str: diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index 5b7c39e..8dd03cb 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -1,4 +1,5 @@ from aiohttp import ClientSession +from maubot import Plugin from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig @@ -10,7 +11,7 @@ class OpenAi(Platform): def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 # 获取额外的其他角色的提示词: role: user role: system @@ -25,7 +26,7 @@ class Anthropic(Platform): def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) - async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 # 获取额外的其他角色的提示词: role: user role: system