增加流式功能

This commit is contained in:
taylorxie
2026-03-09 23:44:39 +08:00
parent 89160ce482
commit bf4d2a444c
3 changed files with 278 additions and 19 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import json
from typing import List
@@ -11,12 +12,17 @@ import maubot_llmplus
import maubot_llmplus.platforms
from maubot_llmplus.platforms import Platform, ChatCompletion
from maubot_llmplus.plugin import AbsExtraConfigPlugin
from maubot_llmplus.thrid_platform import _read_openai_sse
class Ollama(Platform):
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
super().__init__(config, http)
self.streaming = self.config.get('streaming', False)
def is_streaming_enabled(self) -> bool:
return self.streaming
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
full_context = []
@@ -27,20 +33,52 @@ class Ollama(Platform):
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
headers = {'Content-Type': 'application/json'}
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
if response.status != 200:
return ChatCompletion(
result=False,
message={},
finish_reason=f"http status {response.status}",
model=None
)
response_json = await response.json()
return ChatCompletion(
result=True,
message=response_json['message'],
finish_reason='success',
model=response_json['model']
)
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
full_context = []
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
full_context.extend(list(context))
endpoint = f"{self.url}/api/chat"
req_body = {'model': self.model, 'messages': full_context, 'stream': True}
headers = {'Content-Type': 'application/json'}
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
if response.status != 200:
raise ValueError(f"Error: http status {response.status}")
while True:
try:
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
except asyncio.TimeoutError:
break
if not line_bytes:
break
line = line_bytes.decode("utf-8").strip()
if not line:
continue
try:
data = json.loads(line)
if data.get("done"):
break
content = data.get("message", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
pass
async def list_models(self) -> List[str]:
full_url = f"{self.url}/api/tags"
async with self.http.get(full_url) as response:
@@ -53,13 +91,16 @@ class Ollama(Platform):
return "local_ai"
class LmStudio(Platform) :
class LmStudio(Platform):
temperature: int
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
super().__init__(config, http)
self.temperature = self.config['temperature']
pass
self.streaming = self.config.get('streaming', False)
def is_streaming_enabled(self) -> bool:
return self.streaming
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
full_context = []
@@ -72,9 +113,9 @@ class LmStudio(Platform) :
async with self.http.post(
endpoint, headers=headers, data=json.dumps(req_body)
) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
if response.status != 200:
return ChatCompletion(
result=False,
message={},
finish_reason=f"Error: {await response.text()}",
model=None
@@ -82,11 +123,26 @@ class LmStudio(Platform) :
response_json = await response.json()
choice = response_json["choices"][0]
return ChatCompletion(
result=True,
message=choice["message"],
finish_reason=choice["finish_reason"],
model=choice.get("model", None)
)
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
full_context = []
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
full_context.extend(list(context))
endpoint = f"{self.url}/v1/chat/completions"
headers = {"content-type": "application/json"}
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": True}
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
if response.status != 200:
raise ValueError(f"Error: {await response.text()}")
async for chunk in _read_openai_sse(response):
yield chunk
async def list_models(self) -> List[str]:
full_url = f"{self.url}/v1/models"
async with self.http.get(full_url) as response: