From dc5162b6625cf301d8429a8ee24f91bb519320de Mon Sep 17 00:00:00 2001 From: taylor Date: Sun, 13 Oct 2024 18:53:42 +0800 Subject: [PATCH] =?UTF-8?q?add:=20=E6=B7=BB=E5=8A=A0=E5=91=BD=E4=BB=A4ai?= =?UTF-8?q?=20model=20list=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- maubot_llmplus/aibot.py | 24 ++++++++++++++++++++++++ maubot_llmplus/local_paltform.py | 9 +++++++++ maubot_llmplus/platforms.py | 3 +++ 3 files changed, 36 insertions(+) diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 01b8db3..e112d6a 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -148,6 +148,30 @@ class AiBotPlugin(Plugin): return Anthropic(self.config, self.name, self.http) raise ValueError(f"unknown backend type {use_platform}") + """ + 父命令 + """ + @command.new(name="ai", require_subcommand=True) + async def ai_command(self, event: MessageEvent) -> None: + pass + + """ + """ + @ai_command.subcommand(help="") + async def info(self, event: MessageEvent) -> None: + pass + + @ai_command.subcommand(help="") + @command.argument("argus") + async def model(self, event: MessageEvent, argus: str): + # 如果是list表示查看当前可以使用的模型列表 + if argus == 'list': + platform = self.get_ai_platform() + models = platform.list_models() + await event.reply("\n".join(models)) + + # 如果不是,如果是其他的名称,表示这是一个模型名 + @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: return Config diff --git a/maubot_llmplus/local_paltform.py b/maubot_llmplus/local_paltform.py index dd4d760..200b7d0 100644 --- a/maubot_llmplus/local_paltform.py +++ b/maubot_llmplus/local_paltform.py @@ -1,4 +1,5 @@ import json +from typing import List from aiohttp import ClientSession from maubot import Plugin @@ -40,6 +41,14 @@ class Ollama(Platform): model=response_json['model'] ) + async def list_models(self) -> List[str]: + full_url = f"{self.url}/api/tags" + async with self.http.get(full_url) as response: + if response.status != 200: + return [] + response_data = json.loads(await response.json()) + return [model['name'] for model in response_data] + def get_type(self) -> str: return "local_ai" diff --git a/maubot_llmplus/platforms.py b/maubot_llmplus/platforms.py index 63dbebf..9e4b2df 100644 --- a/maubot_llmplus/platforms.py +++ b/maubot_llmplus/platforms.py @@ -54,6 +54,9 @@ class Platform: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: raise NotImplementedError() + async def list_models(self) -> List[str]: + raise NotImplementedError() + def get_type(self) -> str: raise NotImplementedError()