add: 添加命令ai model list的逻辑

This commit is contained in:
taylor
2024-10-13 18:53:42 +08:00
parent 7e1af58c84
commit dc5162b662
3 changed files with 36 additions and 0 deletions

View File

@@ -148,6 +148,30 @@ class AiBotPlugin(Plugin):
return Anthropic(self.config, self.name, self.http) return Anthropic(self.config, self.name, self.http)
raise ValueError(f"unknown backend type {use_platform}") raise ValueError(f"unknown backend type {use_platform}")
"""
父命令
"""
@command.new(name="ai", require_subcommand=True)
async def ai_command(self, event: MessageEvent) -> None:
pass
"""
"""
@ai_command.subcommand(help="")
async def info(self, event: MessageEvent) -> None:
pass
@ai_command.subcommand(help="")
@command.argument("argus")
async def model(self, event: MessageEvent, argus: str):
# 如果是list表示查看当前可以使用的模型列表
if argus == 'list':
platform = self.get_ai_platform()
models = platform.list_models()
await event.reply("\n".join(models))
# 如果不是,如果是其他的名称,表示这是一个模型名
@classmethod @classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config return Config

View File

@@ -1,4 +1,5 @@
import json import json
from typing import List
from aiohttp import ClientSession from aiohttp import ClientSession
from maubot import Plugin from maubot import Plugin
@@ -40,6 +41,14 @@ class Ollama(Platform):
model=response_json['model'] model=response_json['model']
) )
async def list_models(self) -> List[str]:
full_url = f"{self.url}/api/tags"
async with self.http.get(full_url) as response:
if response.status != 200:
return []
response_data = json.loads(await response.json())
return [model['name'] for model in response_data]
def get_type(self) -> str: def get_type(self) -> str:
return "local_ai" return "local_ai"

View File

@@ -54,6 +54,9 @@ class Platform:
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_models(self) -> List[str]:
raise NotImplementedError()
def get_type(self) -> str: def get_type(self) -> str:
raise NotImplementedError() raise NotImplementedError()