diff --git a/base-config.yaml b/base-config.yaml index 554a448..21a4a4a 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -19,12 +19,13 @@ platforms: max_words: 1000 max_context_messages: 20 openai: - url: + url: https://api.openai.com api_key: - model: - max_tokens: - max_words: - temperature: + model: gpt4o + max_tokens: 2000 + max_words: 1000 + max_context_messages: 20 + temperature: 1 anthropic: url: api_key: diff --git a/maubot_llmplus/local_paltform.py b/maubot_llmplus/local_paltform.py index fb14e21..c0deac2 100644 --- a/maubot_llmplus/local_paltform.py +++ b/maubot_llmplus/local_paltform.py @@ -1,14 +1,15 @@ -import json + from typing import List from aiohttp import ClientSession -from maubot import Plugin + from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig import maubot_llmplus import maubot_llmplus.platforms from maubot_llmplus.platforms import Platform, ChatCompletion +from maubot_llmplus.plugin import AbsExtraConfigPlugin class Ollama(Platform): @@ -18,7 +19,7 @@ class Ollama(Platform): super().__init__(config, http) self.chat_api = '/api/chat' - async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: full_context = [] context = await maubot_llmplus.platforms.get_context(plugin, self, evt) full_context.extend(list(context)) @@ -59,5 +60,5 @@ class LmStudio(Platform): super().__init__(config, http) pass - async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: pass diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 89291c4..e907a33 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -51,7 +51,7 @@ class Platform: 调用AI对话接口, 响应结果 """ - async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: raise NotImplementedError() async def list_models(self) -> List[str]: diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index 8dd03cb..f6cd84d 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -1,21 +1,62 @@ +import json + from aiohttp import ClientSession -from maubot import Plugin from mautrix.types import MessageEvent from mautrix.util.config import BaseProxyConfig +import maubot_llmplus.platforms from maubot_llmplus.platforms import Platform, ChatCompletion +from maubot_llmplus.plugin import AbsExtraConfigPlugin class OpenAi(Platform): + max_tokens: int + temperature: int + def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) + self.max_tokens = self.config['max_tokens'] + self.temperature = self.config['temperature'] - async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: - # 获取系统提示词 - # 获取额外的其他角色的提示词: role: user role: system + async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: + full_context = [] + context = await maubot_llmplus.platforms.get_context(plugin, self, evt) + full_context.extend(list(context)) - pass + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config['gpt_api_key']}" + } + data = { + "model": self.model, + "messages": full_context, + } + + if 'max_tokens' in self.config and self.max_tokens: + data["max_tokens"] = self.max_tokens + + if 'temperature' in self.config and self.temperature: + data["temperature"] = self.temperature + + endpoint = f"{self.url}/v1/chat/completions" + async with self.http.post( + endpoint, headers=headers, data=json.dumps(data) + ) as response: + # plugin.log.debug(f"响应内容:{response.status}, {await response.json()}") + if response.status != 200: + return ChatCompletion( + message={}, + finish_reason=f"Error: {await response.text()}", + model=None + ) + response_json = await response.json() + choice = response_json["choices"][0] + return ChatCompletion( + message=choice["message"], + finish_reason=choice["finish_reason"], + model=choice.get("model", None) + ) def get_type(self) -> str: return "openai" @@ -26,7 +67,7 @@ class Anthropic(Platform): def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: super().__init__(config, http) - async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: + async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: # 获取系统提示词 # 获取额外的其他角色的提示词: role: user role: system