add: 增加切换ai平台命令

This commit is contained in:
taylor
2024-10-14 13:46:32 +08:00
parent 69fa0c0a50
commit 400d628e9f
2 changed files with 46 additions and 29 deletions

View File

@@ -117,20 +117,17 @@ class AiBotPlugin(AbsExtraConfigPlugin):
return None
def get_ai_platform(self) -> Platform:
use_platform = self.config['use_platform']
if use_platform == 'local_ai':
type = self.config['platforms']['local_ai']['type']
if type == 'ollama':
return Ollama(self.config, self.http)
elif type == 'lmstudio':
return LmStudio(self.config, self.http)
else:
raise ValueError(f"not found platform type: {type}")
use_platform = self.config.cur_platform
if use_platform == 'openai':
return OpenAi(self.config, self.http)
if use_platform == 'anthropic':
return Anthropic(self.config, self.http)
raise ValueError(f"unknown backend type {use_platform}")
if use_platform == 'local_ai#ollama':
return Ollama(self.config, self.http)
if use_platform == 'lmstudio':
return LmStudio(self.config, self.http)
else:
raise ValueError(f"not found platform type: {type}")
"""
父命令
@@ -147,29 +144,46 @@ class AiBotPlugin(AbsExtraConfigPlugin):
@ai_command.subcommand(help="")
@command.argument("argus")
@command.argument("argus1")
async def model(self, event: MessageEvent, argus: str, argus1):
async def model(self, event: MessageEvent, argus: str):
# 如果是list表示查看当前可以使用的模型列表
if argus == 'list':
if argus == '#list':
platform = self.get_ai_platform()
models = await platform.list_models()
await event.reply("\n".join(models), markdown=True)
# 如果不是,如果是其他的名称,表示这是一个模型名
# 如果是use为第二命令,则表示要切换模型
if argus.startswith('use'):
arg_elements = argus.strip().split(" ", 2)
# 如果命令小于2的个数就没有写模型名无法切换
if len(arg_elements) < 2:
await event.reply("give me a model name after 'use' command", markdown=True)
platform = self.get_ai_platform()
models = platform.list_models()
if f"- {arg_elements[1]}" in models:
self.log.debug(f"switch model: {arg_elements[1]}")
self.config._cur_model = arg_elements[1]
await event.react("")
@ai_command.subcommand(help="")
@command.argument("argus")
async def use(self, event: MessageEvent, argus: str):
platform = self.get_ai_platform()
# 获取模型列表,判断使用的模型是否存在于列表中
models = platform.list_models()
if f"- {argus}" in models:
self.log.debug(f"switch model: {argus}")
self.config._cur_model = argus
await event.react("")
else:
await event.reply("not found valid model")
@ai_command.subcommand(help="")
@command.argument("argus")
async def switch(self, event: MessageEvent, argus: str):
# 判断是否是本地ai模型如果是还需要解析#后的type
if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio':
if argus.split('#')[1] == self.config.cur_platform:
event.reply(f"current ai platform has be {argus}")
else:
await event.reply("not found valid model")
self.config.cur_platform = argus
await event.react("")
# 如果是openai或者是claude
elif argus == 'openai' or argus == 'anthropic':
if argus == self.config.cur_platform:
event.reply(f"current ai platform has be {argus}")
else:
self.config.cur_platform = argus
await event.react("")
else:
event.reply(f"nof found ai platform: {argus}")
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: