Compare commits
2 Commits
78b44a08fc
...
dc5162b662
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc5162b662 | ||
|
|
7e1af58c84 |
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
|
||||
from typing import Type
|
||||
from maubot.handlers import event
|
||||
from maubot.handlers import command, event
|
||||
from maubot import Plugin, MessageEvent
|
||||
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
|
||||
from mautrix.util import markdown
|
||||
@@ -15,6 +15,7 @@ from maubot_llmplus.thrid_platform import OpenAi, Anthropic
|
||||
配置文件加载
|
||||
"""
|
||||
|
||||
|
||||
class Config(BaseProxyConfig):
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
helper.copy("allowed_users")
|
||||
@@ -28,7 +29,6 @@ class Config(BaseProxyConfig):
|
||||
|
||||
|
||||
class AiBotPlugin(Plugin):
|
||||
|
||||
name: str
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -105,19 +105,18 @@ class AiBotPlugin(Plugin):
|
||||
if parent_event.sender == self.client.mxid:
|
||||
return True
|
||||
|
||||
|
||||
@event.on(EventType.ROOM_MESSAGE)
|
||||
async def on_message(self, event: MessageEvent) -> None:
|
||||
if not await self.should_respond(event):
|
||||
return
|
||||
|
||||
try:
|
||||
self.log.debug("开始发送消息")
|
||||
await event.mark_read()
|
||||
await self.client.set_typing(event.room_id, timeout=99999)
|
||||
platform = self.get_ai_platform()
|
||||
chat_completion = await platform.create_chat_completion(self, event)
|
||||
self.log.debug(f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
||||
self.log.debug(
|
||||
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
||||
# ai gpt调用
|
||||
# 关闭typing提示
|
||||
await self.client.set_typing(event.room_id, timeout=0)
|
||||
@@ -126,7 +125,6 @@ class AiBotPlugin(Plugin):
|
||||
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||
formatted_body=markdown.render(resp_content))
|
||||
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||
self.log.debug("发送结束")
|
||||
except Exception as e:
|
||||
self.log.exception(f"Something went wrong: {e}")
|
||||
await event.respond(f"Something went wrong: {e}")
|
||||
@@ -150,6 +148,30 @@ class AiBotPlugin(Plugin):
|
||||
return Anthropic(self.config, self.name, self.http)
|
||||
raise ValueError(f"unknown backend type {use_platform}")
|
||||
|
||||
"""
|
||||
父命令
|
||||
"""
|
||||
@command.new(name="ai", require_subcommand=True)
|
||||
async def ai_command(self, event: MessageEvent) -> None:
|
||||
pass
|
||||
|
||||
"""
|
||||
"""
|
||||
@ai_command.subcommand(help="")
|
||||
async def info(self, event: MessageEvent) -> None:
|
||||
pass
|
||||
|
||||
@ai_command.subcommand(help="")
|
||||
@command.argument("argus")
|
||||
async def model(self, event: MessageEvent, argus: str):
|
||||
# 如果是list表示查看当前可以使用的模型列表
|
||||
if argus == 'list':
|
||||
platform = self.get_ai_platform()
|
||||
models = platform.list_models()
|
||||
await event.reply("\n".join(models))
|
||||
|
||||
# 如果不是,如果是其他的名称,表示这是一个模型名
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||
return Config
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from maubot import Plugin
|
||||
@@ -25,9 +26,6 @@ class Ollama(Platform):
|
||||
endpoint = f"{self.url}/api/chat"
|
||||
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if self.api_key is not None:
|
||||
headers['Authorization'] = self.api_key
|
||||
plugin.log.debug(f"{json.dumps(req_body)}")
|
||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
@@ -36,8 +34,6 @@ class Ollama(Platform):
|
||||
finish_reason=f"http status {response.status}",
|
||||
model=None
|
||||
)
|
||||
text = await response.text()
|
||||
plugin.log.debug(f"解析后的响应内容: {text}")
|
||||
response_json = await response.json()
|
||||
return ChatCompletion(
|
||||
message=response_json['message'],
|
||||
@@ -45,6 +41,14 @@ class Ollama(Platform):
|
||||
model=response_json['model']
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/api/tags"
|
||||
async with self.http.get(full_url) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = json.loads(await response.json())
|
||||
return [model['name'] for model in response_data]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "local_ai"
|
||||
|
||||
|
||||
@@ -54,6 +54,9 @@ class Platform:
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user