From e5d4e52bc082d23c54b6be992d807a74a31677f0 Mon Sep 17 00:00:00 2001 From: taylor Date: Sun, 13 Oct 2024 15:37:12 +0800 Subject: [PATCH] =?UTF-8?q?add:=20=E6=B7=BB=E5=8A=A0ollama=E8=B0=83?= =?UTF-8?q?=E7=94=A8AI=20chat=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/maubot-llmplus.iml | 2 +- .idea/misc.xml | 2 +- base-config.yaml | 8 +- maubot_llmplus/aibot.py | 40 ++++++-- maubot_llmplus/llm/__init__.py | 0 maubot_llmplus/llm/local_paltform.py | 51 ++++++++++ maubot_llmplus/llm/platforms.py | 146 +++++++++++++++++++++++++++ maubot_llmplus/llm/thrid_platform.py | 39 +++++++ 8 files changed, 275 insertions(+), 13 deletions(-) create mode 100644 maubot_llmplus/llm/__init__.py create mode 100644 maubot_llmplus/llm/local_paltform.py create mode 100644 maubot_llmplus/llm/platforms.py create mode 100644 maubot_llmplus/llm/thrid_platform.py diff --git a/.idea/maubot-llmplus.iml b/.idea/maubot-llmplus.iml index 9d99f9e..6d99fbe 100644 --- a/.idea/maubot-llmplus.iml +++ b/.idea/maubot-llmplus.iml @@ -4,7 +4,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 74d11cf..90553d3 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/base-config.yaml b/base-config.yaml index 82a3bf2..87b546c 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -11,7 +11,7 @@ enable_multi_user: system_prompt: platforms: - local: + local_ai: type: ollama url: http://localhost:11434 api_key: @@ -30,3 +30,9 @@ platforms: max_tokens: model: max_words: + +additional_prompt: + - role: user + content: xxx + - role: system + content: xxx \ No newline at end of file diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index d0be60f..b0d7182 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -11,10 +11,15 @@ from maubot.handlers import command, event from maubot import Plugin, MessageEvent from mautrix.errors import MNotFound, MatrixRequestError from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, \ - EncryptedEvent + EncryptedEvent, MediaMessageEventContent, ImageInfo, EncryptedFile from mautrix.util import markdown from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper +from maubot_llmplus.llm import platforms +from maubot_llmplus.llm.local_paltform import Ollama, LmStudio +from maubot_llmplus.llm.platforms import Platform +from maubot_llmplus.llm.thrid_platform import OpenAi, Anthropic + """ 配置文件加载 """ @@ -29,6 +34,7 @@ class Config(BaseProxyConfig): helper.copy("enable_multi_user") helper.copy("system_prompt") helper.copy("platforms") + helper.copy("additional_prompt") class AiBotPlugin(Plugin): @@ -109,11 +115,6 @@ class AiBotPlugin(Plugin): if parent_event.sender == self.client.mxid: return True - async def get_context(self, event: MessageEvent): - return None - - async def _ai_call(self, prompt): - return None @event.on(EventType.ROOM_MESSAGE) async def on_message(self, event: MessageEvent) -> None: @@ -123,20 +124,39 @@ class AiBotPlugin(Plugin): try: await event.mark_read() await self.client.set_typing(event.room_id, timeout=99999) - resp_content = "response test" + platform = self.get_platform() + chat_completion = platform.create_chat_completion(event) # ai gpt调用 # 关闭typing提示 await self.client.set_typing(event.room_id, timeout=0) - # 打开typing提示 - response = TextMessageEventContent(msgtype=MessageType.NOTICE, body=resp_content, format=Format.HTML, + resp_content = chat_completion.message['content'] + response = TextMessageEventContent(msgtype=MessageType.IMAGE, body=resp_content, format=Format.HTML, formatted_body=markdown.render(resp_content)) await event.respond(response, in_thread=self.config['reply_in_thread']) except Exception as e: + self.log.exception(f"Something went wrong: {e}") + await event.respond(f"Something went wrong: {e}") pass - return None; + return None + + async def get_platform(self) -> Platform: + use_platform = self.config['use_platform'] + if use_platform == 'local_ai': + type = self.config['platforms']['local_ai']['type'] + if type == 'ollama': + return Ollama(self.config) + elif type == 'lmstudio': + return LmStudio(self.config) + else: + raise ValueError(f"not found platform type: {type}") + if use_platform == 'openai': + return OpenAi(self.config) + if use_platform == 'anthropic': + return Anthropic(self.config) + raise ValueError(f"unknown backend type {use_platform}") @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: diff --git a/maubot_llmplus/llm/__init__.py b/maubot_llmplus/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/maubot_llmplus/llm/local_paltform.py b/maubot_llmplus/llm/local_paltform.py new file mode 100644 index 0000000..658b54b --- /dev/null +++ b/maubot_llmplus/llm/local_paltform.py @@ -0,0 +1,51 @@ +import json +import platform +from collections import deque +from typing import List + +from maubot import Plugin +from mautrix.types import MessageEvent +from mautrix.util.config import BaseProxyConfig + +from maubot_llmplus import AiBotPlugin +from maubot_llmplus.llm import platforms +from maubot_llmplus.llm.platforms import Platform, ChatCompletion + + +class Ollama(Platform): + chat_api: str + + def __init__(self, config: BaseProxyConfig) -> None: + super().__init__(config) + self.chat_api = '/api/chat' + + async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + full_context = [] + context = platforms.get_context(evt) + full_context.extend(list(context)) + + endpoint = f"{self.url}/api/chat" + req_body = {'model': self.model, 'message': full_context, 'steam': False} + headers = {} + if self.api_key is not None: + headers['Authorization'] = self.api_key + async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response: + if response.status != 200: + return ChatCompletion( + message={}, + finish_reason=f"http status {response.status}", + model=None + ) + response_json = await response.json() + return ChatCompletion( + message=response_json['message'], + finish_reason='success', + model=response_json.get('model', None) + ) + + def get_type(self) -> str: + return "local_ai" + + +class LmStudio(Platform): + pass diff --git a/maubot_llmplus/llm/platforms.py b/maubot_llmplus/llm/platforms.py new file mode 100644 index 0000000..85827a5 --- /dev/null +++ b/maubot_llmplus/llm/platforms.py @@ -0,0 +1,146 @@ +import json +from collections import deque +from datetime import datetime +from typing import Optional, List, Generator + +from aiohttp import ClientSession +from maubot import Plugin +from mautrix.types import MessageEvent, EncryptedEvent +from mautrix.util.config import BaseProxyConfig + +from maubot_llmplus import AiBotPlugin + +""" + AI响应对象 +""" + + +class ChatCompletion: + def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None: + self.message = message + self.finish_reason = finish_reason + self.model = model + + def __eq__(self, other) -> bool: + return self.message == other.message and self.model == other.model + + +class Platform: + http: ClientSession + config: dict + url: str + api_key: str + model: str + max_words: int + additional_prompt: List[dict] + system_prompt: str + max_context_messages: int + + def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: + self.http = http + self.config = config['platforms'][self.get_type()] + self.url = self.config['url'] + self.model = self.config['model'] + self.max_words = self.config['max_words'] + self.api_key = self.config['api_key'] + self.max_context_messages = self.config['max_context_messages'] + self.additional_prompt = config['additional_prompt'] + self.system_prompt = config['system_prompt'] + + """a + 调用AI对话接口, 响应结果 + """ + + async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + raise NotImplementedError() + + def get_type(self) -> str: + raise NotImplementedError() + + + +async def get_context(plugin: AiBotPlugin, evt: MessageEvent) -> deque: + # 创建系统提示词上下文 + system_context = deque() + # 生成当前时间 + timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S') + # 加入系统提示词 + system_prompt = {"role": "system", + "content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)} + if plugin.config['enable_multi_user']: + system_prompt["content"] += """ + User messages are in the context of multiperson chatrooms. + Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example: + "username: hello world." + In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses. + your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them. + """ + system_context.append(system_prompt) + + # 添加额外的系统提示词和用户提示词 + additional_context = json.loads(json.dumps(plugin.config['additional_prompt'])) + if additional_context: + for item in additional_context: + system_context.append(item) + # 如果 消息长度已经超过了配置的消息条数,那么就抛出错误 + if len(additional_context) > plugin.config['max_context_messages'] - 1: + raise ValueError(f"sorry, my configuration has too many additional prompts " + f"({plugin.config['max_context_messages']}) and i'll never see your message. " + f"Update my config to have fewer messages and i'll be able to answer your questions!") + + # 用户历史聊天上下文 + chat_context = deque() + # 计算系统提示词单词数 + word_count = sum([len(m["content"].split()) for m in system_context]) + message_count = len(system_context) - 1 + async for next_event in generate_context_messages(plugin, evt): + # 如果不是文本类型,就跳过 + try: + if not next_event.content.msgtype.is_text: + continue + except (KeyError, AttributeError): + continue + + # 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant + role = 'assistant' if plugin.client.mxid == next_event.sender else 'user' + message = next_event['content']['body'] + user = '' + # 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名 + if plugin.config['enable_multi_user']: + user = (await plugin.client.get_displayname(next_event.sender) or + plugin.client.parse_user_id(next_event.sender)[0]) + ": " + + # 计算单词量和消息数 + word_count += len(message.split()) + message_count += 1 + if word_count >= plugin.config['max_words'] or message_count >= plugin.config['max_context_messages']: + break + chat_context.appendleft({"role": role, "content": user + message}) + + return system_context + chat_context + + + + +async def generate_context_messages(plugin: AiBotPlugin, evt: MessageEvent) -> Generator[MessageEvent, None, None]: + yield evt + if plugin.config['reply_in_thread']: + while evt.content.relates_to.in_reply_to: + evt = await plugin.client.get_event(room_id=evt.room_id, event_id=evt.content.get_reply_to()) + yield evt + else: + event_context = await plugin.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id, + limit=plugin.config["max_context_messages"] * 2) + previous_messages = iter(event_context.events_before) + for evt in previous_messages: + + # We already have the event, but currently, get_event_context doesn't automatically decrypt events + if isinstance(evt, EncryptedEvent) and plugin.client.crypto: + evt = await plugin.client.get_event(event_id=evt.event_id, room_id=evt.room_id) + if not evt: + raise ValueError("Decryption error!") + + yield evt + + + diff --git a/maubot_llmplus/llm/thrid_platform.py b/maubot_llmplus/llm/thrid_platform.py new file mode 100644 index 0000000..dfc6b42 --- /dev/null +++ b/maubot_llmplus/llm/thrid_platform.py @@ -0,0 +1,39 @@ +from collections import deque +from typing import List + +from maubot import Plugin +from mautrix.types import MessageEvent +from mautrix.util.config import BaseProxyConfig + +from maubot_llmplus import AiBotPlugin +from maubot_llmplus.llm.platforms import Platform, ChatCompletion + + +class OpenAi(Platform): + + def __init__(self, config: BaseProxyConfig) -> None: + super().__init__(config) + + async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + # 获取系统提示词 + # 获取额外的其他角色的提示词: role: user role: system + + pass + + def get_type(self) -> str: + return "openai" + + +class Anthropic(Platform): + + def __init__(self, config: BaseProxyConfig) -> None: + super().__init__(config) + + async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion: + # 获取系统提示词 + # 获取额外的其他角色的提示词: role: user role: system + + pass + + def get_type(self) -> str: + return "anthropic"