diff --git a/maubot_llmplus/aibot.py b/maubot_llmplus/aibot.py index 26ffef4..143bb85 100644 --- a/maubot_llmplus/aibot.py +++ b/maubot_llmplus/aibot.py @@ -117,20 +117,17 @@ class AiBotPlugin(AbsExtraConfigPlugin): return None def get_ai_platform(self) -> Platform: - use_platform = self.config['use_platform'] - if use_platform == 'local_ai': - type = self.config['platforms']['local_ai']['type'] - if type == 'ollama': - return Ollama(self.config, self.http) - elif type == 'lmstudio': - return LmStudio(self.config, self.http) - else: - raise ValueError(f"not found platform type: {type}") + use_platform = self.config.cur_platform if use_platform == 'openai': return OpenAi(self.config, self.http) if use_platform == 'anthropic': return Anthropic(self.config, self.http) - raise ValueError(f"unknown backend type {use_platform}") + if use_platform == 'local_ai#ollama': + return Ollama(self.config, self.http) + if use_platform == 'lmstudio': + return LmStudio(self.config, self.http) + else: + raise ValueError(f"not found platform type: {type}") """ 父命令 @@ -147,29 +144,46 @@ class AiBotPlugin(AbsExtraConfigPlugin): @ai_command.subcommand(help="") @command.argument("argus") - @command.argument("argus1") - async def model(self, event: MessageEvent, argus: str, argus1): + async def model(self, event: MessageEvent, argus: str): # 如果是list表示查看当前可以使用的模型列表 - if argus == 'list': + if argus == '#list': platform = self.get_ai_platform() models = await platform.list_models() await event.reply("\n".join(models), markdown=True) - # 如果不是,如果是其他的名称,表示这是一个模型名 - # 如果是use为第二命令,则表示要切换模型 - if argus.startswith('use'): - arg_elements = argus.strip().split(" ", 2) - # 如果命令小于2的个数,就没有写模型名,无法切换 - if len(arg_elements) < 2: - await event.reply("give me a model name after 'use' command", markdown=True) - platform = self.get_ai_platform() - models = platform.list_models() - if f"- {arg_elements[1]}" in models: - self.log.debug(f"switch model: {arg_elements[1]}") - self.config._cur_model = arg_elements[1] - await event.react("✅") + + @ai_command.subcommand(help="") + @command.argument("argus") + async def use(self, event: MessageEvent, argus: str): + platform = self.get_ai_platform() + # 获取模型列表,判断使用的模型是否存在于列表中 + models = platform.list_models() + if f"- {argus}" in models: + self.log.debug(f"switch model: {argus}") + self.config._cur_model = argus + await event.react("✅") + else: + await event.reply("not found valid model") + + @ai_command.subcommand(help="") + @command.argument("argus") + async def switch(self, event: MessageEvent, argus: str): + # 判断是否是本地ai模型,如果是还需要解析#后的type + if argus == 'local_ai#ollama' or argus == 'local_ai#lmstudio': + if argus.split('#')[1] == self.config.cur_platform: + event.reply(f"current ai platform has be {argus}") else: - await event.reply("not found valid model") + self.config.cur_platform = argus + await event.react("✅") + # 如果是openai或者是claude + elif argus == 'openai' or argus == 'anthropic': + if argus == self.config.cur_platform: + event.reply(f"current ai platform has be {argus}") + else: + self.config.cur_platform = argus + await event.react("✅") + else: + event.reply(f"nof found ai platform: {argus}") @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: diff --git a/maubot_llmplus/plugin.py b/maubot_llmplus/plugin.py index 494a902..22e17c9 100644 --- a/maubot_llmplus/plugin.py +++ b/maubot_llmplus/plugin.py @@ -23,7 +23,8 @@ class AbsExtraConfigPlugin(Plugin): class Config(BaseProxyConfig): - _cur_model: str + cur_model: str + cur_platform: str def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("allowed_users") @@ -35,4 +36,6 @@ class Config(BaseProxyConfig): helper.copy("platforms") helper.copy("additional_prompt") - self._cur_model = helper.base['platforms'][helper.base['use_platform']]['model'] + self.cur_platform = helper.base['use_platform'] if helper.base['use_platform'] != 'local_ai' else \ + f"{helper.base['use_platform']}#{helper.base['platforms'][helper.base['local_ai']['type']]}" + self.cur_model = helper.base['platforms'][helper.base['use_platform']]['model']