diff --git a/.idea/maubot-llmplus.iml b/.idea/maubot-llmplus.iml
index 9d99f9e..6d99fbe 100644
--- a/.idea/maubot-llmplus.iml
+++ b/.idea/maubot-llmplus.iml
@@ -4,7 +4,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 74d11cf..90553d3 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/base-config.yaml b/base-config.yaml
index 82a3bf2..87b546c 100644
--- a/base-config.yaml
+++ b/base-config.yaml
@@ -11,7 +11,7 @@ enable_multi_user:
system_prompt:
platforms:
- local:
+ local_ai:
type: ollama
url: http://localhost:11434
api_key:
@@ -30,3 +30,9 @@ platforms:
max_tokens:
model:
max_words:
+
+additional_prompt:
+ - role: user
+ content: xxx
+ - role: system
+ content: xxx
\ No newline at end of file
diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py
index d0be60f..b0d7182 100644
--- a/maubot_llmplus/aibot.py
+++ b/maubot_llmplus/aibot.py
@@ -11,10 +11,15 @@ from maubot.handlers import command, event
from maubot import Plugin, MessageEvent
from mautrix.errors import MNotFound, MatrixRequestError
from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, \
- EncryptedEvent
+ EncryptedEvent, MediaMessageEventContent, ImageInfo, EncryptedFile
from mautrix.util import markdown
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
+from maubot_llmplus.llm import platforms
+from maubot_llmplus.llm.local_paltform import Ollama, LmStudio
+from maubot_llmplus.llm.platforms import Platform
+from maubot_llmplus.llm.thrid_platform import OpenAi, Anthropic
+
"""
配置文件加载
"""
@@ -29,6 +34,7 @@ class Config(BaseProxyConfig):
helper.copy("enable_multi_user")
helper.copy("system_prompt")
helper.copy("platforms")
+ helper.copy("additional_prompt")
class AiBotPlugin(Plugin):
@@ -109,11 +115,6 @@ class AiBotPlugin(Plugin):
if parent_event.sender == self.client.mxid:
return True
- async def get_context(self, event: MessageEvent):
- return None
-
- async def _ai_call(self, prompt):
- return None
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
@@ -123,20 +124,39 @@ class AiBotPlugin(Plugin):
try:
await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999)
- resp_content = "response test"
+ platform = self.get_platform()
+ chat_completion = platform.create_chat_completion(event)
# ai gpt调用
# 关闭typing提示
await self.client.set_typing(event.room_id, timeout=0)
-
# 打开typing提示
- response = TextMessageEventContent(msgtype=MessageType.NOTICE, body=resp_content, format=Format.HTML,
+ resp_content = chat_completion.message['content']
+ response = TextMessageEventContent(msgtype=MessageType.IMAGE, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread'])
except Exception as e:
+ self.log.exception(f"Something went wrong: {e}")
+ await event.respond(f"Something went wrong: {e}")
pass
- return None;
+ return None
+
+ async def get_platform(self) -> Platform:
+ use_platform = self.config['use_platform']
+ if use_platform == 'local_ai':
+ type = self.config['platforms']['local_ai']['type']
+ if type == 'ollama':
+ return Ollama(self.config)
+ elif type == 'lmstudio':
+ return LmStudio(self.config)
+ else:
+ raise ValueError(f"not found platform type: {type}")
+ if use_platform == 'openai':
+ return OpenAi(self.config)
+ if use_platform == 'anthropic':
+ return Anthropic(self.config)
+ raise ValueError(f"unknown backend type {use_platform}")
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
diff --git a/maubot_llmplus/llm/__init__.py b/maubot_llmplus/llm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/maubot_llmplus/llm/local_paltform.py b/maubot_llmplus/llm/local_paltform.py
new file mode 100644
index 0000000..658b54b
--- /dev/null
+++ b/maubot_llmplus/llm/local_paltform.py
@@ -0,0 +1,51 @@
+import json
+import platform
+from collections import deque
+from typing import List
+
+from maubot import Plugin
+from mautrix.types import MessageEvent
+from mautrix.util.config import BaseProxyConfig
+
+from maubot_llmplus import AiBotPlugin
+from maubot_llmplus.llm import platforms
+from maubot_llmplus.llm.platforms import Platform, ChatCompletion
+
+
+class Ollama(Platform):
+ chat_api: str
+
+ def __init__(self, config: BaseProxyConfig) -> None:
+ super().__init__(config)
+ self.chat_api = '/api/chat'
+
+ async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
+ full_context = []
+ context = platforms.get_context(evt)
+ full_context.extend(list(context))
+
+ endpoint = f"{self.url}/api/chat"
+ req_body = {'model': self.model, 'message': full_context, 'steam': False}
+ headers = {}
+ if self.api_key is not None:
+ headers['Authorization'] = self.api_key
+ async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
+ if response.status != 200:
+ return ChatCompletion(
+ message={},
+ finish_reason=f"http status {response.status}",
+ model=None
+ )
+ response_json = await response.json()
+ return ChatCompletion(
+ message=response_json['message'],
+ finish_reason='success',
+ model=response_json.get('model', None)
+ )
+
+ def get_type(self) -> str:
+ return "local_ai"
+
+
+class LmStudio(Platform):
+ pass
diff --git a/maubot_llmplus/llm/platforms.py b/maubot_llmplus/llm/platforms.py
new file mode 100644
index 0000000..85827a5
--- /dev/null
+++ b/maubot_llmplus/llm/platforms.py
@@ -0,0 +1,146 @@
+import json
+from collections import deque
+from datetime import datetime
+from typing import Optional, List, Generator
+
+from aiohttp import ClientSession
+from maubot import Plugin
+from mautrix.types import MessageEvent, EncryptedEvent
+from mautrix.util.config import BaseProxyConfig
+
+from maubot_llmplus import AiBotPlugin
+
+"""
+ AI响应对象
+"""
+
+
+class ChatCompletion:
+ def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
+ self.message = message
+ self.finish_reason = finish_reason
+ self.model = model
+
+ def __eq__(self, other) -> bool:
+ return self.message == other.message and self.model == other.model
+
+
+class Platform:
+ http: ClientSession
+ config: dict
+ url: str
+ api_key: str
+ model: str
+ max_words: int
+ additional_prompt: List[dict]
+ system_prompt: str
+ max_context_messages: int
+
+ def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
+ self.http = http
+ self.config = config['platforms'][self.get_type()]
+ self.url = self.config['url']
+ self.model = self.config['model']
+ self.max_words = self.config['max_words']
+ self.api_key = self.config['api_key']
+ self.max_context_messages = self.config['max_context_messages']
+ self.additional_prompt = config['additional_prompt']
+ self.system_prompt = config['system_prompt']
+
+ """a
+ 调用AI对话接口, 响应结果
+ """
+
+ async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
+ raise NotImplementedError()
+
+ def get_type(self) -> str:
+ raise NotImplementedError()
+
+
+
+async def get_context(plugin: AiBotPlugin, evt: MessageEvent) -> deque:
+ # 创建系统提示词上下文
+ system_context = deque()
+ # 生成当前时间
+ timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
+ # 加入系统提示词
+ system_prompt = {"role": "system",
+ "content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)}
+ if plugin.config['enable_multi_user']:
+ system_prompt["content"] += """
+ User messages are in the context of multiperson chatrooms.
+ Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
+ "username: hello world."
+ In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
+ your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
+ """
+ system_context.append(system_prompt)
+
+ # 添加额外的系统提示词和用户提示词
+ additional_context = json.loads(json.dumps(plugin.config['additional_prompt']))
+ if additional_context:
+ for item in additional_context:
+ system_context.append(item)
+ # 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
+ if len(additional_context) > plugin.config['max_context_messages'] - 1:
+ raise ValueError(f"sorry, my configuration has too many additional prompts "
+ f"({plugin.config['max_context_messages']}) and i'll never see your message. "
+ f"Update my config to have fewer messages and i'll be able to answer your questions!")
+
+ # 用户历史聊天上下文
+ chat_context = deque()
+ # 计算系统提示词单词数
+ word_count = sum([len(m["content"].split()) for m in system_context])
+ message_count = len(system_context) - 1
+ async for next_event in generate_context_messages(plugin, evt):
+ # 如果不是文本类型,就跳过
+ try:
+ if not next_event.content.msgtype.is_text:
+ continue
+ except (KeyError, AttributeError):
+ continue
+
+ # 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
+ role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
+ message = next_event['content']['body']
+ user = ''
+ # 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
+ if plugin.config['enable_multi_user']:
+ user = (await plugin.client.get_displayname(next_event.sender) or
+ plugin.client.parse_user_id(next_event.sender)[0]) + ": "
+
+ # 计算单词量和消息数
+ word_count += len(message.split())
+ message_count += 1
+ if word_count >= plugin.config['max_words'] or message_count >= plugin.config['max_context_messages']:
+ break
+ chat_context.appendleft({"role": role, "content": user + message})
+
+ return system_context + chat_context
+
+
+
+
+async def generate_context_messages(plugin: AiBotPlugin, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
+ yield evt
+ if plugin.config['reply_in_thread']:
+ while evt.content.relates_to.in_reply_to:
+ evt = await plugin.client.get_event(room_id=evt.room_id, event_id=evt.content.get_reply_to())
+ yield evt
+ else:
+ event_context = await plugin.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id,
+ limit=plugin.config["max_context_messages"] * 2)
+ previous_messages = iter(event_context.events_before)
+ for evt in previous_messages:
+
+ # We already have the event, but currently, get_event_context doesn't automatically decrypt events
+ if isinstance(evt, EncryptedEvent) and plugin.client.crypto:
+ evt = await plugin.client.get_event(event_id=evt.event_id, room_id=evt.room_id)
+ if not evt:
+ raise ValueError("Decryption error!")
+
+ yield evt
+
+
+
diff --git a/maubot_llmplus/llm/thrid_platform.py b/maubot_llmplus/llm/thrid_platform.py
new file mode 100644
index 0000000..dfc6b42
--- /dev/null
+++ b/maubot_llmplus/llm/thrid_platform.py
@@ -0,0 +1,39 @@
+from collections import deque
+from typing import List
+
+from maubot import Plugin
+from mautrix.types import MessageEvent
+from mautrix.util.config import BaseProxyConfig
+
+from maubot_llmplus import AiBotPlugin
+from maubot_llmplus.llm.platforms import Platform, ChatCompletion
+
+
+class OpenAi(Platform):
+
+ def __init__(self, config: BaseProxyConfig) -> None:
+ super().__init__(config)
+
+ async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
+ # 获取系统提示词
+ # 获取额外的其他角色的提示词: role: user role: system
+
+ pass
+
+ def get_type(self) -> str:
+ return "openai"
+
+
+class Anthropic(Platform):
+
+ def __init__(self, config: BaseProxyConfig) -> None:
+ super().__init__(config)
+
+ async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
+ # 获取系统提示词
+ # 获取额外的其他角色的提示词: role: user role: system
+
+ pass
+
+ def get_type(self) -> str:
+ return "anthropic"