From 5821623aae173435bc2eb1637d35564be3a7bce9 Mon Sep 17 00:00:00 2001 From: taylor Date: Sun, 13 Oct 2024 17:07:42 +0800 Subject: [PATCH] =?UTF-8?q?add:=20=E6=B7=BB=E5=8A=A0ollama=E8=B0=83?= =?UTF-8?q?=E7=94=A8AI=20chat=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- maubot_llmplus/aibot.py | 13 +++++++------ maubot_llmplus/local_paltform.py | 10 +++++----- maubot_llmplus/platforms.py | 4 +++- maubot_llmplus/thrid_platform.py | 8 ++++---- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 6b2a57e..f76813f 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -15,7 +15,6 @@ from maubot_llmplus.thrid_platform import OpenAi, Anthropic 配置文件加载 """ - class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("allowed_users") @@ -30,12 +29,14 @@ class Config(BaseProxyConfig): class AiBotPlugin(Plugin): + name: str + async def start(self) -> None: await super().start() # 加载并更新配置 self.config.load_and_update() # 决定当前机器人的名称 - super.name = self.config['name'] or \ + self.name = self.config['name'] or \ await self.client.get_displayname(self.client.mxid) or \ self.client.parse_user_id(self.client.mxid)[0] @@ -136,15 +137,15 @@ class AiBotPlugin(Plugin): if use_platform == 'local_ai': type = self.config['platforms']['local_ai']['type'] if type == 'ollama': - return Ollama(self.config, self.http) + return Ollama(self.config, self.name, self.http) elif type == 'lmstudio': - return LmStudio(self.config, self.http) + return LmStudio(self.config, self.name, self.http) else: raise ValueError(f"not found platform type: {type}") if use_platform == 'openai': - return OpenAi(self.config, self.http) + return OpenAi(self.config, self.name, self.http) if use_platform == 'anthropic': - return Anthropic(self.config, self.http) + return Anthropic(self.config, self.name, self.http) raise ValueError(f"unknown backend type {use_platform}") @classmethod diff --git a/maubot_llmplus/local_paltform.py b/maubot_llmplus/local_paltform.py index b328b88..aa13b6b 100644 --- a/maubot_llmplus/local_paltform.py +++ b/maubot_llmplus/local_paltform.py @@ -13,13 +13,13 @@ from maubot_llmplus.platforms import Platform, ChatCompletion class Ollama(Platform): chat_api: str - def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: - super().__init__(config, http) + def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None: + super().__init__(config, name, http) self.chat_api = '/api/chat' async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: full_context = [] - context = await maubot_llmplus.platforms.get_context(plugin, evt) + context = await maubot_llmplus.platforms.get_context(plugin, self, evt) full_context.extend(list(context)) endpoint = f"{self.url}/api/chat" @@ -47,8 +47,8 @@ class Ollama(Platform): class LmStudio(Platform): - def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: - super().__init__(config, http) + def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None: + super().__init__(config, name, http) pass async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 930296c..11bfe03 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -33,8 +33,9 @@ class Platform: additional_prompt: List[dict] system_prompt: str max_context_messages: int + name: str - def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: + def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None: self.http = http self.config = config['platforms'][self.get_type()] self.url = self.config['url'] @@ -44,6 +45,7 @@ class Platform: self.max_context_messages = self.config['max_context_messages'] self.additional_prompt = config['additional_prompt'] self.system_prompt = config['system_prompt'] + self.name = name """a 调用AI对话接口, 响应结果 diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index 8dd03cb..d7691e9 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -8,8 +8,8 @@ from maubot_llmplus.platforms import Platform, ChatCompletion class OpenAi(Platform): - def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: - super().__init__(config, http) + def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None: + super().__init__(config, name, http) async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 @@ -23,8 +23,8 @@ class OpenAi(Platform): class Anthropic(Platform): - def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: - super().__init__(config, http) + def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None: + super().__init__(config, name, http) async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词