Compare commits
7 Commits
d5d634bf14
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
249f225045 | ||
|
|
448a95134f | ||
|
|
6b2fc9ea07 | ||
|
|
96373a9c14 | ||
|
|
70ea0a6916 | ||
|
|
98a4dba820 | ||
|
|
a5e43190f4 |
@@ -127,7 +127,6 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
platform = self.get_ai_platform()
|
platform = self.get_ai_platform()
|
||||||
|
|
||||||
if platform.is_streaming_enabled():
|
if platform.is_streaming_enabled():
|
||||||
await self.client.set_typing(event.room_id, timeout=0)
|
|
||||||
await self._handle_streaming(event, platform)
|
await self._handle_streaming(event, platform)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -154,7 +153,7 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
|
async def _handle_streaming(self, evt: MessageEvent, platform) -> None:
|
||||||
# 发送初始占位消息
|
# 发送初始占位消息;on_message 已设 typing=on,等收到第一个 chunk 再关掉
|
||||||
placeholder = TextMessageEventContent(
|
placeholder = TextMessageEventContent(
|
||||||
msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌"
|
msgtype=MessageType.TEXT, body="▌", format=Format.HTML, formatted_body="▌"
|
||||||
)
|
)
|
||||||
@@ -163,17 +162,17 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
|
|
||||||
accumulated = ""
|
accumulated = ""
|
||||||
last_edit_len = 0
|
last_edit_len = 0
|
||||||
EDIT_THRESHOLD = 100 # 每积累100个字符更新一次消息
|
first_chunk = True
|
||||||
|
EDIT_THRESHOLD = 100
|
||||||
|
|
||||||
async def send_edit(content: TextMessageEventContent) -> None:
|
async def send_edit(content: TextMessageEventContent) -> None:
|
||||||
"""顺序发送编辑消息:用shield确保send_message不被cancel,保护mautrix内部锁"""
|
# shield 防止 wait_for 超时时 cancel send_task,保护 mautrix 内部锁不残留
|
||||||
send_task = asyncio.ensure_future(self.client.send_message(evt.room_id, content))
|
send_task = asyncio.ensure_future(self.client.send_message(evt.room_id, content))
|
||||||
try:
|
try:
|
||||||
# shield防止wait_for超时cancel send_task本身,避免mautrix锁残留
|
|
||||||
await asyncio.wait_for(asyncio.shield(send_task), timeout=8.0)
|
await asyncio.wait_for(asyncio.shield(send_task), timeout=8.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
self.log.debug("Streaming: edit wait_for timed out, awaiting task completion")
|
self.log.debug("Streaming: edit timed out, waiting naturally")
|
||||||
await send_task # send_task仍在运行,等待自然完成
|
await send_task
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.warning(f"Streaming: edit error: {e}")
|
self.log.warning(f"Streaming: edit error: {e}")
|
||||||
if not send_task.done():
|
if not send_task.done():
|
||||||
@@ -181,6 +180,10 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in platform.create_chat_completion_stream(self, evt):
|
async for chunk in platform.create_chat_completion_stream(self, evt):
|
||||||
|
if first_chunk:
|
||||||
|
# 收到第一个 chunk 才关掉 typing,等待期间用户可见 typing 指示器
|
||||||
|
await self.client.set_typing(evt.room_id, timeout=0)
|
||||||
|
first_chunk = False
|
||||||
accumulated += chunk
|
accumulated += chunk
|
||||||
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
|
if len(accumulated) - last_edit_len >= EDIT_THRESHOLD:
|
||||||
display = accumulated + " ▌"
|
display = accumulated + " ▌"
|
||||||
@@ -198,10 +201,12 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
|||||||
self.log.exception(f"Streaming error: {e}")
|
self.log.exception(f"Streaming error: {e}")
|
||||||
if not accumulated:
|
if not accumulated:
|
||||||
accumulated = f"Streaming error: {e}"
|
accumulated = f"Streaming error: {e}"
|
||||||
|
finally:
|
||||||
|
if first_chunk:
|
||||||
|
await self.client.set_typing(evt.room_id, timeout=0)
|
||||||
|
|
||||||
self.log.debug(f"Streaming: loop done, total={len(accumulated)}")
|
self.log.debug(f"Streaming: loop done, total={len(accumulated)}")
|
||||||
|
|
||||||
# 输出最终完整内容
|
|
||||||
if not accumulated:
|
if not accumulated:
|
||||||
accumulated = "(无响应)"
|
accumulated = "(无响应)"
|
||||||
final_content = TextMessageEventContent(
|
final_content = TextMessageEventContent(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -39,6 +40,7 @@ async def _read_openai_sse(response):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Deepseek(Platform):
|
class Deepseek(Platform):
|
||||||
|
|
||||||
def __init__(self, config: BaseProxyConfig, http: ClientSession):
|
def __init__(self, config: BaseProxyConfig, http: ClientSession):
|
||||||
@@ -381,10 +383,11 @@ class Gemini(Platform):
|
|||||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||||
request_body, headers = self._build_gemini_request(context)
|
request_body, headers = self._build_gemini_request(context)
|
||||||
|
|
||||||
endpoint = f"{self.url}/v1beta/models/{self.model}:streamGenerateContent"
|
endpoint = f"{self.url}/v1beta/models/{self.model}:streamGenerateContent?alt=sse"
|
||||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
raise ValueError(f"Error: {await response.text()}")
|
raise ValueError(f"Error: {await response.text()}")
|
||||||
|
# 与 Anthropic 保持一致:内联 while 循环,避免双层异步生成器代理
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||||
@@ -399,12 +402,21 @@ class Gemini(Platform):
|
|||||||
try:
|
try:
|
||||||
data = json.loads(data_str)
|
data = json.loads(data_str)
|
||||||
candidates = data.get("candidates", [])
|
candidates = data.get("candidates", [])
|
||||||
if candidates:
|
if not candidates:
|
||||||
parts = candidates[0].get("content", {}).get("parts", [])
|
continue
|
||||||
for part in parts:
|
candidate = candidates[0]
|
||||||
text = part.get("text", "")
|
# 先 yield 文本,再判断是否结束(对齐 OpenAI [DONE] 逻辑)
|
||||||
if text:
|
parts = candidate.get("content", {}).get("parts", [])
|
||||||
yield text
|
for part in parts:
|
||||||
|
text = part.get("text", "")
|
||||||
|
if text:
|
||||||
|
yield text
|
||||||
|
finish_reason = candidate.get("finishReason")
|
||||||
|
if finish_reason:
|
||||||
|
logging.getLogger("instance/aibot").debug(
|
||||||
|
f"Gemini stream finished: finishReason={finish_reason}"
|
||||||
|
)
|
||||||
|
break
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user