Compare commits
89 Commits
dc5162b662
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249f225045 | ||
|
|
448a95134f | ||
|
|
6b2fc9ea07 | ||
|
|
96373a9c14 | ||
|
|
70ea0a6916 | ||
|
|
98a4dba820 | ||
|
|
a5e43190f4 | ||
|
|
d5d634bf14 | ||
|
|
bf4d2a444c | ||
|
|
89160ce482 | ||
|
|
11e37a157d | ||
|
|
caddfb61f1 | ||
|
|
1070cf517f | ||
|
|
87d9ab789c | ||
|
|
300a7fbfd6 | ||
|
|
9f25fdab12 | ||
|
|
b53a918aaa | ||
|
|
22cb30bde0 | ||
|
|
d5e818b334 | ||
|
|
04077c7f12 | ||
|
|
94f16c4f8b | ||
|
|
66881b0d91 | ||
|
|
17c18b48dc | ||
|
|
ba962297f1 | ||
|
|
b831b7d440 | ||
|
|
51693af83d | ||
|
|
254280d63d | ||
|
|
4dc2f7646c | ||
|
|
f606978ad9 | ||
|
|
09c2ecef29 | ||
|
|
9440cda7b0 | ||
|
|
b65a00dabc | ||
|
|
11a1e86774 | ||
|
|
4ae0c25356 | ||
|
|
fd79ebd99e | ||
|
|
094033fb76 | ||
|
|
978bb9051d | ||
|
|
331709411a | ||
|
|
8805ac6413 | ||
|
|
0a917f1ba0 | ||
|
|
6196c07188 | ||
|
|
d738d498ce | ||
|
|
b7cb51da4d | ||
|
|
53ad9708bd | ||
|
|
89359c40e2 | ||
|
|
8db87b1eca | ||
|
|
4f9282e49d | ||
|
|
dbe39c2477 | ||
|
|
8874bef006 | ||
|
|
31624b3059 | ||
|
|
9dbfa2c8de | ||
|
|
50508cc9e9 | ||
|
|
f2b76f531d | ||
|
|
5506fb83bb | ||
|
|
2fd5394773 | ||
|
|
c03be10fc1 | ||
|
|
3ce7b4efe7 | ||
|
|
0082ff55af | ||
|
|
b7dd0ff347 | ||
|
|
e9dc178e83 | ||
|
|
1b8e028b83 | ||
|
|
584ffcc9c4 | ||
|
|
2c927d6659 | ||
|
|
400d628e9f | ||
|
|
69fa0c0a50 | ||
|
|
1c5e3b0038 | ||
|
|
6bec1d070f | ||
|
|
d4ba2a7c34 | ||
|
|
24b094842f | ||
|
|
3451e3591f | ||
|
|
2081e78308 | ||
|
|
25f120f18f | ||
|
|
79f667aaaa | ||
|
|
2d04636cfd | ||
|
|
e2d6acb92f | ||
|
|
0f7a2c4b33 | ||
|
|
8c16faa34a | ||
|
|
9f9e87eb5c | ||
|
|
8e82c01cab | ||
|
|
825860b6f9 | ||
|
|
296aa56c26 | ||
|
|
c5b8566d83 | ||
|
|
cae9dc5c78 | ||
|
|
5327e9b572 | ||
|
|
b934cd399b | ||
|
|
9417a3c75c | ||
|
|
6823b8a6d5 | ||
|
|
6f957b155e | ||
|
|
3e8bf75f05 |
22
README.md
22
README.md
@@ -1,3 +1,25 @@
|
||||
# maubot-llmplus
|
||||
-------
|
||||
maubot plugin: llm plus
|
||||
|
||||
order:
|
||||
- !ai info
|
||||
> View the configuration information currently in official use.
|
||||
- !ai platform list
|
||||
> list platforms.
|
||||
- !ai platform current
|
||||
> query current platform in use.
|
||||
- !ai model list
|
||||
> list models on current platform.
|
||||
- !ai model current
|
||||
> query current model in use.
|
||||
- !ai use [model_name]
|
||||
> switch model in platform, you can use `!ai model list` command query model list.
|
||||
- !ai switch [platform_name]
|
||||
> switch platform
|
||||
> support platforms:
|
||||
> - local_ai#ollama
|
||||
> - local_ai#lmstudio
|
||||
> - openai
|
||||
> - anthropic
|
||||
|
||||
|
||||
100
base-config.yaml
100
base-config.yaml
@@ -1,39 +1,105 @@
|
||||
# allow users
|
||||
allowed_users: []
|
||||
|
||||
# allow update and read permission users
|
||||
allow_update_read_command_users: []
|
||||
|
||||
# allow readonly permission users
|
||||
allow_readonly_command_users: []
|
||||
|
||||
# current use platform
|
||||
use_platform: local_ai
|
||||
|
||||
name:
|
||||
# bot name
|
||||
name: "ai bot"
|
||||
|
||||
reply_in_thread:
|
||||
reply_in_thread: true
|
||||
|
||||
enable_multi_user:
|
||||
enable_multi_user: true
|
||||
|
||||
system_prompt: ""
|
||||
# system prompt
|
||||
system_prompt: "response in chinese"
|
||||
|
||||
# platform config
|
||||
platforms:
|
||||
local_ai:
|
||||
type: ollama
|
||||
url: http://localhost:11434
|
||||
url: http://192.168.32.162:11434
|
||||
api_key:
|
||||
model: llama3.2
|
||||
temperature: 1
|
||||
max_tokens: 2000
|
||||
max_words: 1000
|
||||
max_context_messages: 100
|
||||
max_context_messages: 20
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
qwen:
|
||||
# 国内: https://dashscope.aliyuncs.com
|
||||
# 海外: https://dashscope-intl.aliyuncs.com
|
||||
url: https://dashscope.aliyuncs.com
|
||||
api_key:
|
||||
model: qwen-plus
|
||||
temperature: 0.7
|
||||
top_p: 0.8
|
||||
max_tokens: 2000
|
||||
max_words: 1000
|
||||
max_context_messages: 20
|
||||
# 是否开启深度思考模式(仅 qwq 系列支持)
|
||||
enable_thinking: false
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
deepseek:
|
||||
url: https://api.deepseek.com
|
||||
api_key:
|
||||
model:
|
||||
max_words: 1000
|
||||
max_context_messages: 20
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
gemini:
|
||||
url: https://generativelanguage.googleapis.com
|
||||
api_key:
|
||||
model: gemini-2.0-flash
|
||||
temperature: 1
|
||||
max_tokens: 2000
|
||||
max_words: 1000
|
||||
max_context_messages: 20
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
openai:
|
||||
url:
|
||||
url: https://api.openai.com
|
||||
api_key:
|
||||
model:
|
||||
max_tokens:
|
||||
max_words:
|
||||
temperature:
|
||||
model: gpt-4o-mini
|
||||
max_tokens: 2000
|
||||
max_words: 1000
|
||||
max_context_messages: 20
|
||||
temperature: 1
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
anthropic:
|
||||
url:
|
||||
url: https://api.anthropic.com
|
||||
api_key:
|
||||
max_tokens:
|
||||
model:
|
||||
max_words:
|
||||
model: claude-3-5-sonnet-20240620
|
||||
max_words: 1000
|
||||
max_tokens: 2000
|
||||
max_context_messages: 20
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
xai:
|
||||
url: https://api.x.ai
|
||||
api_key:
|
||||
model: grok-beta
|
||||
temperature: 1
|
||||
max_tokens: 1000
|
||||
max_words: 2000
|
||||
max_context_messages: 20
|
||||
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||
streaming: false
|
||||
|
||||
|
||||
# additional prompt
|
||||
additional_prompt:
|
||||
- role: user
|
||||
content: xxx
|
||||
content: "What model is currently in use?"
|
||||
- role: system
|
||||
content: xxx
|
||||
content: "you can response text contain user name"
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from typing import Type
|
||||
@@ -9,36 +10,16 @@ from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||
|
||||
from maubot_llmplus.local_paltform import Ollama, LmStudio
|
||||
from maubot_llmplus.platforms import Platform
|
||||
from maubot_llmplus.thrid_platform import OpenAi, Anthropic
|
||||
|
||||
"""
|
||||
配置文件加载
|
||||
"""
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
|
||||
from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi, Deepseek, Gemini, Qwen
|
||||
|
||||
|
||||
class Config(BaseProxyConfig):
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
helper.copy("allowed_users")
|
||||
helper.copy("use_platform")
|
||||
helper.copy("name")
|
||||
helper.copy("reply_in_thread")
|
||||
helper.copy("enable_multi_user")
|
||||
helper.copy("system_prompt")
|
||||
helper.copy("platforms")
|
||||
helper.copy("additional_prompt")
|
||||
|
||||
|
||||
class AiBotPlugin(Plugin):
|
||||
name: str
|
||||
class AiBotPlugin(AbsExtraConfigPlugin):
|
||||
|
||||
async def start(self) -> None:
|
||||
await super().start()
|
||||
# 加载并更新配置
|
||||
self.config.load_and_update()
|
||||
# 决定当前机器人的名称
|
||||
self.name = self.config['name'] or \
|
||||
await self.client.get_displayname(self.client.mxid) or \
|
||||
self.client.parse_user_id(self.client.mxid)[0]
|
||||
|
||||
"""
|
||||
判断sender是否是allowed_users中的成员
|
||||
@@ -58,6 +39,31 @@ class AiBotPlugin(Plugin):
|
||||
self.log.debug(f"{sender} doesn't match allowed_users")
|
||||
pass
|
||||
|
||||
def is_allow_command(self, sender: str, command_key: str) -> bool:
|
||||
allow_users = self.config[command_key]
|
||||
# 如果一个都没有配置,都没有权限可以执行更新命令
|
||||
if len(allow_users) <= 0:
|
||||
return False
|
||||
|
||||
# sender是否是配置中的一员, 如果是就允许进行命令修改执行,否则只有可读命令的执行
|
||||
for u in allow_users:
|
||||
if re.match(u, sender):
|
||||
return True
|
||||
self.log.debug(f"{sender} doesn't match {command_key}")
|
||||
pass
|
||||
|
||||
def is_allow_update_read_command(self, sender: str) -> bool:
|
||||
return self.is_allow_command(sender, "allow_update_read_command_users")
|
||||
|
||||
def is_allow_readonly_command(self, sender: str) -> bool:
|
||||
is_update_read = self.is_allow_update_read_command(sender)
|
||||
# 如果读写都有权限,就一定会有读权限,返回True
|
||||
if is_update_read:
|
||||
return True
|
||||
|
||||
# 如果没有读写权限,需要判断是否有只读权限
|
||||
return self.is_allow_command(sender, "allow_readonly_command_users")
|
||||
|
||||
"""
|
||||
判断是否应该让AI进行回应
|
||||
回应条件:
|
||||
@@ -75,6 +81,10 @@ class AiBotPlugin(Plugin):
|
||||
if event.sender == self.client.mxid:
|
||||
return False
|
||||
|
||||
# 如果发送的消息中,第一个字符是感叹号,不进行回复
|
||||
if event.content.body[0] == '!':
|
||||
return False
|
||||
|
||||
# 判断这个用户是否在允许列表中, 不存在返回False
|
||||
# 如果列表为空, 继续往下执行
|
||||
if not self.is_allow(event.sender):
|
||||
@@ -86,7 +96,7 @@ class AiBotPlugin(Plugin):
|
||||
return False
|
||||
|
||||
# 检查是否发送消息中有带上机器人的别名
|
||||
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
if re.search("(^|\\s)(@)?" + self.get_bot_name() + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 当聊天室只有两个人并且其中一个是机器人时
|
||||
@@ -105,6 +115,7 @@ class AiBotPlugin(Plugin):
|
||||
if parent_event.sender == self.client.mxid:
|
||||
return True
|
||||
|
||||
|
||||
@event.on(EventType.ROOM_MESSAGE)
|
||||
async def on_message(self, event: MessageEvent) -> None:
|
||||
if not await self.should_respond(event):
|
||||
@@ -114,17 +125,26 @@ class AiBotPlugin(Plugin):
|
||||
await event.mark_read()
|
||||
await self.client.set_typing(event.room_id, timeout=99999)
|
||||
platform = self.get_ai_platform()
|
||||
|
||||
if platform.is_streaming_enabled():
|
||||
await self._handle_streaming(event, platform)
|
||||
return
|
||||
|
||||
chat_completion = await platform.create_chat_completion(self, event)
|
||||
self.log.debug(
|
||||
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
||||
# ai gpt调用
|
||||
# 关闭typing提示
|
||||
await self.client.set_typing(event.room_id, timeout=0)
|
||||
# 打开typing提示
|
||||
if chat_completion.result:
|
||||
resp_content = chat_completion.message['content']
|
||||
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||
formatted_body=markdown.render(resp_content))
|
||||
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||
else:
|
||||
resp_content = "调用失败,请检查: " + chat_completion.finish_reason
|
||||
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||
formatted_body=markdown.render(resp_content))
|
||||
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||
|
||||
except Exception as e:
|
||||
self.log.exception(f"Something went wrong: {e}")
|
||||
await event.respond(f"Something went wrong: {e}")
|
||||
@@ -132,21 +152,94 @@ class AiBotPlugin(Plugin):
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
|
||||
# 发送初始占位消息;on_message 已设 typing=on,等收到第一个 chunk 再关掉
|
||||
placeholder = TextMessageEventContent(
|
||||
msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌"
|
||||
)
|
||||
response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread'])
|
||||
self.log.debug("Streaming: placeholder sent")
|
||||
|
||||
accumulated = ""
|
||||
last_edit_len = 0
|
||||
first_chunk = True
|
||||
EDIT_THRESHOLD = 100
|
||||
|
||||
async def send_edit(content: TextMessageEventContent) -> None:
|
||||
# shield 防止 wait_for 超时时 cancel send_task,保护 mautrix 内部锁不残留
|
||||
send_task = asyncio.ensure_future(self.client.send_message(evt.room_id, content))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(send_task), timeout=8.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.log.debug("Streaming: edit timed out, waiting naturally")
|
||||
await send_task
|
||||
except Exception as e:
|
||||
self.log.warning(f"Streaming: edit error: {e}")
|
||||
if not send_task.done():
|
||||
await send_task
|
||||
|
||||
try:
|
||||
async for chunk in platform.create_chat_completion_stream(self, evt):
|
||||
if first_chunk:
|
||||
# 收到第一个 chunk 才关掉 typing,等待期间用户可见 typing 指示器
|
||||
await self.client.set_typing(evt.room_id, timeout=0)
|
||||
first_chunk = False
|
||||
accumulated += chunk
|
||||
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
|
||||
display = accumulated + " ▌"
|
||||
new_content = TextMessageEventContent(
|
||||
msgtype=MessageType.TEXT,
|
||||
body=display,
|
||||
format=Format.HTML,
|
||||
formatted_body=markdown.render(display)
|
||||
)
|
||||
new_content.set_edit(response_event_id)
|
||||
self.log.debug(f"Streaming: mid-edit, accumulated={len(accumulated)}")
|
||||
await send_edit(new_content)
|
||||
last_edit_len = len(accumulated)
|
||||
except Exception as e:
|
||||
self.log.exception(f"Streaming error: {e}")
|
||||
if not accumulated:
|
||||
accumulated = f"Streaming error: {e}"
|
||||
finally:
|
||||
if first_chunk:
|
||||
await self.client.set_typing(evt.room_id, timeout=0)
|
||||
|
||||
self.log.debug(f"Streaming: loop done, total={len(accumulated)}")
|
||||
|
||||
if not accumulated:
|
||||
accumulated = "(无响应)"
|
||||
final_content = TextMessageEventContent(
|
||||
msgtype=MessageType.TEXT,
|
||||
body=accumulated,
|
||||
format=Format.HTML,
|
||||
formatted_body=markdown.render(accumulated)
|
||||
)
|
||||
final_content.set_edit(response_event_id)
|
||||
self.log.debug("Streaming: sending final edit")
|
||||
await send_edit(final_content)
|
||||
self.log.debug("Streaming: final edit done")
|
||||
|
||||
def get_ai_platform(self) -> Platform:
|
||||
use_platform = self.config['use_platform']
|
||||
if use_platform == 'local_ai':
|
||||
type = self.config['platforms']['local_ai']['type']
|
||||
if type == 'ollama':
|
||||
return Ollama(self.config, self.name, self.http)
|
||||
elif type == 'lmstudio':
|
||||
return LmStudio(self.config, self.name, self.http)
|
||||
use_platform = self.config.cur_platform
|
||||
if use_platform == 'openai':
|
||||
return OpenAi(self.config, self.http)
|
||||
if use_platform == 'anthropic':
|
||||
return Anthropic(self.config, self.http)
|
||||
if use_platform == 'xai':
|
||||
return XAi(self.config, self.http)
|
||||
if use_platform == 'deepseek':
|
||||
return Deepseek(self.config, self.http)
|
||||
if use_platform == 'gemini':
|
||||
return Gemini(self.config, self.http)
|
||||
if use_platform == 'qwen':
|
||||
return Qwen(self.config, self.http)
|
||||
if use_platform == 'local_ai#ollama':
|
||||
return Ollama(self.config, self.http)
|
||||
if use_platform == 'local_ai#lmstudio':
|
||||
return LmStudio(self.config, self.http)
|
||||
else:
|
||||
raise ValueError(f"not found platform type: {type}")
|
||||
if use_platform == 'openai':
|
||||
return OpenAi(self.config, self.name, self.http)
|
||||
if use_platform == 'anthropic':
|
||||
return Anthropic(self.config, self.name, self.http)
|
||||
raise ValueError(f"unknown backend type {use_platform}")
|
||||
|
||||
"""
|
||||
父命令
|
||||
@@ -157,20 +250,142 @@ class AiBotPlugin(Plugin):
|
||||
|
||||
"""
|
||||
"""
|
||||
@ai_command.subcommand(help="")
|
||||
@ai_command.subcommand(help="View the configuration information currently in official use")
|
||||
async def info(self, event: MessageEvent) -> None:
|
||||
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||
is_allow = self.is_allow_readonly_command(event.sender)
|
||||
if not is_allow:
|
||||
await event.reply(f"{event.sender} have not read permission")
|
||||
return
|
||||
|
||||
show_infos = []
|
||||
# 当前机器人名称
|
||||
show_infos.append(f"bot name: {self.get_bot_name()}\n\n")
|
||||
# 查询当前使用的ai平台
|
||||
show_infos.append(f"platform: {self.get_cur_platform()}\n\n")
|
||||
show_infos.append("platform detail: \n\n")
|
||||
# 查询当前ai平台的配置信息
|
||||
p_m_dict = self.get_config()
|
||||
for k, v in p_m_dict.items():
|
||||
show_infos.append(f"- {k}: {v}\n")
|
||||
# 当前使用的model
|
||||
show_infos.append(f"\nmodel: {self.config.cur_model}\n")
|
||||
# TODO 列出model信息
|
||||
await event.reply("".join(show_infos), markdown=True)
|
||||
pass
|
||||
|
||||
@ai_command.subcommand(help="")
|
||||
"""
|
||||
获取配置信息
|
||||
"""
|
||||
def get_config(self) -> dict:
|
||||
platform_config_dict = dict(self.config['platforms'][self.get_cur_platform()])
|
||||
# 移除敏感配置
|
||||
platform_config_dict.pop('api_key')
|
||||
platform_config_dict.pop('url')
|
||||
return platform_config_dict
|
||||
|
||||
"""
|
||||
获取实际平台名称
|
||||
"""
|
||||
def get_cur_platform(self) -> str:
|
||||
platform_model = self.config.cur_platform
|
||||
return platform_model.split('#')[0]
|
||||
|
||||
@ai_command.subcommand(help="List platforms or query current platform in use")
|
||||
@command.argument("argus")
|
||||
async def platform(self, event: MessageEvent, argus: str):
|
||||
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||
is_allow = self.is_allow_readonly_command(event.sender)
|
||||
if not is_allow:
|
||||
await event.reply(f"{event.sender} have not read permission")
|
||||
return
|
||||
|
||||
if argus == 'list':
|
||||
p_dict = dict(self.config['platforms'])
|
||||
platforms = [f"- {platform}" for platform in set(p_dict.keys())]
|
||||
await event.reply("\n".join(platforms))
|
||||
pass
|
||||
if argus == 'current':
|
||||
await event.reply(f"current use platform is {self.config.cur_platform}")
|
||||
pass
|
||||
|
||||
@ai_command.subcommand(help="List models or query current model in use")
|
||||
@command.argument("argus")
|
||||
async def model(self, event: MessageEvent, argus: str):
|
||||
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||
is_allow = self.is_allow_readonly_command(event.sender)
|
||||
if not is_allow:
|
||||
await event.reply(f"{event.sender} have not read permission")
|
||||
return
|
||||
|
||||
# 如果是list表示查看当前可以使用的模型列表
|
||||
if argus == 'list':
|
||||
platform = self.get_ai_platform()
|
||||
models = platform.list_models()
|
||||
await event.reply("\n".join(models))
|
||||
models = await platform.list_models()
|
||||
await event.reply("\n".join(models), markdown=True)
|
||||
pass
|
||||
# 如果是current,显示出当前的使用模型
|
||||
if argus == 'current':
|
||||
await event.reply(f"current use model is {self.config.cur_model}")
|
||||
pass
|
||||
|
||||
# 如果不是,如果是其他的名称,表示这是一个模型名
|
||||
@ai_command.subcommand(help="switch model in platform")
|
||||
@command.argument("argus")
|
||||
async def use(self, event: MessageEvent, argus: str):
|
||||
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||
is_allow = self.is_allow_update_read_command(event.sender)
|
||||
if not is_allow:
|
||||
await event.reply(f"{event.sender} have not update permission")
|
||||
return
|
||||
|
||||
platform = self.get_ai_platform()
|
||||
# 获取模型列表,判断使用的模型是否存在于列表中
|
||||
models = await platform.list_models()
|
||||
if f"- {argus}" in models:
|
||||
self.log.debug(f"switch model: {argus}")
|
||||
self.config.cur_model = argus
|
||||
await event.react("✅")
|
||||
else:
|
||||
await event.reply("not found valid model")
|
||||
|
||||
@ai_command.subcommand(help="switch platform")
|
||||
@command.argument("argus")
|
||||
async def switch(self, event: MessageEvent, argus: str):
|
||||
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||
is_allow = self.is_allow_update_read_command(event.sender)
|
||||
if not is_allow:
|
||||
await event.reply(f"{event.sender} have not update permission")
|
||||
return
|
||||
|
||||
# 判断是否是本地ai模型,如果是还需要解析#后的type
|
||||
if argus == 'local_ai':
|
||||
await event.reply("local ai platform has ollama and lmstudio. "
|
||||
"you can type `!ai use local_ai#{type}`. "
|
||||
"Example: local_ai#ollama")
|
||||
pass
|
||||
if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio':
|
||||
if argus == self.config.cur_platform:
|
||||
await event.reply(f"current ai platform has be {argus}")
|
||||
pass
|
||||
else:
|
||||
self.config.cur_platform = argus
|
||||
self.config.cur_model = self.config['platforms'][argus.split("#")[0]]['model']
|
||||
await event.react("✅")
|
||||
# 如果是openai或者是claude
|
||||
elif argus == 'openai' or argus == 'anthropic' or argus == 'xai' or argus == 'deepseek' or argus == 'gemini' or argus == 'qwen':
|
||||
if argus == self.config.cur_platform:
|
||||
await event.reply(f"current ai platform has be {argus}")
|
||||
pass
|
||||
else:
|
||||
self.config.cur_platform = argus
|
||||
# 使用配置的默认模型
|
||||
self.config.cur_model = self.config['platforms'][argus]['model']
|
||||
await event.react("✅")
|
||||
else:
|
||||
await event.reply(f"nof found ai platform: {argus}")
|
||||
pass
|
||||
self.log.debug(f"switch platform: {self.config.cur_platform}")
|
||||
self.log.debug(f"use default config model: {self.config.cur_model}")
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||
|
||||
@@ -1,24 +1,30 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from maubot import Plugin
|
||||
|
||||
from mautrix.types import MessageEvent
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
|
||||
import maubot_llmplus
|
||||
import maubot_llmplus.platforms
|
||||
from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||
from maubot_llmplus.thrid_platform import _read_openai_sse
|
||||
|
||||
|
||||
class Ollama(Platform):
|
||||
chat_api: str
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
||||
super().__init__(config, name, http)
|
||||
self.chat_api = '/api/chat'
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
@@ -27,37 +33,123 @@ class Ollama(Platform):
|
||||
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"http status {response.status}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=response_json['message'],
|
||||
finish_reason='success',
|
||||
model=response_json['model']
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint = f"{self.url}/api/chat"
|
||||
req_body = {'model': self.model, 'messages': full_context, 'stream': True}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: http status {response.status}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if data.get("done"):
|
||||
break
|
||||
content = data.get("message", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/api/tags"
|
||||
async with self.http.get(full_url) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = json.loads(await response.json())
|
||||
return [model['name'] for model in response_data]
|
||||
response_data = await response.json()
|
||||
return [f"- {model['model']}" for model in response_data['models']]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "local_ai"
|
||||
|
||||
|
||||
class LmStudio(Platform):
|
||||
temperature: int
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
||||
super().__init__(config, name, http)
|
||||
pass
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
pass
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
headers = {"content-type": "application/json"}
|
||||
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": False}
|
||||
async with self.http.post(
|
||||
endpoint, headers=headers, data=json.dumps(req_body)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
choice = response_json["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice["finish_reason"],
|
||||
model=choice.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
headers = {"content-type": "application/json"}
|
||||
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": True}
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1/models"
|
||||
async with self.http.get(full_url) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [f"- {m['id']}" for m in response_data["data"]]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "local_ai"
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Generator
|
||||
from typing import Optional, List, Generator, AsyncIterator
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from maubot import Plugin
|
||||
from mautrix.types import MessageEvent, EncryptedEvent
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
|
||||
|
||||
"""
|
||||
AI响应对象
|
||||
@@ -14,7 +15,8 @@ from mautrix.util.config import BaseProxyConfig
|
||||
|
||||
|
||||
class ChatCompletion:
|
||||
def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
||||
def __init__(self, result: bool, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
||||
self.result = result
|
||||
self.message = message
|
||||
self.finish_reason = finish_reason
|
||||
self.model = model
|
||||
@@ -33,43 +35,49 @@ class Platform:
|
||||
additional_prompt: List[dict]
|
||||
system_prompt: str
|
||||
max_context_messages: int
|
||||
name: str
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
||||
def __init__(self, config: Config, http: ClientSession) -> None:
|
||||
self.http = http
|
||||
self.config = config['platforms'][self.get_type()]
|
||||
self.url = self.config['url']
|
||||
self.model = self.config['model']
|
||||
# 设置当前的使用模型,这里不直接使用config对象下的配置值,而是加入了与命令决定后的使用模型名称
|
||||
self.model = config.cur_model
|
||||
self.max_words = self.config['max_words']
|
||||
self.api_key = self.config['api_key']
|
||||
self.max_context_messages = self.config['max_context_messages']
|
||||
self.additional_prompt = config['additional_prompt']
|
||||
self.system_prompt = config['system_prompt']
|
||||
self.name = name
|
||||
|
||||
"""a
|
||||
调用AI对话接口, 响应结果
|
||||
"""
|
||||
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> AsyncIterator[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_type(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
|
||||
async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) -> deque:
|
||||
"""
|
||||
获取系统提示上下文
|
||||
"""
|
||||
async def get_system_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent) -> deque:
|
||||
# 创建系统提示词上下文
|
||||
system_context = deque()
|
||||
# 生成当前时间
|
||||
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
|
||||
# 加入系统提示词
|
||||
system_prompt = {"role": "system",
|
||||
"content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)}
|
||||
"content": plugin.config['system_prompt'].format(name=plugin.get_bot_name(), timestamp=timestamp)}
|
||||
if plugin.config['enable_multi_user']:
|
||||
system_prompt["content"] += """
|
||||
User messages are in the context of multiperson chatrooms.
|
||||
@@ -91,7 +99,12 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
||||
raise ValueError(f"sorry, my configuration has too many additional prompts "
|
||||
f"({platform.max_context_messages}) and i'll never see your message. "
|
||||
f"Update my config to have fewer messages and i'll be able to answer your questions!")
|
||||
return system_context
|
||||
|
||||
"""
|
||||
获取聊天信息上下文
|
||||
"""
|
||||
async def get_chat_context(system_context: deque, plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent, hasAssistant: bool=True) -> deque:
|
||||
# 用户历史聊天上下文
|
||||
chat_context = deque()
|
||||
# 计算系统提示词单词数
|
||||
@@ -105,8 +118,16 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
||||
except (KeyError, AttributeError):
|
||||
continue
|
||||
|
||||
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
||||
# 如果没有assistant的角色,那么如果当前的对话消息是机器人的,忽略不要
|
||||
if not hasAssistant:
|
||||
if plugin.client.mxid == next_event.sender:
|
||||
continue
|
||||
else :
|
||||
role = 'user'
|
||||
else :
|
||||
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
|
||||
|
||||
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
||||
message = next_event['content']['body']
|
||||
user = ''
|
||||
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
|
||||
@@ -121,11 +142,16 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
||||
break
|
||||
chat_context.appendleft({"role": role, "content": user + message})
|
||||
|
||||
return chat_context
|
||||
|
||||
"""
|
||||
获取总消息上下文
|
||||
"""
|
||||
async def get_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent, hasAssistant: bool=True) -> deque:
|
||||
system_context = await get_system_context(plugin, platform, evt)
|
||||
chat_context = await get_chat_context(system_context, plugin, platform, evt, hasAssistant)
|
||||
return system_context + chat_context
|
||||
|
||||
|
||||
|
||||
|
||||
async def generate_context_messages(plugin: Plugin, platform: Platform, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
|
||||
yield evt
|
||||
if plugin.config['reply_in_thread']:
|
||||
|
||||
43
maubot_llmplus/plugin.py
Normal file
43
maubot_llmplus/plugin.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from maubot import Plugin
|
||||
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||
|
||||
|
||||
class AbsExtraConfigPlugin(Plugin):
|
||||
default_username: str
|
||||
user_id: str
|
||||
|
||||
async def start(self) -> None:
|
||||
await super().start()
|
||||
self.default_username = await self.client.get_displayname(self.client.mxid)
|
||||
self.user_id = self.client.parse_user_id(self.client.mxid)[0]
|
||||
|
||||
def get_bot_name(self) -> str:
|
||||
return self.config['name'] or \
|
||||
self.default_username or \
|
||||
self.user_id
|
||||
|
||||
|
||||
"""
|
||||
配置文件加载
|
||||
"""
|
||||
|
||||
|
||||
class Config(BaseProxyConfig):
|
||||
cur_model: str
|
||||
cur_platform: str
|
||||
|
||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||
helper.copy("allowed_users")
|
||||
helper.copy("use_platform")
|
||||
helper.copy("name")
|
||||
helper.copy("reply_in_thread")
|
||||
helper.copy("enable_multi_user")
|
||||
helper.copy("system_prompt")
|
||||
helper.copy("platforms")
|
||||
helper.copy("additional_prompt")
|
||||
helper.copy("allow_update_read_command_users")
|
||||
helper.copy("allow_readonly_command_users")
|
||||
|
||||
self.cur_platform = helper.base['use_platform'] if helper.base['use_platform'] != 'local_ai' else \
|
||||
f"{helper.base['use_platform']}#{helper.base['platforms']['local_ai']['type']}"
|
||||
self.cur_model = helper.base['platforms'][helper.base['use_platform']]['model']
|
||||
@@ -1,36 +1,648 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
|
||||
from typing import List
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from maubot import Plugin
|
||||
from mautrix.types import MessageEvent
|
||||
from mautrix.util.config import BaseProxyConfig
|
||||
|
||||
import maubot_llmplus.platforms
|
||||
from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||
|
||||
|
||||
async def _read_openai_sse(response):
|
||||
"""读取 OpenAI 兼容格式的 SSE 流,yield 每个 delta content"""
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class Deepseek(Platform):
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession):
|
||||
super().__init__(config, http)
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
}
|
||||
|
||||
endpoint = f"{self.url}/chat/completions"
|
||||
async with self.http.post(
|
||||
endpoint, headers=headers, data=json.dumps(data)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
choice = response_json["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice["finish_reason"],
|
||||
model=response_json.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
endpoint = f"{self.url}/chat/completions"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = ["deepseek-chat", "deepseek-reasoner"]
|
||||
return [f"- {m}" for m in models]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "deepseek"
|
||||
|
||||
class OpenAi(Platform):
|
||||
max_tokens: int
|
||||
temperature: float
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
||||
super().__init__(config, name, http)
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
# 获取系统提示词
|
||||
# 获取额外的其他角色的提示词: role: user role: system
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
pass
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
# 如果是gpt5的,就用max_completion_tokens
|
||||
if 'gpt-5' in self.model:
|
||||
data["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
# 如果是gpt4之前的,就是用max_tokens
|
||||
data["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
data["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(
|
||||
endpoint, headers=headers, data=json.dumps(data)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
choice = response_json["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice["finish_reason"],
|
||||
model=choice.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
if 'gpt-5' in self.model:
|
||||
data["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
data["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
data["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# 调用openai接口获取模型列表
|
||||
full_url = f"{self.url}/v1/models"
|
||||
headers = {'Authorization': f"Bearer {self.api_key}"}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [f"- {m['id']}" for m in response_data["data"]]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "openai"
|
||||
|
||||
|
||||
class Anthropic(Platform):
|
||||
max_tokens: int
|
||||
streaming: bool
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
||||
super().__init__(config, name, http)
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
||||
# 获取系统提示词
|
||||
# 获取额外的其他角色的提示词: role: user role: system
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
def _build_request(self, full_chat_context: list) -> tuple:
|
||||
endpoint = f"{self.url}/v1/messages"
|
||||
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
|
||||
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
|
||||
"messages": full_chat_context}
|
||||
return endpoint, headers, req_body
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_chat_context = []
|
||||
system_context = deque()
|
||||
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||
full_chat_context.extend(list(chat_context))
|
||||
|
||||
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
text = "\n\n".join(c["text"] for c in response_json["content"])
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=dict(role="assistant", content=text),
|
||||
finish_reason=response_json['stop_reason'],
|
||||
model=response_json['model']
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_chat_context = []
|
||||
system_context = deque()
|
||||
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||
full_chat_context.extend(list(chat_context))
|
||||
|
||||
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||
req_body["stream"] = True
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
if data.get("type") == "message_stop":
|
||||
break
|
||||
if data.get("type") == "content_block_delta":
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
yield delta.get("text", "")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1/models"
|
||||
headers = {
|
||||
'anthropic-version': "2023-06-01",
|
||||
'X-Api-Key': f"{self.api_key}"
|
||||
}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [f"- {m['id']}" for m in response_data['data']]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "anthropic"
|
||||
|
||||
|
||||
class Gemini(Platform):
|
||||
max_tokens: int
|
||||
temperature: float
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
def _build_gemini_request(self, context) -> tuple:
|
||||
system_parts = []
|
||||
contents = []
|
||||
for msg in context:
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
if role == 'system':
|
||||
system_parts.append({"text": content})
|
||||
elif role == 'assistant':
|
||||
contents.append({"role": "model", "parts": [{"text": content}]})
|
||||
else:
|
||||
contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
request_body = {
|
||||
"contents": contents,
|
||||
"generationConfig": {}
|
||||
}
|
||||
|
||||
if system_parts:
|
||||
request_body["system_instruction"] = {"parts": system_parts}
|
||||
|
||||
if self.max_tokens:
|
||||
request_body["generationConfig"]["maxOutputTokens"] = self.max_tokens
|
||||
|
||||
if self.temperature:
|
||||
request_body["generationConfig"]["temperature"] = self.temperature
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key
|
||||
}
|
||||
return request_body, headers
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
request_body, headers = self._build_gemini_request(context)
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:generateContent"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
candidate = response_json["candidates"][0]
|
||||
text = "".join(part["text"] for part in candidate["content"]["parts"])
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message={"role": "assistant", "content": text},
|
||||
finish_reason=candidate.get("finishReason", "STOP"),
|
||||
model=response_json.get("modelVersion", self.model)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
request_body, headers = self._build_gemini_request(context)
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:streamGenerateContent?alt=sse"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
# 与 Anthropic 保持一致:内联 while 循环,避免双层异步生成器代理
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates:
|
||||
continue
|
||||
candidate = candidates[0]
|
||||
# 先 yield 文本,再判断是否结束(对齐 OpenAI [DONE] 逻辑)
|
||||
parts = candidate.get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason:
|
||||
logging.getLogger("instance/aibot").debug(
|
||||
f"Gemini stream finished: finishReason={finish_reason}"
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1beta/models"
|
||||
headers = {"x-goog-api-key": self.api_key}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [
|
||||
f"- {m['name'].replace('models/', '')}"
|
||||
for m in response_data.get("models", [])
|
||||
if "generateContent" in m.get("supportedGenerationMethods", [])
|
||||
]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
|
||||
class XAi(Platform):
|
||||
max_tokens: int
|
||||
temperature: int
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
request_body = {
|
||||
"messages": full_context,
|
||||
"model": self.model,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
request_body["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
request_body["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
choice = response_json["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice["finish_reason"],
|
||||
model=response_json["model"]
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
request_body = {
|
||||
"messages": full_context,
|
||||
"model": self.model,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
request_body["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
request_body["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1/models"
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f"Bearer {self.api_key}"}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [f"- {m['id']}" for m in response_data["data"]]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "xai"
|
||||
|
||||
|
||||
class Qwen(Platform):
|
||||
max_tokens: int
|
||||
temperature: float
|
||||
top_p: float
|
||||
enable_thinking: bool
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||
self.top_p = float(self.config['top_p']) if self.config.get('top_p') is not None else None
|
||||
self.enable_thinking = self.config['enable_thinking']
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
def _build_qwen_request(self, full_context: list) -> tuple:
|
||||
parameters = {
|
||||
"result_format": "message"
|
||||
}
|
||||
if self.max_tokens:
|
||||
parameters["max_tokens"] = self.max_tokens
|
||||
if self.temperature is not None:
|
||||
parameters["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
parameters["top_p"] = self.top_p
|
||||
if self.enable_thinking:
|
||||
parameters["enable_thinking"] = True
|
||||
|
||||
request_body = {
|
||||
"model": self.model,
|
||||
"input": {
|
||||
"messages": full_context
|
||||
},
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
endpoint = f"{self.url}/api/v1/services/aigc/text-generation/generation"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
return endpoint, headers, request_body
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
choice = response_json["output"]["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
model=response_json.get("model", self.model)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||
# DashScope SSE 流式:增加 header 和 incremental_output 参数(每次只返回增量)
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
request_body["parameters"]["incremental_output"] = True
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("output", {}).get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = [
|
||||
"qwen-max", "qwen-max-latest",
|
||||
"qwen-plus", "qwen-plus-latest",
|
||||
"qwen-turbo", "qwen-turbo-latest",
|
||||
"qwen-long",
|
||||
"qwen3-235b-a22b", "qwen3-30b-a3b",
|
||||
"qwq-plus", "qwq-plus-latest",
|
||||
]
|
||||
return [f"- {m}" for m in models]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "qwen"
|
||||
|
||||
Reference in New Issue
Block a user