add: 添加ollama调用AI chat逻辑

This commit is contained in:
taylor
2024-10-13 15:37:12 +08:00
parent 78acb679ee
commit e5d4e52bc0
8 changed files with 275 additions and 13 deletions

View File

@@ -4,7 +4,7 @@
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="jdk" jdkName="maubot-llm-conda" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.11 (maubot-llmplus)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="maubot-llm-conda" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (maubot-llmplus)" project-jdk-type="Python SDK" />
</project>

View File

@@ -11,7 +11,7 @@ enable_multi_user:
system_prompt:
platforms:
local:
local_ai:
type: ollama
url: http://localhost:11434
api_key:
@@ -30,3 +30,9 @@ platforms:
max_tokens:
model:
max_words:
additional_prompt:
- role: user
content: xxx
- role: system
content: xxx

View File

@@ -11,10 +11,15 @@ from maubot.handlers import command, event
from maubot import Plugin, MessageEvent
from mautrix.errors import MNotFound, MatrixRequestError
from mautrix.types import Format, TextMessageEventContent, EventType, RoomID, UserID, MessageType, RelationType, \
EncryptedEvent
EncryptedEvent, MediaMessageEventContent, ImageInfo, EncryptedFile
from mautrix.util import markdown
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
from maubot_llmplus.llm import platforms
from maubot_llmplus.llm.local_paltform import Ollama, LmStudio
from maubot_llmplus.llm.platforms import Platform
from maubot_llmplus.llm.thrid_platform import OpenAi, Anthropic
"""
配置文件加载
"""
@@ -29,6 +34,7 @@ class Config(BaseProxyConfig):
helper.copy("enable_multi_user")
helper.copy("system_prompt")
helper.copy("platforms")
helper.copy("additional_prompt")
class AiBotPlugin(Plugin):
@@ -109,11 +115,6 @@ class AiBotPlugin(Plugin):
if parent_event.sender == self.client.mxid:
return True
async def get_context(self, event: MessageEvent):
return None
async def _ai_call(self, prompt):
return None
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
@@ -123,20 +124,39 @@ class AiBotPlugin(Plugin):
try:
await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999)
resp_content = "response test"
platform = self.get_platform()
chat_completion = platform.create_chat_completion(event)
# ai gpt调用
# 关闭typing提示
await self.client.set_typing(event.room_id, timeout=0)
# 打开typing提示
response = TextMessageEventContent(msgtype=MessageType.NOTICE, body=resp_content, format=Format.HTML,
resp_content = chat_completion.message['content']
response = TextMessageEventContent(msgtype=MessageType.IMAGE, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread'])
except Exception as e:
self.log.exception(f"Something went wrong: {e}")
await event.respond(f"Something went wrong: {e}")
pass
return None;
return None
async def get_platform(self) -> Platform:
use_platform = self.config['use_platform']
if use_platform == 'local_ai':
type = self.config['platforms']['local_ai']['type']
if type == 'ollama':
return Ollama(self.config)
elif type == 'lmstudio':
return LmStudio(self.config)
else:
raise ValueError(f"not found platform type: {type}")
if use_platform == 'openai':
return OpenAi(self.config)
if use_platform == 'anthropic':
return Anthropic(self.config)
raise ValueError(f"unknown backend type {use_platform}")
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:

View File

View File

@@ -0,0 +1,51 @@
import json
import platform
from collections import deque
from typing import List
from maubot import Plugin
from mautrix.types import MessageEvent
from mautrix.util.config import BaseProxyConfig
from maubot_llmplus import AiBotPlugin
from maubot_llmplus.llm import platforms
from maubot_llmplus.llm.platforms import Platform, ChatCompletion
class Ollama(Platform):
chat_api: str
def __init__(self, config: BaseProxyConfig) -> None:
super().__init__(config)
self.chat_api = '/api/chat'
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
full_context = []
context = platforms.get_context(evt)
full_context.extend(list(context))
endpoint = f"{self.url}/api/chat"
req_body = {'model': self.model, 'message': full_context, 'steam': False}
headers = {}
if self.api_key is not None:
headers['Authorization'] = self.api_key
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
if response.status != 200:
return ChatCompletion(
message={},
finish_reason=f"http status {response.status}",
model=None
)
response_json = await response.json()
return ChatCompletion(
message=response_json['message'],
finish_reason='success',
model=response_json.get('model', None)
)
def get_type(self) -> str:
return "local_ai"
class LmStudio(Platform):
pass

View File

@@ -0,0 +1,146 @@
import json
from collections import deque
from datetime import datetime
from typing import Optional, List, Generator
from aiohttp import ClientSession
from maubot import Plugin
from mautrix.types import MessageEvent, EncryptedEvent
from mautrix.util.config import BaseProxyConfig
from maubot_llmplus import AiBotPlugin
"""
AI响应对象
"""
class ChatCompletion:
def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
self.message = message
self.finish_reason = finish_reason
self.model = model
def __eq__(self, other) -> bool:
return self.message == other.message and self.model == other.model
class Platform:
http: ClientSession
config: dict
url: str
api_key: str
model: str
max_words: int
additional_prompt: List[dict]
system_prompt: str
max_context_messages: int
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
self.http = http
self.config = config['platforms'][self.get_type()]
self.url = self.config['url']
self.model = self.config['model']
self.max_words = self.config['max_words']
self.api_key = self.config['api_key']
self.max_context_messages = self.config['max_context_messages']
self.additional_prompt = config['additional_prompt']
self.system_prompt = config['system_prompt']
"""a
调用AI对话接口, 响应结果
"""
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError()
def get_type(self) -> str:
raise NotImplementedError()
async def get_context(plugin: AiBotPlugin, evt: MessageEvent) -> deque:
# 创建系统提示词上下文
system_context = deque()
# 生成当前时间
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
# 加入系统提示词
system_prompt = {"role": "system",
"content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)}
if plugin.config['enable_multi_user']:
system_prompt["content"] += """
User messages are in the context of multiperson chatrooms.
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
"username: hello world."
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
"""
system_context.append(system_prompt)
# 添加额外的系统提示词和用户提示词
additional_context = json.loads(json.dumps(plugin.config['additional_prompt']))
if additional_context:
for item in additional_context:
system_context.append(item)
# 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
if len(additional_context) > plugin.config['max_context_messages'] - 1:
raise ValueError(f"sorry, my configuration has too many additional prompts "
f"({plugin.config['max_context_messages']}) and i'll never see your message. "
f"Update my config to have fewer messages and i'll be able to answer your questions!")
# 用户历史聊天上下文
chat_context = deque()
# 计算系统提示词单词数
word_count = sum([len(m["content"].split()) for m in system_context])
message_count = len(system_context) - 1
async for next_event in generate_context_messages(plugin, evt):
# 如果不是文本类型,就跳过
try:
if not next_event.content.msgtype.is_text:
continue
except (KeyError, AttributeError):
continue
# 如果当前的这条历史消息是机器人自己的那么角色就要设置为assistant
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
message = next_event['content']['body']
user = ''
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
if plugin.config['enable_multi_user']:
user = (await plugin.client.get_displayname(next_event.sender) or
plugin.client.parse_user_id(next_event.sender)[0]) + ": "
# 计算单词量和消息数
word_count += len(message.split())
message_count += 1
if word_count >= plugin.config['max_words'] or message_count >= plugin.config['max_context_messages']:
break
chat_context.appendleft({"role": role, "content": user + message})
return system_context + chat_context
async def generate_context_messages(plugin: AiBotPlugin, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
yield evt
if plugin.config['reply_in_thread']:
while evt.content.relates_to.in_reply_to:
evt = await plugin.client.get_event(room_id=evt.room_id, event_id=evt.content.get_reply_to())
yield evt
else:
event_context = await plugin.client.get_event_context(room_id=evt.room_id, event_id=evt.event_id,
limit=plugin.config["max_context_messages"] * 2)
previous_messages = iter(event_context.events_before)
for evt in previous_messages:
# We already have the event, but currently, get_event_context doesn't automatically decrypt events
if isinstance(evt, EncryptedEvent) and plugin.client.crypto:
evt = await plugin.client.get_event(event_id=evt.event_id, room_id=evt.room_id)
if not evt:
raise ValueError("Decryption error!")
yield evt

View File

@@ -0,0 +1,39 @@
from collections import deque
from typing import List
from maubot import Plugin
from mautrix.types import MessageEvent
from mautrix.util.config import BaseProxyConfig
from maubot_llmplus import AiBotPlugin
from maubot_llmplus.llm.platforms import Platform, ChatCompletion
class OpenAi(Platform):
def __init__(self, config: BaseProxyConfig) -> None:
super().__init__(config)
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词
# 获取额外的其他角色的提示词: role: user role: system
pass
def get_type(self) -> str:
return "openai"
class Anthropic(Platform):
def __init__(self, config: BaseProxyConfig) -> None:
super().__init__(config)
async def create_chat_completion(self, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词
# 获取额外的其他角色的提示词: role: user role: system
pass
def get_type(self) -> str:
return "anthropic"