增加流式功能
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from typing import List
|
||||
@@ -11,12 +12,17 @@ import maubot_llmplus
|
||||
import maubot_llmplus.platforms
|
||||
from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||
from maubot_llmplus.thrid_platform import _read_openai_sse
|
||||
|
||||
|
||||
class Ollama(Platform):
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
@@ -27,20 +33,52 @@ class Ollama(Platform):
|
||||
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"http status {response.status}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=response_json['message'],
|
||||
finish_reason='success',
|
||||
model=response_json['model']
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint = f"{self.url}/api/chat"
|
||||
req_body = {'model': self.model, 'messages': full_context, 'stream': True}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: http status {response.status}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if data.get("done"):
|
||||
break
|
||||
content = data.get("message", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/api/tags"
|
||||
async with self.http.get(full_url) as response:
|
||||
@@ -53,13 +91,16 @@ class Ollama(Platform):
|
||||
return "local_ai"
|
||||
|
||||
|
||||
class LmStudio(Platform) :
|
||||
class LmStudio(Platform):
|
||||
temperature: int
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.temperature = self.config['temperature']
|
||||
pass
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
@@ -72,9 +113,9 @@ class LmStudio(Platform) :
|
||||
async with self.http.post(
|
||||
endpoint, headers=headers, data=json.dumps(req_body)
|
||||
) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
@@ -82,11 +123,26 @@ class LmStudio(Platform) :
|
||||
response_json = await response.json()
|
||||
choice = response_json["choices"][0]
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message=choice["message"],
|
||||
finish_reason=choice["finish_reason"],
|
||||
model=choice.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
headers = {"content-type": "application/json"}
|
||||
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": True}
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1/models"
|
||||
async with self.http.get(full_url) as response:
|
||||
|
||||
@@ -13,10 +13,40 @@ from maubot_llmplus.platforms import Platform, ChatCompletion
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
||||
|
||||
|
||||
async def _read_openai_sse(response):
|
||||
"""读取 OpenAI 兼容格式的 SSE 流,yield 每个 delta content"""
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
|
||||
class Deepseek(Platform):
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession):
|
||||
super().__init__(config, http)
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
@@ -52,6 +82,28 @@ class Deepseek(Platform):
|
||||
model=response_json.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
endpoint = f"{self.url}/chat/completions"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = ["deepseek-chat", "deepseek-reasoner"]
|
||||
return [f"- {m}" for m in models]
|
||||
@@ -67,6 +119,10 @@ class OpenAi(Platform):
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = self.config['max_tokens']
|
||||
self.temperature = self.config['temperature']
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
@@ -97,7 +153,6 @@ class OpenAi(Platform):
|
||||
async with self.http.post(
|
||||
endpoint, headers=headers, data=json.dumps(data)
|
||||
) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
@@ -114,6 +169,37 @@ class OpenAi(Platform):
|
||||
model=choice.get("model", None)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": full_context,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
if 'gpt-5' in self.model:
|
||||
data["max_completion_tokens"] = self.max_tokens
|
||||
else:
|
||||
data["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
data["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(data)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# 调用openai接口获取模型列表
|
||||
full_url = f"{self.url}/v1/models"
|
||||
@@ -156,7 +242,6 @@ class Anthropic(Platform):
|
||||
endpoint, headers, req_body = self._build_request(full_chat_context)
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
@@ -208,7 +293,6 @@ class Anthropic(Platform):
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# 调用openai接口获取模型列表
|
||||
full_url = f"{self.url}/v1/models"
|
||||
headers = {
|
||||
'anthropic-version': "2023-06-01",
|
||||
@@ -232,10 +316,12 @@ class Gemini(Platform):
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = self.config['max_tokens']
|
||||
self.temperature = self.config['temperature']
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
def _build_gemini_request(self, context) -> tuple:
|
||||
system_parts = []
|
||||
contents = []
|
||||
for msg in context:
|
||||
@@ -262,12 +348,17 @@ class Gemini(Platform):
|
||||
if self.temperature:
|
||||
request_body["generationConfig"]["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:generateContent"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key
|
||||
}
|
||||
return request_body, headers
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
request_body, headers = self._build_gemini_request(context)
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:generateContent"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
@@ -286,6 +377,37 @@ class Gemini(Platform):
|
||||
model=response_json.get("modelVersion", self.model)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
request_body, headers = self._build_gemini_request(context)
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:streamGenerateContent"
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
for part in parts:
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1beta/models"
|
||||
headers = {"x-goog-api-key": self.api_key}
|
||||
@@ -311,6 +433,10 @@ class XAi(Platform):
|
||||
super().__init__(config, http)
|
||||
self.temperature = self.config['temperature']
|
||||
self.max_tokens = self.config['max_tokens']
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
@@ -335,7 +461,6 @@ class XAi(Platform):
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
@@ -352,10 +477,35 @@ class XAi(Platform):
|
||||
model=response_json["model"]
|
||||
)
|
||||
|
||||
pass
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
request_body = {
|
||||
"messages": full_context,
|
||||
"model": self.model,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if 'max_tokens' in self.config and self.max_tokens:
|
||||
request_body["max_tokens"] = self.max_tokens
|
||||
|
||||
if 'temperature' in self.config and self.temperature:
|
||||
request_body["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1/chat/completions"
|
||||
async with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
async for chunk in _read_openai_sse(response):
|
||||
yield chunk
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# 调用openai接口获取模型列表
|
||||
full_url = f"{self.url}/v1/models"
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f"Bearer {self.api_key}"}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
@@ -363,7 +513,6 @@ class XAi(Platform):
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [f"- {m['id']}" for m in response_data["data"]]
|
||||
pass
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "xai"
|
||||
@@ -381,12 +530,12 @@ class Qwen(Platform):
|
||||
self.temperature = self.config['temperature']
|
||||
self.top_p = self.config['top_p']
|
||||
self.enable_thinking = self.config['enable_thinking']
|
||||
self.streaming = self.config.get('streaming', False)
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
def is_streaming_enabled(self) -> bool:
|
||||
return self.streaming
|
||||
|
||||
def _build_qwen_request(self, full_context: list) -> tuple:
|
||||
parameters = {
|
||||
"result_format": "message"
|
||||
}
|
||||
@@ -412,6 +561,14 @@ class Qwen(Platform):
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
return endpoint, headers, request_body
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
@@ -430,6 +587,40 @@ class Qwen(Platform):
|
||||
model=response_json.get("model", self.model)
|
||||
)
|
||||
|
||||
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
||||
full_context = []
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
full_context.extend(list(context))
|
||||
|
||||
endpoint, headers, request_body = self._build_qwen_request(full_context)
|
||||
# DashScope SSE 流式:增加 header 和 incremental_output 参数(每次只返回增量)
|
||||
headers["X-DashScope-SSE"] = "enable"
|
||||
request_body["parameters"]["incremental_output"] = True
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
raise ValueError(f"Error: {await response.text()}")
|
||||
while True:
|
||||
try:
|
||||
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not line_bytes:
|
||||
break
|
||||
line = line_bytes.decode("utf-8").strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
data_str = line[5:].strip()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("output", {}).get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = [
|
||||
"qwen-max", "qwen-max-latest",
|
||||
@@ -442,4 +633,4 @@ class Qwen(Platform):
|
||||
return [f"- {m}" for m in models]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "qwen"
|
||||
return "qwen"
|
||||
|
||||
Reference in New Issue
Block a user