Compare commits

...

2 Commits

Author SHA1 Message Date
taylor
dc5162b662 add: 添加命令ai model list的逻辑 2024-10-13 18:53:42 +08:00
taylor
7e1af58c84 add: 添加ollama调用AI chat逻辑 2024-10-13 18:31:18 +08:00
3 changed files with 40 additions and 11 deletions

View File

@@ -1,7 +1,7 @@
import re
from typing import Type
from maubot.handlers import event
from maubot.handlers import command, event
from maubot import Plugin, MessageEvent
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
from mautrix.util import markdown
@@ -15,6 +15,7 @@ from maubot_llmplus.thrid_platform import OpenAi, Anthropic
配置文件加载
"""
class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("allowed_users")
@@ -28,7 +29,6 @@ class Config(BaseProxyConfig):
class AiBotPlugin(Plugin):
name: str
async def start(self) -> None:
@@ -105,19 +105,18 @@ class AiBotPlugin(Plugin):
if parent_event.sender == self.client.mxid:
return True
@event.on(EventType.ROOM_MESSAGE)
async def on_message(self, event: MessageEvent) -> None:
if not await self.should_respond(event):
return
try:
self.log.debug("开始发送消息")
await event.mark_read()
await self.client.set_typing(event.room_id, timeout=99999)
platform = self.get_ai_platform()
chat_completion = await platform.create_chat_completion(self, event)
self.log.debug(f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
self.log.debug(
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
# ai gpt调用
# 关闭typing提示
await self.client.set_typing(event.room_id, timeout=0)
@@ -126,7 +125,6 @@ class AiBotPlugin(Plugin):
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
formatted_body=markdown.render(resp_content))
await event.respond(response, in_thread=self.config['reply_in_thread'])
self.log.debug("发送结束")
except Exception as e:
self.log.exception(f"Something went wrong: {e}")
await event.respond(f"Something went wrong: {e}")
@@ -150,6 +148,30 @@ class AiBotPlugin(Plugin):
return Anthropic(self.config, self.name, self.http)
raise ValueError(f"unknown backend type {use_platform}")
"""
父命令
"""
@command.new(name="ai", require_subcommand=True)
async def ai_command(self, event: MessageEvent) -> None:
pass
"""
"""
@ai_command.subcommand(help="")
async def info(self, event: MessageEvent) -> None:
pass
@ai_command.subcommand(help="")
@command.argument("argus")
async def model(self, event: MessageEvent, argus: str):
# 如果是list表示查看当前可以使用的模型列表
if argus == 'list':
platform = self.get_ai_platform()
models = platform.list_models()
await event.reply("\n".join(models))
# 如果不是,如果是其他的名称,表示这是一个模型名
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config

View File

@@ -1,4 +1,5 @@
import json
from typing import List
from aiohttp import ClientSession
from maubot import Plugin
@@ -25,9 +26,6 @@ class Ollama(Platform):
endpoint = f"{self.url}/api/chat"
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
headers = {'Content-Type': 'application/json'}
if self.api_key is not None:
headers['Authorization'] = self.api_key
plugin.log.debug(f"{json.dumps(req_body)}")
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
if response.status != 200:
@@ -36,8 +34,6 @@ class Ollama(Platform):
finish_reason=f"http status {response.status}",
model=None
)
text = await response.text()
plugin.log.debug(f"解析后的响应内容: {text}")
response_json = await response.json()
return ChatCompletion(
message=response_json['message'],
@@ -45,6 +41,14 @@ class Ollama(Platform):
model=response_json['model']
)
async def list_models(self) -> List[str]:
full_url = f"{self.url}/api/tags"
async with self.http.get(full_url) as response:
if response.status != 200:
return []
response_data = json.loads(await response.json())
return [model['name'] for model in response_data]
def get_type(self) -> str:
return "local_ai"

View File

@@ -54,6 +54,9 @@ class Platform:
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError()
async def list_models(self) -> List[str]:
raise NotImplementedError()
def get_type(self) -> str:
raise NotImplementedError()