add: 增加切换ai平台命令

This commit is contained in:
taylor
2024-10-14 13:46:32 +08:00
parent 69fa0c0a50
commit 400d628e9f
2 changed files with 46 additions and 29 deletions

View File

@@ -117,20 +117,17 @@ class AiBotPlugin(AbsExtraConfigPlugin):
return None return None
def get_ai_platform(self) -> Platform: def get_ai_platform(self) -> Platform:
use_platform = self.config['use_platform'] use_platform = self.config.cur_platform
if use_platform == 'local_ai':
type = self.config['platforms']['local_ai']['type']
if type == 'ollama':
return Ollama(self.config, self.http)
elif type == 'lmstudio':
return LmStudio(self.config, self.http)
else:
raise ValueError(f"not found platform type: {type}")
if use_platform == 'openai': if use_platform == 'openai':
return OpenAi(self.config, self.http) return OpenAi(self.config, self.http)
if use_platform == 'anthropic': if use_platform == 'anthropic':
return Anthropic(self.config, self.http) return Anthropic(self.config, self.http)
raise ValueError(f"unknown backend type {use_platform}") if use_platform == 'local_ai#ollama':
return Ollama(self.config, self.http)
if use_platform == 'lmstudio':
return LmStudio(self.config, self.http)
else:
raise ValueError(f"not found platform type: {type}")
""" """
父命令 父命令
@@ -147,29 +144,46 @@ class AiBotPlugin(AbsExtraConfigPlugin):
@ai_command.subcommand(help="") @ai_command.subcommand(help="")
@command.argument("argus") @command.argument("argus")
@command.argument("argus1") async def model(self, event: MessageEvent, argus: str):
async def model(self, event: MessageEvent, argus: str, argus1):
# 如果是list表示查看当前可以使用的模型列表 # 如果是list表示查看当前可以使用的模型列表
if argus == 'list': if argus == '#list':
platform = self.get_ai_platform() platform = self.get_ai_platform()
models = await platform.list_models() models = await platform.list_models()
await event.reply("\n".join(models), markdown=True) await event.reply("\n".join(models), markdown=True)
# 如果不是,如果是其他的名称,表示这是一个模型名
# 如果是use为第二命令,则表示要切换模型 @ai_command.subcommand(help="")
if argus.startswith('use'): @command.argument("argus")
arg_elements = argus.strip().split(" ", 2) async def use(self, event: MessageEvent, argus: str):
# 如果命令小于2的个数就没有写模型名无法切换 platform = self.get_ai_platform()
if len(arg_elements) < 2: # 获取模型列表,判断使用的模型是否存在于列表中
await event.reply("give me a model name after 'use' command", markdown=True) models = platform.list_models()
platform = self.get_ai_platform() if f"- {argus}" in models:
models = platform.list_models() self.log.debug(f"switch model: {argus}")
if f"- {arg_elements[1]}" in models: self.config._cur_model = argus
self.log.debug(f"switch model: {arg_elements[1]}") await event.react("")
self.config._cur_model = arg_elements[1] else:
await event.react("") await event.reply("not found valid model")
@ai_command.subcommand(help="")
@command.argument("argus")
async def switch(self, event: MessageEvent, argus: str):
# 判断是否是本地ai模型如果是还需要解析#后的type
if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio':
if argus.split('#')[1] == self.config.cur_platform:
event.reply(f"current ai platform has be {argus}")
else: else:
await event.reply("not found valid model") self.config.cur_platform = argus
await event.react("")
# 如果是openai或者是claude
elif argus == 'openai' or argus == 'anthropic':
if argus == self.config.cur_platform:
event.reply(f"current ai platform has be {argus}")
else:
self.config.cur_platform = argus
await event.react("")
else:
event.reply(f"nof found ai platform: {argus}")
@classmethod @classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]: def get_config_class(cls) -> Type[BaseProxyConfig]:

View File

@@ -23,7 +23,8 @@ class AbsExtraConfigPlugin(Plugin):
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
_cur_model: str cur_model: str
cur_platform: str
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("allowed_users") helper.copy("allowed_users")
@@ -35,4 +36,6 @@ class Config(BaseProxyConfig):
helper.copy("platforms") helper.copy("platforms")
helper.copy("additional_prompt") helper.copy("additional_prompt")
self._cur_model = helper.base['platforms'][helper.base['use_platform']]['model'] self.cur_platform = helper.base['use_platform'] if helper.base['use_platform'] != 'local_ai' else \
f"{helper.base['use_platform']}#{helper.base['platforms'][helper.base['local_ai']['type']]}"
self.cur_model = helper.base['platforms'][helper.base['use_platform']]['model']