add: 添加ollama调用AI chat逻辑

This commit is contained in:
taylor
2024-10-13 17:07:42 +08:00
parent bab547198f
commit 5821623aae
4 changed files with 19 additions and 16 deletions

View File

@@ -15,7 +15,6 @@ from maubot_llmplus.thrid_platform import OpenAi, Anthropic
配置文件加载 配置文件加载
""" """
class Config(BaseProxyConfig): class Config(BaseProxyConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("allowed_users") helper.copy("allowed_users")
@@ -30,12 +29,14 @@ class Config(BaseProxyConfig):
class AiBotPlugin(Plugin): class AiBotPlugin(Plugin):
name: str
async def start(self) -> None: async def start(self) -> None:
await super().start() await super().start()
# 加载并更新配置 # 加载并更新配置
self.config.load_and_update() self.config.load_and_update()
# 决定当前机器人的名称 # 决定当前机器人的名称
super.name = self.config['name'] or \ self.name = self.config['name'] or \
await self.client.get_displayname(self.client.mxid) or \ await self.client.get_displayname(self.client.mxid) or \
self.client.parse_user_id(self.client.mxid)[0] self.client.parse_user_id(self.client.mxid)[0]
@@ -136,15 +137,15 @@ class AiBotPlugin(Plugin):
if use_platform == 'local_ai': if use_platform == 'local_ai':
type = self.config['platforms']['local_ai']['type'] type = self.config['platforms']['local_ai']['type']
if type == 'ollama': if type == 'ollama':
return Ollama(self.config, self.http) return Ollama(self.config, self.name, self.http)
elif type == 'lmstudio': elif type == 'lmstudio':
return LmStudio(self.config, self.http) return LmStudio(self.config, self.name, self.http)
else: else:
raise ValueError(f"not found platform type: {type}") raise ValueError(f"not found platform type: {type}")
if use_platform == 'openai': if use_platform == 'openai':
return OpenAi(self.config, self.http) return OpenAi(self.config, self.name, self.http)
if use_platform == 'anthropic': if use_platform == 'anthropic':
return Anthropic(self.config, self.http) return Anthropic(self.config, self.name, self.http)
raise ValueError(f"unknown backend type {use_platform}") raise ValueError(f"unknown backend type {use_platform}")
@classmethod @classmethod

View File

@@ -13,13 +13,13 @@ from maubot_llmplus.platforms import Platform, ChatCompletion
class Ollama(Platform): class Ollama(Platform):
chat_api: str chat_api: str
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, name, http)
self.chat_api = '/api/chat' self.chat_api = '/api/chat'
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
full_context = [] full_context = []
context = await maubot_llmplus.platforms.get_context(plugin, evt) context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
full_context.extend(list(context)) full_context.extend(list(context))
endpoint = f"{self.url}/api/chat" endpoint = f"{self.url}/api/chat"
@@ -47,8 +47,8 @@ class Ollama(Platform):
class LmStudio(Platform): class LmStudio(Platform):
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, name, http)
pass pass
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:

View File

@@ -33,8 +33,9 @@ class Platform:
additional_prompt: List[dict] additional_prompt: List[dict]
system_prompt: str system_prompt: str
max_context_messages: int max_context_messages: int
name: str
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
self.http = http self.http = http
self.config = config['platforms'][self.get_type()] self.config = config['platforms'][self.get_type()]
self.url = self.config['url'] self.url = self.config['url']
@@ -44,6 +45,7 @@ class Platform:
self.max_context_messages = self.config['max_context_messages'] self.max_context_messages = self.config['max_context_messages']
self.additional_prompt = config['additional_prompt'] self.additional_prompt = config['additional_prompt']
self.system_prompt = config['system_prompt'] self.system_prompt = config['system_prompt']
self.name = name
"""a """a
调用AI对话接口, 响应结果 调用AI对话接口, 响应结果

View File

@@ -8,8 +8,8 @@ from maubot_llmplus.platforms import Platform, ChatCompletion
class OpenAi(Platform): class OpenAi(Platform):
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, name, http)
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词 # 获取系统提示词
@@ -23,8 +23,8 @@ class OpenAi(Platform):
class Anthropic(Platform): class Anthropic(Platform):
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, name: str, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, name, http)
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词 # 获取系统提示词