Compare commits

..

2 Commits

Author SHA1 Message Date
taylor
dc5162b662 add: 添加命令ai model list的逻辑 2024-10-13 18:53:42 +08:00
taylor
7e1af58c84 add: 添加ollama调用AI chat逻辑 2024-10-13 18:31:18 +08:00
3 changed files with 40 additions and 11 deletions

View File

@@ -1,7 +1,7 @@
import re import re
from typing import Type from typing import Type
from maubot.handlers import event from maubot.handlers import command, event
from maubot import Plugin, MessageEvent from maubot import Plugin, MessageEvent
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
from mautrix.util import markdown from mautrix.util import markdown
@@ -15,6 +15,7 @@ from maubot_llmplus.thrid_platform import OpenAi, Anthropic
配置文件加载 配置文件加载
""" """
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("allowed_users") helper.copy("allowed_users")
@@ -28,7 +29,6 @@ class Config(BaseProxyConfig):
class AiBotPlugin(Plugin): class AiBotPlugin(Plugin):
name: str name: str
async def start(self) -> None: async def start(self) -> None:
@@ -105,19 +105,18 @@ class AiBotPlugin(Plugin):
if parent_event.sender == self.client.mxid: if parent_event.sender == self.client.mxid:
return True return True
@event.on(EventType.ROOM_MESSAGE) @event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None: async def on_message(self, event: MessageEvent) -> None:
if not await self.should_respond(event): if not await self.should_respond(event):
return return
try: try:
self.log.debug("开始发送消息")
await event.mark_read() await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999) await self.client.set_typing(event.room_id, timeout=99999)
platform = self.get_ai_platform() platform = self.get_ai_platform()
chat_completion = await platform.create_chat_completion(self, event) chat_completion = await platform.create_chat_completion(self, event)
self.log.debug(f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}") self.log.debug(
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
# ai gpt调用 # ai gpt调用
# 关闭typing提示 # 关闭typing提示
await self.client.set_typing(event.room_id, timeout=0) await self.client.set_typing(event.room_id, timeout=0)
@@ -126,7 +125,6 @@ class AiBotPlugin(Plugin):
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML, response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content)) formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread']) await event.respond(response, in_thread=self.config['reply_in_thread'])
self.log.debug("发送结束")
except Exception as e: except Exception as e:
self.log.exception(f"Something went wrong: {e}") self.log.exception(f"Something went wrong: {e}")
await event.respond(f"Something went wrong: {e}") await event.respond(f"Something went wrong: {e}")
@@ -150,6 +148,30 @@ class AiBotPlugin(Plugin):
return Anthropic(self.config, self.name, self.http) return Anthropic(self.config, self.name, self.http)
raise ValueError(f"unknown backend type {use_platform}") raise ValueError(f"unknown backend type {use_platform}")
"""
父命令
"""
@command.new(name="ai", require_subcommand=True)
async def ai_command(self, event: MessageEvent) -> None:
pass
"""
"""
@ai_command.subcommand(help="")
async def info(self, event: MessageEvent) -> None:
pass
@ai_command.subcommand(help="")
@command.argument("argus")
async def model(self, event: MessageEvent, argus: str):
# 如果是list表示查看当前可以使用的模型列表
if argus == 'list':
platform = self.get_ai_platform()
models = platform.list_models()
await event.reply("\n".join(models))
# 如果不是,如果是其他的名称,表示这是一个模型名
@classmethod @classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config return Config

View File

@@ -1,4 +1,5 @@
import json import json
from typing import List
from aiohttp import ClientSession from aiohttp import ClientSession
from maubot import Plugin from maubot import Plugin
@@ -25,9 +26,6 @@ class Ollama(Platform):
endpoint = f"{self.url}/api/chat" endpoint = f"{self.url}/api/chat"
req_body = {'model': self.model, 'messages': full_context, 'stream': False} req_body = {'model': self.model, 'messages': full_context, 'stream': False}
headers = {'Content-Type': 'application/json'} headers = {'Content-Type': 'application/json'}
if self.api_key is not None:
headers['Authorization'] = self.api_key
plugin.log.debug(f"{json.dumps(req_body)}")
async with self.http.post(endpoint, headers=headers, json=req_body) as response: async with self.http.post(endpoint, headers=headers, json=req_body) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}") # plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
if response.status != 200: if response.status != 200:
@@ -36,8 +34,6 @@ class Ollama(Platform):
finish_reason=f"http status {response.status}", finish_reason=f"http status {response.status}",
model=None model=None
) )
text = await response.text()
plugin.log.debug(f"解析后的响应内容: {text}")
response_json = await response.json() response_json = await response.json()
return ChatCompletion( return ChatCompletion(
message=response_json['message'], message=response_json['message'],
@@ -45,6 +41,14 @@ class Ollama(Platform):
model=response_json['model'] model=response_json['model']
) )
async def list_models(self) -> List[str]:
full_url = f"{self.url}/api/tags"
async with self.http.get(full_url) as response:
if response.status != 200:
return []
response_data = json.loads(await response.json())
return [model['name'] for model in response_data]
def get_type(self) -> str: def get_type(self) -> str:
return "local_ai" return "local_ai"

View File

@@ -54,6 +54,9 @@ class Platform:
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_models(self) -> List[str]:
raise NotImplementedError()
def get_type(self) -> str: def get_type(self) -> str:
raise NotImplementedError() raise NotImplementedError()