152 lines
6.1 KiB
Python
152 lines
6.1 KiB
Python
import json
|
||
from collections import deque
|
||
from datetime import datetime
|
||
from typing import Optional, List, Generator
|
||
|
||
from aiohttp import ClientSession
|
||
from maubot import Plugin
|
||
from mautrix.types import MessageEvent, EncryptedEvent
|
||
from mautrix.util.config import BaseProxyConfig
|
||
|
||
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||
|
||
"""
|
||
AI响应对象
|
||
"""
|
||
|
||
|
||
class ChatCompletion:
|
||
def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
||
self.message = message
|
||
self.finish_reason = finish_reason
|
||
self.model = model
|
||
|
||
def __eq__(self, other) -> bool:
|
||
return self.message == other.message and self.model == other.model
|
||
|
||
|
||
class Platform:
|
||
http: ClientSession
|
||
config: dict
|
||
url: str
|
||
api_key: str
|
||
model: str
|
||
max_words: int
|
||
additional_prompt: List[dict]
|
||
system_prompt: str
|
||
max_context_messages: int
|
||
|
||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||
self.http = http
|
||
self.config = config['platforms'][self.get_type()]
|
||
self.url = self.config['url']
|
||
self.model = config._cur_model
|
||
self.max_words = self.config['max_words']
|
||
self.api_key = self.config['api_key']
|
||
self.max_context_messages = self.config['max_context_messages']
|
||
self.additional_prompt = config['additional_prompt']
|
||
self.system_prompt = config['system_prompt']
|
||
|
||
"""a
|
||
调用AI对话接口, 响应结果
|
||
"""
|
||
|
||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||
raise NotImplementedError()
|
||
|
||
async def list_models(self) -> List[str]:
|
||
raise NotImplementedError()
|
||
|
||
def get_type(self) -> str:
|
||
raise NotImplementedError()
|
||
|
||
|
||
|
||
async def get_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent) -> deque:
|
||
# 创建系统提示词上下文
|
||
system_context = deque()
|
||
# 生成当前时间
|
||
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
|
||
# 加入系统提示词
|
||
system_prompt = {"role": "system",
|
||
"content": plugin.config['system_prompt'].format(name=plugin.get_bot_name(), timestamp=timestamp)}
|
||
if plugin.config['enable_multi_user']:
|
||
system_prompt["content"] += """
|
||
User messages are in the context of multiperson chatrooms.
|
||
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
|
||
"username: hello world."
|
||
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
|
||
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
|
||
"""
|
||
if len(system_prompt["content"]) > 0:
|
||
system_context.append(system_prompt)
|
||
|
||
# 添加额外的系统提示词和用户提示词
|
||
additional_context = json.loads(json.dumps(plugin.config['additional_prompt']))
|
||
if additional_context:
|
||
for item in additional_context:
|
||
system_context.append(item)
|
||
# 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
|
||
if len(additional_context) > platform.max_context_messages - 1:
|
||
raise ValueError(f"sorry, my configuration has too many additional prompts "
|
||
f"({platform.max_context_messages}) and i'll never see your message. "
|
||
f"Update my config to have fewer messages and i'll be able to answer your questions!")
|
||
|
||
# 用户历史聊天上下文
|
||
chat_context = deque()
|
||
# 计算系统提示词单词数
|
||
word_count = sum([len(m["content"].split()) for m in system_context])
|
||
message_count = len(system_context) - 1
|
||
async for next_event in generate_context_messages(plugin, platform, evt):
|
||
# 如果不是文本类型,就跳过
|
||
try:
|
||
if not next_event.content.msgtype.is_text:
|
||
continue
|
||
except (KeyError, AttributeError):
|
||
continue
|
||
|
||
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
||
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
|
||
message = next_event['content']['body']
|
||
user = ''
|
||
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
|
||
if plugin.config['enable_multi_user']:
|
||
user = (await plugin.client.get_displayname(next_event.sender) or
|
||
plugin.client.parse_user_id(next_event.sender)[0]) + ": "
|
||
|
||
# 计算单词量和消息数
|
||
word_count += len(message.split())
|
||
message_count += 1
|
||
if word_count >= platform.max_words or message_count >= platform.max_context_messages:
|
||
break
|
||
chat_context.appendleft({"role": role, "content": user + message})
|
||
|
||
return system_context + chat_context
|
||
|
||
|
||
|
||
|
||
async def generate_context_messages(plugin: Plugin, platform: Platform, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
|
||
yield evt
|
||
if plugin.config['reply_in_thread']:
|
||
while evt.content.relates_to.in_reply_to:
|
||
evt = await plugin.client.get_event(room_id=evt.room_id, event_id=evt.content.get_reply_to())
|
||
yield evt
|
||
else:
|
||
event_context = await plugin.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id,
|
||
limit=platform.max_context_messages * 2)
|
||
plugin.log.debug(f"event_context: {event_context}")
|
||
previous_messages = iter(event_context.events_before)
|
||
for evt in previous_messages:
|
||
|
||
# We already have the event, but currently, get_event_context doesn't automatically decrypt events
|
||
if isinstance(evt, EncryptedEvent) and plugin.client.crypto:
|
||
evt = await plugin.client.get_event(event_id=evt.event_id, room_id=evt.room_id)
|
||
if not evt:
|
||
raise ValueError("Decryption error!")
|
||
|
||
yield evt
|
||
|
||
|
||
|