add: 添加ollama调用AI chat逻辑
This commit is contained in:
2
.idea/maubot-llmplus.iml
generated
2
.idea/maubot-llmplus.iml
generated
@@ -4,7 +4,7 @@
|
|||||||
<content url="file://$MODULE_DIR$">
|
<content url="file://$MODULE_DIR$">
|
||||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||||
</content>
|
</content>
|
||||||
<orderEntry type="jdk" jdkName="maubot-llm-conda" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Python 3.11 (maubot-llmplus)" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
||||||
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -1,4 +1,4 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="maubot-llm-conda" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (maubot-llmplus)" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
@@ -11,7 +11,7 @@ enable_multi_user:
|
|||||||
system_prompt:
|
system_prompt:
|
||||||
|
|
||||||
platforms:
|
platforms:
|
||||||
local:
|
local_ai:
|
||||||
type: ollama
|
type: ollama
|
||||||
url: http://localhost:11434
|
url: http://localhost:11434
|
||||||
api_key:
|
api_key:
|
||||||
@@ -30,3 +30,9 @@ platforms:
|
|||||||
max_tokens:
|
max_tokens:
|
||||||
model:
|
model:
|
||||||
max_words:
|
max_words:
|
||||||
|
|
||||||
|
additional_prompt:
|
||||||
|
- role: user
|
||||||
|
content: xxx
|
||||||
|
- role: system
|
||||||
|
content: xxx
|
||||||
@@ -11,10 +11,15 @@ from maubot.handlers import command, event
|
|||||||
from maubot import Plugin, MessageEvent
|
from maubot import Plugin, MessageEvent
|
||||||
from mautrix.errors import MNotFound, MatrixRequestError
|
from mautrix.errors import MNotFound, MatrixRequestError
|
||||||
from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, \
|
from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, \
|
||||||
EncryptedEvent
|
EncryptedEvent, MediaMessageEventContent, ImageInfo, EncryptedFile
|
||||||
from mautrix.util import markdown
|
from mautrix.util import markdown
|
||||||
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||||
|
|
||||||
|
from maubot_llmplus.llm import platforms
|
||||||
|
from maubot_llmplus.llm.local_paltform import Ollama, LmStudio
|
||||||
|
from maubot_llmplus.llm.platforms import Platform
|
||||||
|
from maubot_llmplus.llm.thrid_platform import OpenAi, Anthropic
|
||||||
|
|
||||||
"""
|
"""
|
||||||
配置文件加载
|
配置文件加载
|
||||||
"""
|
"""
|
||||||
@@ -29,6 +34,7 @@ class Config(BaseProxyConfig):
|
|||||||
helper.copy("enable_multi_user")
|
helper.copy("enable_multi_user")
|
||||||
helper.copy("system_prompt")
|
helper.copy("system_prompt")
|
||||||
helper.copy("platforms")
|
helper.copy("platforms")
|
||||||
|
helper.copy("additional_prompt")
|
||||||
|
|
||||||
|
|
||||||
class AiBotPlugin(Plugin):
|
class AiBotPlugin(Plugin):
|
||||||
@@ -109,11 +115,6 @@ class AiBotPlugin(Plugin):
|
|||||||
if parent_event.sender == self.client.mxid:
|
if parent_event.sender == self.client.mxid:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_context(self, event: MessageEvent):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _ai_call(self, prompt):
|
|
||||||
return None
|
|
||||||
|
|
||||||
@event.on(EventType.ROOM_MESSAGE)
|
@event.on(EventType.ROOM_MESSAGE)
|
||||||
async def on_message(self, event: MessageEvent) -> None:
|
async def on_message(self, event: MessageEvent) -> None:
|
||||||
@@ -123,20 +124,39 @@ class AiBotPlugin(Plugin):
|
|||||||
try:
|
try:
|
||||||
await event.mark_read()
|
await event.mark_read()
|
||||||
await self.client.set_typing(event.room_id, timeout=99999)
|
await self.client.set_typing(event.room_id, timeout=99999)
|
||||||
resp_content = "response test"
|
platform = self.get_platform()
|
||||||
|
chat_completion = platform.create_chat_completion(event)
|
||||||
# ai gpt调用
|
# ai gpt调用
|
||||||
# 关闭typing提示
|
# 关闭typing提示
|
||||||
await self.client.set_typing(event.room_id, timeout=0)
|
await self.client.set_typing(event.room_id, timeout=0)
|
||||||
|
|
||||||
# 打开typing提示
|
# 打开typing提示
|
||||||
response = TextMessageEventContent(msgtype=MessageType.NOTICE, body=resp_content, format=Format.HTML,
|
resp_content = chat_completion.message['content']
|
||||||
|
response = TextMessageEventContent(msgtype=MessageType.IMAGE, body=resp_content, format=Format.HTML,
|
||||||
formatted_body=markdown.render(resp_content))
|
formatted_body=markdown.render(resp_content))
|
||||||
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self.log.exception(f"Something went wrong: {e}")
|
||||||
|
await event.respond(f"Something went wrong: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return None;
|
return None
|
||||||
|
|
||||||
|
async def get_platform(self) -> Platform:
|
||||||
|
use_platform = self.config['use_platform']
|
||||||
|
if use_platform == 'local_ai':
|
||||||
|
type = self.config['platforms']['local_ai']['type']
|
||||||
|
if type == 'ollama':
|
||||||
|
return Ollama(self.config)
|
||||||
|
elif type == 'lmstudio':
|
||||||
|
return LmStudio(self.config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not found platform type: {type}")
|
||||||
|
if use_platform == 'openai':
|
||||||
|
return OpenAi(self.config)
|
||||||
|
if use_platform == 'anthropic':
|
||||||
|
return Anthropic(self.config)
|
||||||
|
raise ValueError(f"unknown backend type {use_platform}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||||
|
|||||||
0
maubot_llmplus/llm/__init__.py
Normal file
0
maubot_llmplus/llm/__init__.py
Normal file
51
maubot_llmplus/llm/local_paltform.py
Normal file
51
maubot_llmplus/llm/local_paltform.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
import platform
|
||||||
|
from collections import deque
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from maubot import Plugin
|
||||||
|
from mautrix.types import MessageEvent
|
||||||
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
from maubot_llmplus import AiBotPlugin
|
||||||
|
from maubot_llmplus.llm import platforms
|
||||||
|
from maubot_llmplus.llm.platforms import Platform, ChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
class Ollama(Platform):
|
||||||
|
chat_api: str
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
self.chat_api = '/api/chat'
|
||||||
|
|
||||||
|
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = platforms.get_context(evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/api/chat"
|
||||||
|
req_body = {'model': self.model, 'message': full_context, 'steam': False}
|
||||||
|
headers = {}
|
||||||
|
if self.api_key is not None:
|
||||||
|
headers['Authorization'] = self.api_key
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
message={},
|
||||||
|
finish_reason=f"http status {response.status}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
return ChatCompletion(
|
||||||
|
message=response_json['message'],
|
||||||
|
finish_reason='success',
|
||||||
|
model=response_json.get('model', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "local_ai"
|
||||||
|
|
||||||
|
|
||||||
|
class LmStudio(Platform):
|
||||||
|
pass
|
||||||
146
maubot_llmplus/llm/platforms.py
Normal file
146
maubot_llmplus/llm/platforms.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import json
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Generator
|
||||||
|
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from maubot import Plugin
|
||||||
|
from mautrix.types import MessageEvent, EncryptedEvent
|
||||||
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
from maubot_llmplus import AiBotPlugin
|
||||||
|
|
||||||
|
"""
|
||||||
|
AI响应对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletion:
|
||||||
|
def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
||||||
|
self.message = message
|
||||||
|
self.finish_reason = finish_reason
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
return self.message == other.message and self.model == other.model
|
||||||
|
|
||||||
|
|
||||||
|
class Platform:
|
||||||
|
http: ClientSession
|
||||||
|
config: dict
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
max_words: int
|
||||||
|
additional_prompt: List[dict]
|
||||||
|
system_prompt: str
|
||||||
|
max_context_messages: int
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
|
self.http = http
|
||||||
|
self.config = config['platforms'][self.get_type()]
|
||||||
|
self.url = self.config['url']
|
||||||
|
self.model = self.config['model']
|
||||||
|
self.max_words = self.config['max_words']
|
||||||
|
self.api_key = self.config['api_key']
|
||||||
|
self.max_context_messages = self.config['max_context_messages']
|
||||||
|
self.additional_prompt = config['additional_prompt']
|
||||||
|
self.system_prompt = config['system_prompt']
|
||||||
|
|
||||||
|
"""a
|
||||||
|
调用AI对话接口, 响应结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def get_context(plugin: AiBotPlugin, evt: MessageEvent) -> deque:
|
||||||
|
# 创建系统提示词上下文
|
||||||
|
system_context = deque()
|
||||||
|
# 生成当前时间
|
||||||
|
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
# 加入系统提示词
|
||||||
|
system_prompt = {"role": "system",
|
||||||
|
"content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)}
|
||||||
|
if plugin.config['enable_multi_user']:
|
||||||
|
system_prompt["content"] += """
|
||||||
|
User messages are in the context of multiperson chatrooms.
|
||||||
|
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
|
||||||
|
"username: hello world."
|
||||||
|
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
|
||||||
|
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
|
||||||
|
"""
|
||||||
|
system_context.append(system_prompt)
|
||||||
|
|
||||||
|
# 添加额外的系统提示词和用户提示词
|
||||||
|
additional_context = json.loads(json.dumps(plugin.config['additional_prompt']))
|
||||||
|
if additional_context:
|
||||||
|
for item in additional_context:
|
||||||
|
system_context.append(item)
|
||||||
|
# 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
|
||||||
|
if len(additional_context) > plugin.config['max_context_messages'] - 1:
|
||||||
|
raise ValueError(f"sorry, my configuration has too many additional prompts "
|
||||||
|
f"({plugin.config['max_context_messages']}) and i'll never see your message. "
|
||||||
|
f"Update my config to have fewer messages and i'll be able to answer your questions!")
|
||||||
|
|
||||||
|
# 用户历史聊天上下文
|
||||||
|
chat_context = deque()
|
||||||
|
# 计算系统提示词单词数
|
||||||
|
word_count = sum([len(m["content"].split()) for m in system_context])
|
||||||
|
message_count = len(system_context) - 1
|
||||||
|
async for next_event in generate_context_messages(plugin, evt):
|
||||||
|
# 如果不是文本类型,就跳过
|
||||||
|
try:
|
||||||
|
if not next_event.content.msgtype.is_text:
|
||||||
|
continue
|
||||||
|
except (KeyError, AttributeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
||||||
|
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
|
||||||
|
message = next_event['content']['body']
|
||||||
|
user = ''
|
||||||
|
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
|
||||||
|
if plugin.config['enable_multi_user']:
|
||||||
|
user = (await plugin.client.get_displayname(next_event.sender) or
|
||||||
|
plugin.client.parse_user_id(next_event.sender)[0]) + ": "
|
||||||
|
|
||||||
|
# 计算单词量和消息数
|
||||||
|
word_count += len(message.split())
|
||||||
|
message_count += 1
|
||||||
|
if word_count >= plugin.config['max_words'] or message_count >= plugin.config['max_context_messages']:
|
||||||
|
break
|
||||||
|
chat_context.appendleft({"role": role, "content": user + message})
|
||||||
|
|
||||||
|
return system_context + chat_context
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_context_messages(plugin: AiBotPlugin, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
|
||||||
|
yield evt
|
||||||
|
if plugin.config['reply_in_thread']:
|
||||||
|
while evt.content.relates_to.in_reply_to:
|
||||||
|
evt = await plugin.client.get_event(room_id=evt.room_id, event_id=evt.content.get_reply_to())
|
||||||
|
yield evt
|
||||||
|
else:
|
||||||
|
event_context = await plugin.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id,
|
||||||
|
limit=plugin.config["max_context_messages"] * 2)
|
||||||
|
previous_messages = iter(event_context.events_before)
|
||||||
|
for evt in previous_messages:
|
||||||
|
|
||||||
|
# We already have the event, but currently, get_event_context doesn't automatically decrypt events
|
||||||
|
if isinstance(evt, EncryptedEvent) and plugin.client.crypto:
|
||||||
|
evt = await plugin.client.get_event(event_id=evt.event_id, room_id=evt.room_id)
|
||||||
|
if not evt:
|
||||||
|
raise ValueError("Decryption error!")
|
||||||
|
|
||||||
|
yield evt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
39
maubot_llmplus/llm/thrid_platform.py
Normal file
39
maubot_llmplus/llm/thrid_platform.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from collections import deque
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from maubot import Plugin
|
||||||
|
from mautrix.types import MessageEvent
|
||||||
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
from maubot_llmplus import AiBotPlugin
|
||||||
|
from maubot_llmplus.llm.platforms import Platform, ChatCompletion
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAi(Platform):
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
# 获取系统提示词
|
||||||
|
# 获取额外的其他角色的提示词: role: user role: system
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class Anthropic(Platform):
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
# 获取系统提示词
|
||||||
|
# 获取额外的其他角色的提示词: role: user role: system
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "anthropic"
|
||||||
Reference in New Issue
Block a user