diff --git a/base-config.yaml b/base-config.yaml index 0ecf25e..fba423e 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -40,6 +40,11 @@ platforms: max_words: 1000 max_tokens: 2000 max_context_messages: 20 + xai: + url: curl https://api.x.ai + api_key: + model: grok-beta + temperature: 1 # additional prompt additional_prompt: diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 8e397ac..a69cbc5 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -10,7 +10,7 @@ from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from maubot_llmplus.local_paltform import Ollama, LmStudio from maubot_llmplus.platforms import Platform from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config -from maubot_llmplus.thrid_platform import OpenAi, Anthropic +from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi class AiBotPlugin(AbsExtraConfigPlugin): @@ -122,6 +122,8 @@ class AiBotPlugin(AbsExtraConfigPlugin): return OpenAi(self.config, self.http) if use_platform == 'anthropic': return Anthropic(self.config, self.http) + if use_platform == 'xai': + return XAi(self.config, self.http) if use_platform == 'local_ai#ollama': return Ollama(self.config, self.http) if use_platform == 'local_ai#lmstudio': diff --git a/maubot_llmplus/thrid_platform.py b/maubot_llmplus/thrid_platform.py index ffc3ced..7e6def3 100644 --- a/maubot_llmplus/thrid_platform.py +++ b/maubot_llmplus/thrid_platform.py @@ -13,7 +13,6 @@ from maubot_llmplus.plugin import AbsExtraConfigPlugin class OpenAi(Platform): - max_tokens: int temperature: int @@ -90,7 +89,8 @@ class Anthropic(Platform): endpoint = f"{self.url}/v1/messages" headers = {"x-api-key": self.api_key, "anthropic-version": "2023-06-01", "content-type": "application/json"} - req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt, "messages": full_chat_context} + req_body = {"model": self.model, "max_tokens": self.max_tokens, "system": self.system_prompt, + "messages": full_chat_context} async with self.http.post(endpoint, headers=headers, data=json.dumps(req_body)) as response: # plugin.log.debug(f"响应内容:{response.status}, {await response.json()}") @@ -111,8 +111,65 @@ class Anthropic(Platform): async def list_models(self) -> List[str]: # 由于没有列出所有支持的模型的api,所有只能写死在代码中 - models = ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229 ", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"] + models = ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229 ", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"] return [f"- {m}" for m in models] def get_type(self) -> str: return "anthropic" + + +class XAi(Platform): + + def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: + super().__init__(config, http) + + def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion: + full_context = [] + context = await maubot_llmplus.platforms.get_context(plugin, self, evt) + full_context.extend(list(context)) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + request_body = { + "message": full_context, + "model": self.model, + "stream": False + } + + if 'temperature' in self.config and self.temperature: + request_body["temperature"] = self.temperature + + endpoint = f"{self.url}/v1/chat/completions" + with self.http.post(url=endpoint, data=json.dumps(request_body), headers=headers) as resp: + # plugin.log.debug(f"响应内容:{response.status}, {await response.json()}") + if response.status != 200: + return ChatCompletion( + message={}, + finish_reason=f"Error: {await response.text()}", + model=None + ) + response_json = await response.json() + choice = response_json["choices"][0] + return ChatCompletion( + message=choice["message"], + finish_reason=choice["finish_reason"], + model=response_json["model"] + ) + + pass + + def list_models(self) -> List[str]: + # 调用openai接口获取模型列表 + full_url = f"{self.url}/v1/models" + async with self.http.get(full_url) as response: + if response.status != 200: + return [] + response_data = await response.json() + return [f"- {m['id']}" for m in response_data["models"]] + pass + + def get_type(self) -> str: + return "xai" \ No newline at end of file