156 lines
6.1 KiB
Python
156 lines
6.1 KiB
Python
import asyncio
|
|
import json
|
|
|
|
from typing import List
|
|
|
|
from aiohttp import ClientSession
|
|
|
|
from mautrix.types import MessageEvent
|
|
from mautrix.util.config import BaseProxyConfig
|
|
|
|
import maubot_llmplus
|
|
import maubot_llmplus.platforms
|
|
from maubot_llmplus.platforms import Platform, ChatCompletion
|
|
from maubot_llmplus.plugin import AbsExtraConfigPlugin
|
|
from maubot_llmplus.thrid_platform import _read_openai_sse
|
|
|
|
|
|
class Ollama(Platform):
|
|
|
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
|
super().__init__(config, http)
|
|
self.streaming = self.config.get('streaming', False)
|
|
|
|
def is_streaming_enabled(self) -> bool:
|
|
return self.streaming
|
|
|
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
|
full_context = []
|
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
|
full_context.extend(list(context))
|
|
|
|
endpoint = f"{self.url}/api/chat"
|
|
req_body = {'model': self.model, 'messages': full_context, 'stream': False}
|
|
headers = {'Content-Type': 'application/json'}
|
|
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
|
if response.status != 200:
|
|
return ChatCompletion(
|
|
result=False,
|
|
message={},
|
|
finish_reason=f"http status {response.status}",
|
|
model=None
|
|
)
|
|
response_json = await response.json()
|
|
return ChatCompletion(
|
|
result=True,
|
|
message=response_json['message'],
|
|
finish_reason='success',
|
|
model=response_json['model']
|
|
)
|
|
|
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
|
full_context = []
|
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
|
full_context.extend(list(context))
|
|
|
|
endpoint = f"{self.url}/api/chat"
|
|
req_body = {'model': self.model, 'messages': full_context, 'stream': True}
|
|
headers = {'Content-Type': 'application/json'}
|
|
async with self.http.post(endpoint, headers=headers, json=req_body) as response:
|
|
if response.status != 200:
|
|
raise ValueError(f"Error: http status {response.status}")
|
|
while True:
|
|
try:
|
|
line_bytes = await asyncio.wait_for(response.content.readline(), timeout=60.0)
|
|
except asyncio.TimeoutError:
|
|
break
|
|
if not line_bytes:
|
|
break
|
|
line = line_bytes.decode("utf-8").strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
data = json.loads(line)
|
|
if data.get("done"):
|
|
break
|
|
content = data.get("message", {}).get("content", "")
|
|
if content:
|
|
yield content
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
async def list_models(self) -> List[str]:
|
|
full_url = f"{self.url}/api/tags"
|
|
async with self.http.get(full_url) as response:
|
|
if response.status != 200:
|
|
return []
|
|
response_data = await response.json()
|
|
return [f"- {model['model']}" for model in response_data['models']]
|
|
|
|
def get_type(self) -> str:
|
|
return "local_ai"
|
|
|
|
|
|
class LmStudio(Platform):
|
|
temperature: int
|
|
|
|
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
|
super().__init__(config, http)
|
|
self.temperature = float(self.config['temperature']) if self.config.get('temperature') is not None else None
|
|
self.streaming = self.config.get('streaming', False)
|
|
|
|
def is_streaming_enabled(self) -> bool:
|
|
return self.streaming
|
|
|
|
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
|
full_context = []
|
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
|
full_context.extend(list(context))
|
|
|
|
endpoint = f"{self.url}/v1/chat/completions"
|
|
headers = {"content-type": "application/json"}
|
|
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": False}
|
|
async with self.http.post(
|
|
endpoint, headers=headers, data=json.dumps(req_body)
|
|
) as response:
|
|
if response.status != 200:
|
|
return ChatCompletion(
|
|
result=False,
|
|
message={},
|
|
finish_reason=f"Error: {await response.text()}",
|
|
model=None
|
|
)
|
|
response_json = await response.json()
|
|
choice = response_json["choices"][0]
|
|
return ChatCompletion(
|
|
result=True,
|
|
message=choice["message"],
|
|
finish_reason=choice["finish_reason"],
|
|
model=choice.get("model", None)
|
|
)
|
|
|
|
async def create_chat_completion_stream(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent):
|
|
full_context = []
|
|
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
|
full_context.extend(list(context))
|
|
|
|
endpoint = f"{self.url}/v1/chat/completions"
|
|
headers = {"content-type": "application/json"}
|
|
req_body = {"model": self.model, "messages": full_context, "temperature": self.temperature, "stream": True}
|
|
async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response:
|
|
if response.status != 200:
|
|
raise ValueError(f"Error: {await response.text()}")
|
|
async for chunk in _read_openai_sse(response):
|
|
yield chunk
|
|
|
|
async def list_models(self) -> List[str]:
|
|
full_url = f"{self.url}/v1/models"
|
|
async with self.http.get(full_url) as response:
|
|
if response.status != 200:
|
|
return []
|
|
response_data = await response.json()
|
|
return [f"- {m['id']}" for m in response_data["data"]]
|
|
|
|
def get_type(self) -> str:
|
|
return "local_ai"
|