Compare commits
91 Commits
78b44a08fc
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249f225045 | ||
|
|
448a95134f | ||
|
|
6b2fc9ea07 | ||
|
|
96373a9c14 | ||
|
|
70ea0a6916 | ||
|
|
98a4dba820 | ||
|
|
a5e43190f4 | ||
|
|
d5d634bf14 | ||
|
|
bf4d2a444c | ||
|
|
89160ce482 | ||
|
|
11e37a157d | ||
|
|
caddfb61f1 | ||
|
|
1070cf517f | ||
|
|
87d9ab789c | ||
|
|
300a7fbfd6 | ||
|
|
9f25fdab12 | ||
|
|
b53a918aaa | ||
|
|
22cb30bde0 | ||
|
|
d5e818b334 | ||
|
|
04077c7f12 | ||
|
|
94f16c4f8b | ||
|
|
66881b0d91 | ||
|
|
17c18b48dc | ||
|
|
ba962297f1 | ||
|
|
b831b7d440 | ||
|
|
51693af83d | ||
|
|
254280d63d | ||
|
|
4dc2f7646c | ||
|
|
f606978ad9 | ||
|
|
09c2ecef29 | ||
|
|
9440cda7b0 | ||
|
|
b65a00dabc | ||
|
|
11a1e86774 | ||
|
|
4ae0c25356 | ||
|
|
fd79ebd99e | ||
|
|
094033fb76 | ||
|
|
978bb9051d | ||
|
|
331709411a | ||
|
|
8805ac6413 | ||
|
|
0a917f1ba0 | ||
|
|
6196c07188 | ||
|
|
d738d498ce | ||
|
|
b7cb51da4d | ||
|
|
53ad9708bd | ||
|
|
89359c40e2 | ||
|
|
8db87b1eca | ||
|
|
4f9282e49d | ||
|
|
dbe39c2477 | ||
|
|
8874bef006 | ||
|
|
31624b3059 | ||
|
|
9dbfa2c8de | ||
|
|
50508cc9e9 | ||
|
|
f2b76f531d | ||
|
|
5506fb83bb | ||
|
|
2fd5394773 | ||
|
|
c03be10fc1 | ||
|
|
3ce7b4efe7 | ||
|
|
0082ff55af | ||
|
|
b7dd0ff347 | ||
|
|
e9dc178e83 | ||
|
|
1b8e028b83 | ||
|
|
584ffcc9c4 | ||
|
|
2c927d6659 | ||
|
|
400d628e9f | ||
|
|
69fa0c0a50 | ||
|
|
1c5e3b0038 | ||
|
|
6bec1d070f | ||
|
|
d4ba2a7c34 | ||
|
|
24b094842f | ||
|
|
3451e3591f | ||
|
|
2081e78308 | ||
|
|
25f120f18f | ||
|
|
79f667aaaa | ||
|
|
2d04636cfd | ||
|
|
e2d6acb92f | ||
|
|
0f7a2c4b33 | ||
|
|
8c16faa34a | ||
|
|
9f9e87eb5c | ||
|
|
8e82c01cab | ||
|
|
825860b6f9 | ||
|
|
296aa56c26 | ||
|
|
c5b8566d83 | ||
|
|
cae9dc5c78 | ||
|
|
5327e9b572 | ||
|
|
b934cd399b | ||
|
|
9417a3c75c | ||
|
|
6823b8a6d5 | ||
|
|
6f957b155e | ||
|
|
3e8bf75f05 | ||
|
|
dc5162b662 | ||
|
|
7e1af58c84 |
24
README.md
24
README.md
@@ -1,3 +1,25 @@
|
|||||||
# maubot-llmplus
|
# maubot-llmplus
|
||||||
-------
|
-------
|
||||||
maubot plugin: llm plus
|
maubot plugin: llm plus
|
||||||
|
|
||||||
|
order:
|
||||||
|
- !ai info
|
||||||
|
> View the configuration information currently in official use.
|
||||||
|
- !ai platform list
|
||||||
|
> list platforms.
|
||||||
|
- !ai platform current
|
||||||
|
> query current platform in use.
|
||||||
|
- !ai model list
|
||||||
|
> list models on current platform.
|
||||||
|
- !ai model current
|
||||||
|
> query current model in use.
|
||||||
|
- !ai use [model_name]
|
||||||
|
> switch model in platform, you can use `!ai model list` command query model list.
|
||||||
|
- !ai switch [platform_name]
|
||||||
|
> switch platform
|
||||||
|
> support platforms:
|
||||||
|
> - local_ai#ollama
|
||||||
|
> - local_ai#lmstudio
|
||||||
|
> - openai
|
||||||
|
> - anthropic
|
||||||
|
|
||||||
|
|||||||
100
base-config.yaml
100
base-config.yaml
@@ -1,39 +1,105 @@
|
|||||||
|
# allow users
|
||||||
allowed_users: []
|
allowed_users: []
|
||||||
|
|
||||||
|
# allow update and read permission users
|
||||||
|
allow_update_read_command_users: []
|
||||||
|
|
||||||
|
# allow readonly permission users
|
||||||
|
allow_readonly_command_users: []
|
||||||
|
|
||||||
|
# current use platform
|
||||||
use_platform: local_ai
|
use_platform: local_ai
|
||||||
|
|
||||||
name:
|
# bot name
|
||||||
|
name: "ai bot"
|
||||||
|
|
||||||
reply_in_thread:
|
reply_in_thread: true
|
||||||
|
|
||||||
enable_multi_user:
|
enable_multi_user: true
|
||||||
|
|
||||||
system_prompt: ""
|
# system prompt
|
||||||
|
system_prompt: "response in chinese"
|
||||||
|
|
||||||
|
# platform config
|
||||||
platforms:
|
platforms:
|
||||||
local_ai:
|
local_ai:
|
||||||
type: ollama
|
type: ollama
|
||||||
url: http://localhost:11434
|
url: http://192.168.32.162:11434
|
||||||
api_key:
|
api_key:
|
||||||
model: llama3.2
|
model: llama3.2
|
||||||
|
temperature: 1
|
||||||
|
max_tokens: 2000
|
||||||
max_words: 1000
|
max_words: 1000
|
||||||
max_context_messages: 100
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
|
qwen:
|
||||||
|
# 国内: https://dashscope.aliyuncs.com
|
||||||
|
# 海外: https://dashscope-intl.aliyuncs.com
|
||||||
|
url: https://dashscope.aliyuncs.com
|
||||||
|
api_key:
|
||||||
|
model: qwen-plus
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.8
|
||||||
|
max_tokens: 2000
|
||||||
|
max_words: 1000
|
||||||
|
max_context_messages: 20
|
||||||
|
# 是否开启深度思考模式(仅 qwq 系列支持)
|
||||||
|
enable_thinking: false
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
|
deepseek:
|
||||||
|
url: https://api.deepseek.com
|
||||||
|
api_key:
|
||||||
|
model:
|
||||||
|
max_words: 1000
|
||||||
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
|
gemini:
|
||||||
|
url: https://generativelanguage.googleapis.com
|
||||||
|
api_key:
|
||||||
|
model: gemini-2.0-flash
|
||||||
|
temperature: 1
|
||||||
|
max_tokens: 2000
|
||||||
|
max_words: 1000
|
||||||
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
openai:
|
openai:
|
||||||
url:
|
url: https://api.openai.com
|
||||||
api_key:
|
api_key:
|
||||||
model:
|
model: gpt-4o-mini
|
||||||
max_tokens:
|
max_tokens: 2000
|
||||||
max_words:
|
max_words: 1000
|
||||||
temperature:
|
max_context_messages: 20
|
||||||
|
temperature: 1
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
anthropic:
|
anthropic:
|
||||||
url:
|
url: https://api.anthropic.com
|
||||||
api_key:
|
api_key:
|
||||||
max_tokens:
|
model: claude-3-5-sonnet-20240620
|
||||||
model:
|
max_words: 1000
|
||||||
max_words:
|
max_tokens: 2000
|
||||||
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
|
xai:
|
||||||
|
url: https://api.x.ai
|
||||||
|
api_key:
|
||||||
|
model: grok-beta
|
||||||
|
temperature: 1
|
||||||
|
max_tokens: 1000
|
||||||
|
max_words: 2000
|
||||||
|
max_context_messages: 20
|
||||||
|
# 是否开启流式输出(开启后 Element 中消息会逐步更新)
|
||||||
|
streaming: false
|
||||||
|
|
||||||
|
|
||||||
|
# additional prompt
|
||||||
additional_prompt:
|
additional_prompt:
|
||||||
- role: user
|
- role: user
|
||||||
content: xxx
|
content: "What model is currently in use?"
|
||||||
- role: system
|
- role: system
|
||||||
content: xxx
|
content: "you can response text contain user name"
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from maubot.handlers import event
|
from maubot.handlers import command, event
|
||||||
from maubot import Plugin, MessageEvent
|
from maubot import Plugin, MessageEvent
|
||||||
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
|
from mautrix.types import Format, TextMessageEventContent, EventType, MessageType, RelationType
|
||||||
from mautrix.util import markdown
|
from mautrix.util import markdown
|
||||||
@@ -9,36 +10,16 @@ from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
|||||||
|
|
||||||
from maubot_llmplus.local_paltform import Ollama, LmStudio
|
from maubot_llmplus.local_paltform import Ollama, LmStudio
|
||||||
from maubot_llmplus.platforms import Platform
|
from maubot_llmplus.platforms import Platform
|
||||||
from maubot_llmplus.thrid_platform import OpenAi, Anthropic
|
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
|
||||||
|
from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi, Deepseek, Gemini, Qwen
|
||||||
"""
|
|
||||||
配置文件加载
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Config(BaseProxyConfig):
|
|
||||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
|
||||||
helper.copy("allowed_users")
|
|
||||||
helper.copy("use_platform")
|
|
||||||
helper.copy("name")
|
|
||||||
helper.copy("reply_in_thread")
|
|
||||||
helper.copy("enable_multi_user")
|
|
||||||
helper.copy("system_prompt")
|
|
||||||
helper.copy("platforms")
|
|
||||||
helper.copy("additional_prompt")
|
|
||||||
|
|
||||||
|
|
||||||
class AiBotPlugin(Plugin):
|
class AiBotPlugin(AbsExtraConfigPlugin):
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
await super().start()
|
await super().start()
|
||||||
# 加载并更新配置
|
# 加载并更新配置
|
||||||
self.config.load_and_update()
|
self.config.load_and_update()
|
||||||
# 决定当前机器人的名称
|
|
||||||
self.name = self.config['name'] or \
|
|
||||||
await self.client.get_displayname(self.client.mxid) or \
|
|
||||||
self.client.parse_user_id(self.client.mxid)[0]
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
判断sender是否是allowed_users中的成员
|
判断sender是否是allowed_users中的成员
|
||||||
@@ -58,6 +39,31 @@ class AiBotPlugin(Plugin):
|
|||||||
self.log.debug(f"{sender} doesn't match allowed_users")
|
self.log.debug(f"{sender} doesn't match allowed_users")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def is_allow_command(self, sender: str, command_key: str) -> bool:
|
||||||
|
allow_users = self.config[command_key]
|
||||||
|
# 如果一个都没有配置,都没有权限可以执行更新命令
|
||||||
|
if len(allow_users) <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# sender是否是配置中的一员, 如果是就允许进行命令修改执行,否则只有可读命令的执行
|
||||||
|
for u in allow_users:
|
||||||
|
if re.match(u, sender):
|
||||||
|
return True
|
||||||
|
self.log.debug(f"{sender} doesn't match {command_key}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_allow_update_read_command(self, sender: str) -> bool:
|
||||||
|
return self.is_allow_command(sender, "allow_update_read_command_users")
|
||||||
|
|
||||||
|
def is_allow_readonly_command(self, sender: str) -> bool:
|
||||||
|
is_update_read = self.is_allow_update_read_command(sender)
|
||||||
|
# 如果读写都有权限,就一定会有读权限,返回True
|
||||||
|
if is_update_read:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 如果没有读写权限,需要判断是否有只读权限
|
||||||
|
return self.is_allow_command(sender, "allow_readonly_command_users")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
判断是否应该让AI进行回应
|
判断是否应该让AI进行回应
|
||||||
回应条件:
|
回应条件:
|
||||||
@@ -75,6 +81,10 @@ class AiBotPlugin(Plugin):
|
|||||||
if event.sender == self.client.mxid:
|
if event.sender == self.client.mxid:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# 如果发送的消息中,第一个字符是感叹号,不进行回复
|
||||||
|
if event.content.body[0] == '!':
|
||||||
|
return False
|
||||||
|
|
||||||
# 判断这个用户是否在允许列表中, 不存在返回False
|
# 判断这个用户是否在允许列表中, 不存在返回False
|
||||||
# 如果列表为空, 继续往下执行
|
# 如果列表为空, 继续往下执行
|
||||||
if not self.is_allow(event.sender):
|
if not self.is_allow(event.sender):
|
||||||
@@ -86,7 +96,7 @@ class AiBotPlugin(Plugin):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否发送消息中有带上机器人的别名
|
# 检查是否发送消息中有带上机器人的别名
|
||||||
if re.search("(^|\s)(@)?" + self.name + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
if re.search("(^|\\s)(@)?" + self.get_bot_name() + "([ :,.!?]|$)", event.content.body, re.IGNORECASE):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 当聊天室只有两个人并且其中一个是机器人时
|
# 当聊天室只有两个人并且其中一个是机器人时
|
||||||
@@ -112,21 +122,29 @@ class AiBotPlugin(Plugin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.log.debug("开始发送消息")
|
|
||||||
await event.mark_read()
|
await event.mark_read()
|
||||||
await self.client.set_typing(event.room_id, timeout=99999)
|
await self.client.set_typing(event.room_id, timeout=99999)
|
||||||
platform = self.get_ai_platform()
|
platform = self.get_ai_platform()
|
||||||
|
|
||||||
|
if platform.is_streaming_enabled():
|
||||||
|
await self._handle_streaming(event, platform)
|
||||||
|
return
|
||||||
|
|
||||||
chat_completion = await platform.create_chat_completion(self, event)
|
chat_completion = await platform.create_chat_completion(self, event)
|
||||||
self.log.debug(f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
self.log.debug(
|
||||||
# ai gpt调用
|
f"发送结果 {chat_completion.message}, {chat_completion.model}, {chat_completion.finish_reason}")
|
||||||
# 关闭typing提示
|
|
||||||
await self.client.set_typing(event.room_id, timeout=0)
|
await self.client.set_typing(event.room_id, timeout=0)
|
||||||
# 打开typing提示
|
if chat_completion.result:
|
||||||
resp_content = chat_completion.message['content']
|
resp_content = chat_completion.message['content']
|
||||||
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||||
formatted_body=markdown.render(resp_content))
|
formatted_body=markdown.render(resp_content))
|
||||||
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||||
self.log.debug("发送结束")
|
else:
|
||||||
|
resp_content = "调用失败,请检查: " + chat_completion.finish_reason
|
||||||
|
response = TextMessageEventContent(msgtype=MessageType.TEXT, body=resp_content, format=Format.HTML,
|
||||||
|
formatted_body=markdown.render(resp_content))
|
||||||
|
await event.respond(response, in_thread=self.config['reply_in_thread'])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.exception(f"Something went wrong: {e}")
|
self.log.exception(f"Something went wrong: {e}")
|
||||||
await event.respond(f"Something went wrong: {e}")
|
await event.respond(f"Something went wrong: {e}")
|
||||||
@@ -134,21 +152,240 @@ class AiBotPlugin(Plugin):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
|
||||||
|
# 发送初始占位消息;on_message 已设 typing=on,等收到第一个 chunk 再关掉
|
||||||
|
placeholder = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌"
|
||||||
|
)
|
||||||
|
response_event_id = await evt.respond(placeholder, in_thread=self.config['reply_in_thread'])
|
||||||
|
self.log.debug("Streaming: placeholder sent")
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
last_edit_len = 0
|
||||||
|
first_chunk = True
|
||||||
|
EDIT_THRESHOLD = 100
|
||||||
|
|
||||||
|
async def send_edit(content: TextMessageEventContent) -> None:
|
||||||
|
# shield 防止 wait_for 超时时 cancel send_task,保护 mautrix 内部锁不残留
|
||||||
|
send_task = asyncio.ensure_future(self.client.send_message(evt.room_id, content))
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.shield(send_task), timeout=8.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.log.debug("Streaming: edit timed out, waiting naturally")
|
||||||
|
await send_task
|
||||||
|
except Exception as e:
|
||||||
|
self.log.warning(f"Streaming: edit error: {e}")
|
||||||
|
if not send_task.done():
|
||||||
|
await send_task
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in platform.create_chat_completion_stream(self, evt):
|
||||||
|
if first_chunk:
|
||||||
|
# 收到第一个 chunk 才关掉 typing,等待期间用户可见 typing 指示器
|
||||||
|
await self.client.set_typing(evt.room_id, timeout=0)
|
||||||
|
first_chunk = False
|
||||||
|
accumulated += chunk
|
||||||
|
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
|
||||||
|
display = accumulated + " ▌"
|
||||||
|
new_content = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT,
|
||||||
|
body=display,
|
||||||
|
format=Format.HTML,
|
||||||
|
formatted_body=markdown.render(display)
|
||||||
|
)
|
||||||
|
new_content.set_edit(response_event_id)
|
||||||
|
self.log.debug(f"Streaming: mid-edit, accumulated={len(accumulated)}")
|
||||||
|
await send_edit(new_content)
|
||||||
|
last_edit_len = len(accumulated)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception(f"Streaming error: {e}")
|
||||||
|
if not accumulated:
|
||||||
|
accumulated = f"Streaming error: {e}"
|
||||||
|
finally:
|
||||||
|
if first_chunk:
|
||||||
|
await self.client.set_typing(evt.room_id, timeout=0)
|
||||||
|
|
||||||
|
self.log.debug(f"Streaming: loop done, total={len(accumulated)}")
|
||||||
|
|
||||||
|
if not accumulated:
|
||||||
|
accumulated = "(无响应)"
|
||||||
|
final_content = TextMessageEventContent(
|
||||||
|
msgtype=MessageType.TEXT,
|
||||||
|
body=accumulated,
|
||||||
|
format=Format.HTML,
|
||||||
|
formatted_body=markdown.render(accumulated)
|
||||||
|
)
|
||||||
|
final_content.set_edit(response_event_id)
|
||||||
|
self.log.debug("Streaming: sending final edit")
|
||||||
|
await send_edit(final_content)
|
||||||
|
self.log.debug("Streaming: final edit done")
|
||||||
|
|
||||||
def get_ai_platform(self) -> Platform:
|
def get_ai_platform(self) -> Platform:
|
||||||
use_platform = self.config['use_platform']
|
use_platform = self.config.cur_platform
|
||||||
if use_platform == 'local_ai':
|
|
||||||
type = self.config['platforms']['local_ai']['type']
|
|
||||||
if type == 'ollama':
|
|
||||||
return Ollama(self.config, self.name, self.http)
|
|
||||||
elif type == 'lmstudio':
|
|
||||||
return LmStudio(self.config, self.name, self.http)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"not found platform type: {type}")
|
|
||||||
if use_platform == 'openai':
|
if use_platform == 'openai':
|
||||||
return OpenAi(self.config, self.name, self.http)
|
return OpenAi(self.config, self.http)
|
||||||
if use_platform == 'anthropic':
|
if use_platform == 'anthropic':
|
||||||
return Anthropic(self.config, self.name, self.http)
|
return Anthropic(self.config, self.http)
|
||||||
raise ValueError(f"unknown backend type {use_platform}")
|
if use_platform == 'xai':
|
||||||
|
return XAi(self.config, self.http)
|
||||||
|
if use_platform == 'deepseek':
|
||||||
|
return Deepseek(self.config, self.http)
|
||||||
|
if use_platform == 'gemini':
|
||||||
|
return Gemini(self.config, self.http)
|
||||||
|
if use_platform == 'qwen':
|
||||||
|
return Qwen(self.config, self.http)
|
||||||
|
if use_platform == 'local_ai#ollama':
|
||||||
|
return Ollama(self.config, self.http)
|
||||||
|
if use_platform == 'local_ai#lmstudio':
|
||||||
|
return LmStudio(self.config, self.http)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not found platform type: {type}")
|
||||||
|
|
||||||
|
"""
|
||||||
|
父命令
|
||||||
|
"""
|
||||||
|
@command.new(name="ai", require_subcommand=True)
|
||||||
|
async def ai_command(self, event: MessageEvent) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
@ai_command.subcommand(help="View the configuration information currently in official use")
|
||||||
|
async def info(self, event: MessageEvent) -> None:
|
||||||
|
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||||
|
is_allow = self.is_allow_readonly_command(event.sender)
|
||||||
|
if not is_allow:
|
||||||
|
await event.reply(f"{event.sender} have not read permission")
|
||||||
|
return
|
||||||
|
|
||||||
|
show_infos = []
|
||||||
|
# 当前机器人名称
|
||||||
|
show_infos.append(f"bot name: {self.get_bot_name()}\n\n")
|
||||||
|
# 查询当前使用的ai平台
|
||||||
|
show_infos.append(f"platform: {self.get_cur_platform()}\n\n")
|
||||||
|
show_infos.append("platform detail: \n\n")
|
||||||
|
# 查询当前ai平台的配置信息
|
||||||
|
p_m_dict = self.get_config()
|
||||||
|
for k, v in p_m_dict.items():
|
||||||
|
show_infos.append(f"- {k}: {v}\n")
|
||||||
|
# 当前使用的model
|
||||||
|
show_infos.append(f"\nmodel: {self.config.cur_model}\n")
|
||||||
|
# TODO 列出model信息
|
||||||
|
await event.reply("".join(show_infos), markdown=True)
|
||||||
|
pass
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取配置信息
|
||||||
|
"""
|
||||||
|
def get_config(self) -> dict:
|
||||||
|
platform_config_dict = dict(self.config['platforms'][self.get_cur_platform()])
|
||||||
|
# 移除敏感配置
|
||||||
|
platform_config_dict.pop('api_key')
|
||||||
|
platform_config_dict.pop('url')
|
||||||
|
return platform_config_dict
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取实际平台名称
|
||||||
|
"""
|
||||||
|
def get_cur_platform(self) -> str:
|
||||||
|
platform_model = self.config.cur_platform
|
||||||
|
return platform_model.split('#')[0]
|
||||||
|
|
||||||
|
@ai_command.subcommand(help="List platforms or query current platform in use")
|
||||||
|
@command.argument("argus")
|
||||||
|
async def platform(self, event: MessageEvent, argus: str):
|
||||||
|
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||||
|
is_allow = self.is_allow_readonly_command(event.sender)
|
||||||
|
if not is_allow:
|
||||||
|
await event.reply(f"{event.sender} have not read permission")
|
||||||
|
return
|
||||||
|
|
||||||
|
if argus == 'list':
|
||||||
|
p_dict = dict(self.config['platforms'])
|
||||||
|
platforms = [f"- {platform}" for platform in set(p_dict.keys())]
|
||||||
|
await event.reply("\n".join(platforms))
|
||||||
|
pass
|
||||||
|
if argus == 'current':
|
||||||
|
await event.reply(f"current use platform is {self.config.cur_platform}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ai_command.subcommand(help="List models or query current model in use")
|
||||||
|
@command.argument("argus")
|
||||||
|
async def model(self, event: MessageEvent, argus: str):
|
||||||
|
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||||
|
is_allow = self.is_allow_readonly_command(event.sender)
|
||||||
|
if not is_allow:
|
||||||
|
await event.reply(f"{event.sender} have not read permission")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果是list表示查看当前可以使用的模型列表
|
||||||
|
if argus == 'list':
|
||||||
|
platform = self.get_ai_platform()
|
||||||
|
models = await platform.list_models()
|
||||||
|
await event.reply("\n".join(models), markdown=True)
|
||||||
|
pass
|
||||||
|
# 如果是current,显示出当前的使用模型
|
||||||
|
if argus == 'current':
|
||||||
|
await event.reply(f"current use model is {self.config.cur_model}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ai_command.subcommand(help="switch model in platform")
|
||||||
|
@command.argument("argus")
|
||||||
|
async def use(self, event: MessageEvent, argus: str):
|
||||||
|
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||||
|
is_allow = self.is_allow_update_read_command(event.sender)
|
||||||
|
if not is_allow:
|
||||||
|
await event.reply(f"{event.sender} have not update permission")
|
||||||
|
return
|
||||||
|
|
||||||
|
platform = self.get_ai_platform()
|
||||||
|
# 获取模型列表,判断使用的模型是否存在于列表中
|
||||||
|
models = await platform.list_models()
|
||||||
|
if f"- {argus}" in models:
|
||||||
|
self.log.debug(f"switch model: {argus}")
|
||||||
|
self.config.cur_model = argus
|
||||||
|
await event.react("✅")
|
||||||
|
else:
|
||||||
|
await event.reply("not found valid model")
|
||||||
|
|
||||||
|
@ai_command.subcommand(help="switch platform")
|
||||||
|
@command.argument("argus")
|
||||||
|
async def switch(self, event: MessageEvent, argus: str):
|
||||||
|
# 判断是否有更新命令权限,如果没有就返回没有权限的提示
|
||||||
|
is_allow = self.is_allow_update_read_command(event.sender)
|
||||||
|
if not is_allow:
|
||||||
|
await event.reply(f"{event.sender} have not update permission")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 判断是否是本地ai模型,如果是还需要解析#后的type
|
||||||
|
if argus == 'local_ai':
|
||||||
|
await event.reply("local ai platform has ollama and lmstudio. "
|
||||||
|
"you can type `!ai use local_ai#{type}`. "
|
||||||
|
"Example: local_ai#ollama")
|
||||||
|
pass
|
||||||
|
if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio':
|
||||||
|
if argus == self.config.cur_platform:
|
||||||
|
await event.reply(f"current ai platform has be {argus}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.config.cur_platform = argus
|
||||||
|
self.config.cur_model = self.config['platforms'][argus.split("#")[0]]['model']
|
||||||
|
await event.react("✅")
|
||||||
|
# 如果是openai或者是claude
|
||||||
|
elif argus == 'openai' or argus == 'anthropic' or argus == 'xai' or argus == 'deepseek' or argus == 'gemini' or argus == 'qwen':
|
||||||
|
if argus == self.config.cur_platform:
|
||||||
|
await event.reply(f"current ai platform has be {argus}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.config.cur_platform = argus
|
||||||
|
# 使用配置的默认模型
|
||||||
|
self.config.cur_model = self.config['platforms'][argus]['model']
|
||||||
|
await event.react("✅")
|
||||||
|
else:
|
||||||
|
await event.reply(f"nof found ai platform: {argus}")
|
||||||
|
pass
|
||||||
|
self.log.debug(f"switch platform: {self.config.cur_platform}")
|
||||||
|
self.log.debug(f"use default config model: {self.config.cur_model}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
def get_config_class(cls) -> Type[BaseProxyConfig]:
|
||||||
|
|||||||
@@ -1,23 +1,30 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from maubot import Plugin
|
|
||||||
from mautrix.types import MessageEvent
|
from mautrix.types import MessageEvent
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
import maubot_llmplus
|
import maubot_llmplus
|
||||||
import maubot_llmplus.platforms
|
import maubot_llmplus.platforms
|
||||||
from maubot_llmplus.platforms import Platform, ChatCompletion
|
from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||||
|
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||||
|
from maubot_llmplus.thrid_platform import _read_openai_sse
|
||||||
|
|
||||||
|
|
||||||
class Ollama(Platform):
|
class Ollama(Platform):
|
||||||
chat_api: str
|
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
super().__init__(config, name, http)
|
super().__init__(config, http)
|
||||||
self.chat_api = '/api/chat'
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
full_context = []
|
full_context = []
|
||||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
full_context.extend(list(context))
|
full_context.extend(list(context))
|
||||||
@@ -25,35 +32,124 @@ class Ollama(Platform):
|
|||||||
endpoint = f"{self.url}/api/chat"
|
endpoint = f"{self.url}/api/chat"
|
||||||
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
||||||
headers = {'Content-Type': 'application/json'}
|
headers = {'Content-Type': 'application/json'}
|
||||||
if self.api_key is not None:
|
|
||||||
headers['Authorization'] = self.api_key
|
|
||||||
plugin.log.debug(f"{json.dumps(req_body)}")
|
|
||||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
return ChatCompletion(
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
message={},
|
message={},
|
||||||
finish_reason=f"http status {response.status}",
|
finish_reason=f"http status {response.status}",
|
||||||
model=None
|
model=None
|
||||||
)
|
)
|
||||||
text = await response.text()
|
|
||||||
plugin.log.debug(f"解析后的响应内容: {text}")
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return ChatCompletion(
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
message=response_json['message'],
|
message=response_json['message'],
|
||||||
finish_reason='success',
|
finish_reason='success',
|
||||||
model=response_json['model']
|
model=response_json['model']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/api/chat"
|
||||||
|
req_body = {'model': self.model, 'messages': full_context, 'stream': True}
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: http status {response.status}")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
if data.get("done"):
|
||||||
|
break
|
||||||
|
content = data.get("message", {}).get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
full_url = f"{self.url}/api/tags"
|
||||||
|
async with self.http.get(full_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [f"- {model['model']}" for model in response_data['models']]
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return "local_ai"
|
return "local_ai"
|
||||||
|
|
||||||
|
|
||||||
class LmStudio(Platform):
|
class LmStudio(Platform):
|
||||||
|
temperature: int
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
super().__init__(config, name, http)
|
super().__init__(config, http)
|
||||||
pass
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
def is_streaming_enabled(self) -> bool:
|
||||||
pass
|
return self.streaming
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
headers = {"content-type": "application/json"}
|
||||||
|
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": False}
|
||||||
|
async with self.http.post(
|
||||||
|
endpoint, headers=headers, data=json.dumps(req_body)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
choice = response_json["choices"][0]
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=choice["message"],
|
||||||
|
finish_reason=choice["finish_reason"],
|
||||||
|
model=choice.get("model", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
headers = {"content-type": "application/json"}
|
||||||
|
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": True}
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
async for chunk in _read_openai_sse(response):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
full_url = f"{self.url}/v1/models"
|
||||||
|
async with self.http.get(full_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [f"- {m['id']}" for m in response_data["data"]]
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "local_ai"
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List, Generator
|
from typing import Optional, List, Generator, AsyncIterator
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from maubot import Plugin
|
from maubot import Plugin
|
||||||
from mautrix.types import MessageEvent, EncryptedEvent
|
from mautrix.types import MessageEvent, EncryptedEvent
|
||||||
from mautrix.util.config import BaseProxyConfig
|
|
||||||
|
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AI响应对象
|
AI响应对象
|
||||||
@@ -14,7 +15,8 @@ from mautrix.util.config import BaseProxyConfig
|
|||||||
|
|
||||||
|
|
||||||
class ChatCompletion:
|
class ChatCompletion:
|
||||||
def __init__(self, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
def __init__(self, result: bool, message: dict, finish_reason: str, model: Optional[str]) -> None:
|
||||||
|
self.result = result
|
||||||
self.message = message
|
self.message = message
|
||||||
self.finish_reason = finish_reason
|
self.finish_reason = finish_reason
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -33,48 +35,57 @@ class Platform:
|
|||||||
additional_prompt: List[dict]
|
additional_prompt: List[dict]
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
max_context_messages: int
|
max_context_messages: int
|
||||||
name: str
|
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
def __init__(self, config: Config, http: ClientSession) -> None:
|
||||||
self.http = http
|
self.http = http
|
||||||
self.config = config['platforms'][self.get_type()]
|
self.config = config['platforms'][self.get_type()]
|
||||||
self.url = self.config['url']
|
self.url = self.config['url']
|
||||||
self.model = self.config['model']
|
# 设置当前的使用模型,这里不直接使用config对象下的配置值,而是加入了与命令决定后的使用模型名称
|
||||||
|
self.model = config.cur_model
|
||||||
self.max_words = self.config['max_words']
|
self.max_words = self.config['max_words']
|
||||||
self.api_key = self.config['api_key']
|
self.api_key = self.config['api_key']
|
||||||
self.max_context_messages = self.config['max_context_messages']
|
self.max_context_messages = self.config['max_context_messages']
|
||||||
self.additional_prompt = config['additional_prompt']
|
self.additional_prompt = config['additional_prompt']
|
||||||
self.system_prompt = config['system_prompt']
|
self.system_prompt = config['system_prompt']
|
||||||
self.name = name
|
|
||||||
|
|
||||||
"""a
|
"""a
|
||||||
调用AI对话接口, 响应结果
|
调用AI对话接口, 响应结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> AsyncIterator[str]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取系统提示上下文
|
||||||
async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) -> deque:
|
"""
|
||||||
|
async def get_system_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent) -> deque:
|
||||||
# 创建系统提示词上下文
|
# 创建系统提示词上下文
|
||||||
system_context = deque()
|
system_context = deque()
|
||||||
# 生成当前时间
|
# 生成当前时间
|
||||||
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
|
timestamp = datetime.today().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
# 加入系统提示词
|
# 加入系统提示词
|
||||||
system_prompt = {"role": "system",
|
system_prompt = {"role": "system",
|
||||||
"content": plugin.config['system_prompt'].format(name=plugin.name, timestamp=timestamp)}
|
"content": plugin.config['system_prompt'].format(name=plugin.get_bot_name(), timestamp=timestamp)}
|
||||||
if plugin.config['enable_multi_user']:
|
if plugin.config['enable_multi_user']:
|
||||||
system_prompt["content"] += """
|
system_prompt["content"] += """
|
||||||
User messages are in the context of multiperson chatrooms.
|
User messages are in the context of multiperson chatrooms.
|
||||||
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
|
Each message indicates its sender by prefixing the message with the sender's name followed by a colon, for example:
|
||||||
"username: hello world."
|
"username: hello world."
|
||||||
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
|
In this case, the user called "username" sent the message "hello world.". You should not follow this convention in your responses.
|
||||||
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
|
your response instead could be "hello username!" without including any colons, because you are the only one sending your responses there is no need to prefix them.
|
||||||
"""
|
"""
|
||||||
if len(system_prompt["content"]) > 0:
|
if len(system_prompt["content"]) > 0:
|
||||||
system_context.append(system_prompt)
|
system_context.append(system_prompt)
|
||||||
|
|
||||||
@@ -86,9 +97,14 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
|||||||
# 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
|
# 如果 消息长度已经超过了配置的消息条数,那么就抛出错误
|
||||||
if len(additional_context) > platform.max_context_messages - 1:
|
if len(additional_context) > platform.max_context_messages - 1:
|
||||||
raise ValueError(f"sorry, my configuration has too many additional prompts "
|
raise ValueError(f"sorry, my configuration has too many additional prompts "
|
||||||
f"({platform.max_context_messages}) and i'll never see your message. "
|
f"({platform.max_context_messages}) and i'll never see your message. "
|
||||||
f"Update my config to have fewer messages and i'll be able to answer your questions!")
|
f"Update my config to have fewer messages and i'll be able to answer your questions!")
|
||||||
|
return system_context
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取聊天信息上下文
|
||||||
|
"""
|
||||||
|
async def get_chat_context(system_context: deque, plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent, hasAssistant: bool=True) -> deque:
|
||||||
# 用户历史聊天上下文
|
# 用户历史聊天上下文
|
||||||
chat_context = deque()
|
chat_context = deque()
|
||||||
# 计算系统提示词单词数
|
# 计算系统提示词单词数
|
||||||
@@ -102,8 +118,16 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
|||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 如果没有assistant的角色,那么如果当前的对话消息是机器人的,忽略不要
|
||||||
|
if not hasAssistant:
|
||||||
|
if plugin.client.mxid == next_event.sender:
|
||||||
|
continue
|
||||||
|
else :
|
||||||
|
role = 'user'
|
||||||
|
else :
|
||||||
|
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
|
||||||
|
|
||||||
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
# 如果当前的这条历史消息是机器人自己的,那么角色就要设置为assistant
|
||||||
role = 'assistant' if plugin.client.mxid == next_event.sender else 'user'
|
|
||||||
message = next_event['content']['body']
|
message = next_event['content']['body']
|
||||||
user = ''
|
user = ''
|
||||||
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
|
# 如果是允许多用户使用,那么就需要在每个历史消息前加上用户名
|
||||||
@@ -118,11 +142,16 @@ async def get_context(plugin: Plugin, platform: Platform, evt: MessageEvent) ->
|
|||||||
break
|
break
|
||||||
chat_context.appendleft({"role": role, "content": user + message})
|
chat_context.appendleft({"role": role, "content": user + message})
|
||||||
|
|
||||||
|
return chat_context
|
||||||
|
|
||||||
|
"""
|
||||||
|
获取总消息上下文
|
||||||
|
"""
|
||||||
|
async def get_context(plugin: AbsExtraConfigPlugin, platform: Platform, evt: MessageEvent, hasAssistant: bool=True) -> deque:
|
||||||
|
system_context = await get_system_context(plugin, platform, evt)
|
||||||
|
chat_context = await get_chat_context(system_context, plugin, platform, evt, hasAssistant)
|
||||||
return system_context + chat_context
|
return system_context + chat_context
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_context_messages(plugin: Plugin, platform: Platform, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
|
async def generate_context_messages(plugin: Plugin, platform: Platform, evt: MessageEvent) -> Generator[MessageEvent, None, None]:
|
||||||
yield evt
|
yield evt
|
||||||
if plugin.config['reply_in_thread']:
|
if plugin.config['reply_in_thread']:
|
||||||
|
|||||||
43
maubot_llmplus/plugin.py
Normal file
43
maubot_llmplus/plugin.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from maubot import Plugin
|
||||||
|
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||||
|
|
||||||
|
|
||||||
|
class AbsExtraConfigPlugin(Plugin):
|
||||||
|
default_username: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
await super().start()
|
||||||
|
self.default_username = await self.client.get_displayname(self.client.mxid)
|
||||||
|
self.user_id = self.client.parse_user_id(self.client.mxid)[0]
|
||||||
|
|
||||||
|
def get_bot_name(self) -> str:
|
||||||
|
return self.config['name'] or \
|
||||||
|
self.default_username or \
|
||||||
|
self.user_id
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
配置文件加载
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseProxyConfig):
|
||||||
|
cur_model: str
|
||||||
|
cur_platform: str
|
||||||
|
|
||||||
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||||
|
helper.copy("allowed_users")
|
||||||
|
helper.copy("use_platform")
|
||||||
|
helper.copy("name")
|
||||||
|
helper.copy("reply_in_thread")
|
||||||
|
helper.copy("enable_multi_user")
|
||||||
|
helper.copy("system_prompt")
|
||||||
|
helper.copy("platforms")
|
||||||
|
helper.copy("additional_prompt")
|
||||||
|
helper.copy("allow_update_read_command_users")
|
||||||
|
helper.copy("allow_readonly_command_users")
|
||||||
|
|
||||||
|
self.cur_platform = helper.base['use_platform'] if helper.base['use_platform'] != 'local_ai' else \
|
||||||
|
f"{helper.base['use_platform']}#{helper.base['platforms']['local_ai']['type']}"
|
||||||
|
self.cur_model = helper.base['platforms'][helper.base['use_platform']]['model']
|
||||||
@@ -1,36 +1,648 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from maubot import Plugin
|
|
||||||
from mautrix.types import MessageEvent
|
from mautrix.types import MessageEvent
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
import maubot_llmplus.platforms
|
||||||
from maubot_llmplus.platforms import Platform, ChatCompletion
|
from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||||
|
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_openai_sse(response):
|
||||||
|
"""读取 OpenAI 兼容格式的 SSE 流,yield 每个 delta content"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if choices:
|
||||||
|
content = choices[0].get("delta", {}).get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Deepseek(Platform):
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig, http: ClientSession):
|
||||||
|
super().__init__(config, http)
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": full_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/chat/completions"
|
||||||
|
async with self.http.post(
|
||||||
|
endpoint, headers=headers, data=json.dumps(data)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
choice = response_json["choices"][0]
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=choice["message"],
|
||||||
|
finish_reason=choice["finish_reason"],
|
||||||
|
model=response_json.get("model", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": full_context,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/chat/completions"
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
async for chunk in _read_openai_sse(response):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
models = ["deepseek-chat", "deepseek-reasoner"]
|
||||||
|
return [f"- {m}" for m in models]
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "deepseek"
|
||||||
|
|
||||||
class OpenAi(Platform):
|
class OpenAi(Platform):
|
||||||
|
max_tokens: int
|
||||||
|
temperature: float
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
super().__init__(config, name, http)
|
super().__init__(config, http)
|
||||||
|
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||||
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
def is_streaming_enabled(self) -> bool:
|
||||||
# 获取系统提示词
|
return self.streaming
|
||||||
# 获取额外的其他角色的提示词: role: user role: system
|
|
||||||
|
|
||||||
pass
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": full_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'max_tokens' in self.config and self.max_tokens:
|
||||||
|
# 如果是gpt5的,就用max_completion_tokens
|
||||||
|
if 'gpt-5' in self.model:
|
||||||
|
data["max_completion_tokens"] = self.max_tokens
|
||||||
|
else:
|
||||||
|
# 如果是gpt4之前的,就是用max_tokens
|
||||||
|
data["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if 'temperature' in self.config and self.temperature:
|
||||||
|
data["temperature"] = self.temperature
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
async with self.http.post(
|
||||||
|
endpoint, headers=headers, data=json.dumps(data)
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
choice = response_json["choices"][0]
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=choice["message"],
|
||||||
|
finish_reason=choice["finish_reason"],
|
||||||
|
model=choice.get("model", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": full_context,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'max_tokens' in self.config and self.max_tokens:
|
||||||
|
if 'gpt-5' in self.model:
|
||||||
|
data["max_completion_tokens"] = self.max_tokens
|
||||||
|
else:
|
||||||
|
data["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if 'temperature' in self.config and self.temperature:
|
||||||
|
data["temperature"] = self.temperature
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
async for chunk in _read_openai_sse(response):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
# 调用openai接口获取模型列表
|
||||||
|
full_url = f"{self.url}/v1/models"
|
||||||
|
headers = {'Authorization': f"Bearer {self.api_key}"}
|
||||||
|
async with self.http.get(full_url, headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [f"- {m['id']}" for m in response_data["data"]]
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(Platform):
|
class Anthropic(Platform):
|
||||||
|
max_tokens: int
|
||||||
|
streaming: bool
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
super().__init__(config, name, http)
|
super().__init__(config, http)
|
||||||
|
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
|
def is_streaming_enabled(self) -> bool:
|
||||||
# 获取系统提示词
|
return self.streaming
|
||||||
# 获取额外的其他角色的提示词: role: user role: system
|
|
||||||
|
|
||||||
pass
|
def _build_request(self, full_chat_context: list) -> tuple:
|
||||||
|
endpoint = f"{self.url}/v1/messages"
|
||||||
|
headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"}
|
||||||
|
req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt,
|
||||||
|
"messages": full_chat_context}
|
||||||
|
return endpoint, headers, req_body
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_chat_context = []
|
||||||
|
system_context = deque()
|
||||||
|
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||||
|
full_chat_context.extend(list(chat_context))
|
||||||
|
|
||||||
|
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||||
|
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
text = "\n\n".join(c["text"] for c in response_json["content"])
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=dict(role="assistant", content=text),
|
||||||
|
finish_reason=response_json['stop_reason'],
|
||||||
|
model=response_json['model']
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_chat_context = []
|
||||||
|
system_context = deque()
|
||||||
|
chat_context = await maubot_llmplus.platforms.get_chat_context(system_context, plugin, self, evt)
|
||||||
|
full_chat_context.extend(list(chat_context))
|
||||||
|
|
||||||
|
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||||
|
req_body["stream"] = True
|
||||||
|
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
if data.get("type") == "message_stop":
|
||||||
|
break
|
||||||
|
if data.get("type") == "content_block_delta":
|
||||||
|
delta = data.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
yield delta.get("text", "")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
full_url = f"{self.url}/v1/models"
|
||||||
|
headers = {
|
||||||
|
'anthropic-version': "2023-06-01",
|
||||||
|
'X-Api-Key': f"{self.api_key}"
|
||||||
|
}
|
||||||
|
async with self.http.get(full_url, headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [f"- {m['id']}" for m in response_data['data']]
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
|
|
||||||
|
|
||||||
|
class Gemini(Platform):
|
||||||
|
max_tokens: int
|
||||||
|
temperature: float
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
|
super().__init__(config, http)
|
||||||
|
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||||
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
def _build_gemini_request(self, context) -> tuple:
|
||||||
|
system_parts = []
|
||||||
|
contents = []
|
||||||
|
for msg in context:
|
||||||
|
role = msg['role']
|
||||||
|
content = msg['content']
|
||||||
|
if role == 'system':
|
||||||
|
system_parts.append({"text": content})
|
||||||
|
elif role == 'assistant':
|
||||||
|
contents.append({"role": "model", "parts": [{"text": content}]})
|
||||||
|
else:
|
||||||
|
contents.append({"role": "user", "parts": [{"text": content}]})
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_parts:
|
||||||
|
request_body["system_instruction"] = {"parts": system_parts}
|
||||||
|
|
||||||
|
if self.max_tokens:
|
||||||
|
request_body["generationConfig"]["maxOutputTokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if self.temperature:
|
||||||
|
request_body["generationConfig"]["temperature"] = self.temperature
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-goog-api-key": self.api_key
|
||||||
|
}
|
||||||
|
return request_body, headers
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
request_body, headers = self._build_gemini_request(context)
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1beta/models/{self.model}:generateContent"
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
candidate = response_json["candidates"][0]
|
||||||
|
text = "".join(part["text"] for part in candidate["content"]["parts"])
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message={"role": "assistant", "content": text},
|
||||||
|
finish_reason=candidate.get("finishReason", "STOP"),
|
||||||
|
model=response_json.get("modelVersion", self.model)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
request_body, headers = self._build_gemini_request(context)
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1beta/models/{self.model}:streamGenerateContent?alt=sse"
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
# 与 Anthropic 保持一致:内联 while 循环,避免双层异步生成器代理
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
candidates = data.get("candidates", [])
|
||||||
|
if not candidates:
|
||||||
|
continue
|
||||||
|
candidate = candidates[0]
|
||||||
|
# 先 yield 文本,再判断是否结束(对齐 OpenAI [DONE] 逻辑)
|
||||||
|
parts = candidate.get("content", {}).get("parts", [])
|
||||||
|
for part in parts:
|
||||||
|
text = part.get("text", "")
|
||||||
|
if text:
|
||||||
|
yield text
|
||||||
|
finish_reason = candidate.get("finishReason")
|
||||||
|
if finish_reason:
|
||||||
|
logging.getLogger("instance/aibot").debug(
|
||||||
|
f"Gemini stream finished: finishReason={finish_reason}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
full_url = f"{self.url}/v1beta/models"
|
||||||
|
headers = {"x-goog-api-key": self.api_key}
|
||||||
|
async with self.http.get(full_url, headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [
|
||||||
|
f"- {m['name'].replace('models/', '')}"
|
||||||
|
for m in response_data.get("models", [])
|
||||||
|
if "generateContent" in m.get("supportedGenerationMethods", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "gemini"
|
||||||
|
|
||||||
|
|
||||||
|
class XAi(Platform):
|
||||||
|
max_tokens: int
|
||||||
|
temperature: int
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
|
super().__init__(config, http)
|
||||||
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||||
|
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
request_body = {
|
||||||
|
"messages": full_context,
|
||||||
|
"model": self.model,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'max_tokens' in self.config and self.max_tokens:
|
||||||
|
request_body["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if 'temperature' in self.config and self.temperature:
|
||||||
|
request_body["temperature"] = self.temperature
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
choice = response_json["choices"][0]
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=choice["message"],
|
||||||
|
finish_reason=choice["finish_reason"],
|
||||||
|
model=response_json["model"]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
request_body = {
|
||||||
|
"messages": full_context,
|
||||||
|
"model": self.model,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'max_tokens' in self.config and self.max_tokens:
|
||||||
|
request_body["max_tokens"] = self.max_tokens
|
||||||
|
|
||||||
|
if 'temperature' in self.config and self.temperature:
|
||||||
|
request_body["temperature"] = self.temperature
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/v1/chat/completions"
|
||||||
|
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
async for chunk in _read_openai_sse(response):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
full_url = f"{self.url}/v1/models"
|
||||||
|
headers = {'Content-Type': 'application/json', 'Authorization': f"Bearer {self.api_key}"}
|
||||||
|
async with self.http.get(full_url, headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return []
|
||||||
|
response_data = await response.json()
|
||||||
|
return [f"- {m['id']}" for m in response_data["data"]]
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "xai"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen(Platform):
|
||||||
|
max_tokens: int
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
enable_thinking: bool
|
||||||
|
|
||||||
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||||
|
super().__init__(config, http)
|
||||||
|
self.max_tokens = int(self.config['max_tokens']) if self.config.get('max_tokens') else None
|
||||||
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
||||||
|
self.top_p = float(self.config['top_p']) if self.config.get('top_p') is not None else None
|
||||||
|
self.enable_thinking = self.config['enable_thinking']
|
||||||
|
self.streaming = self.config.get('streaming', False)
|
||||||
|
|
||||||
|
def is_streaming_enabled(self) -> bool:
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
def _build_qwen_request(self, full_context: list) -> tuple:
|
||||||
|
parameters = {
|
||||||
|
"result_format": "message"
|
||||||
|
}
|
||||||
|
if self.max_tokens:
|
||||||
|
parameters["max_tokens"] = self.max_tokens
|
||||||
|
if self.temperature is not None:
|
||||||
|
parameters["temperature"] = self.temperature
|
||||||
|
if self.top_p is not None:
|
||||||
|
parameters["top_p"] = self.top_p
|
||||||
|
if self.enable_thinking:
|
||||||
|
parameters["enable_thinking"] = True
|
||||||
|
|
||||||
|
request_body = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": {
|
||||||
|
"messages": full_context
|
||||||
|
},
|
||||||
|
"parameters": parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = f"{self.url}/api/v1/services/aigc/text-generation/generation"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}"
|
||||||
|
}
|
||||||
|
return endpoint, headers, request_body
|
||||||
|
|
||||||
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||||
|
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return ChatCompletion(
|
||||||
|
result=False,
|
||||||
|
message={},
|
||||||
|
finish_reason=f"Error: {await response.text()}",
|
||||||
|
model=None
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
choice = response_json["output"]["choices"][0]
|
||||||
|
return ChatCompletion(
|
||||||
|
result=True,
|
||||||
|
message=choice["message"],
|
||||||
|
finish_reason=choice.get("finish_reason", "stop"),
|
||||||
|
model=response_json.get("model", self.model)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||||
|
full_context = []
|
||||||
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
|
full_context.extend(list(context))
|
||||||
|
|
||||||
|
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||||
|
# DashScope SSE 流式:增加 header 和 incremental_output 参数(每次只返回增量)
|
||||||
|
headers["X-DashScope-SSE"] = "enable"
|
||||||
|
request_body["parameters"]["incremental_output"] = True
|
||||||
|
|
||||||
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
if not line_bytes:
|
||||||
|
break
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
choices = data.get("output", {}).get("choices", [])
|
||||||
|
if choices:
|
||||||
|
content = choices[0].get("message", {}).get("content", "")
|
||||||
|
if content:
|
||||||
|
yield content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
models = [
|
||||||
|
"qwen-max", "qwen-max-latest",
|
||||||
|
"qwen-plus", "qwen-plus-latest",
|
||||||
|
"qwen-turbo", "qwen-turbo-latest",
|
||||||
|
"qwen-long",
|
||||||
|
"qwen3-235b-a22b", "qwen3-30b-a3b",
|
||||||
|
"qwq-plus", "qwq-plus-latest",
|
||||||
|
]
|
||||||
|
return [f"- {m}" for m in models]
|
||||||
|
|
||||||
|
def get_type(self) -> str:
|
||||||
|
return "qwen"
|
||||||
|
|||||||
Reference in New Issue
Block a user