add: 添加chatgpt chat api逻辑

This commit is contained in:
taylor
2024-10-14 00:16:19 +08:00
parent 8c16faa34a
commit 0f7a2c4b33
4 changed files with 59 additions and 16 deletions

View File

@@ -19,12 +19,13 @@ platforms:
max_words: 1000 max_words: 1000
max_context_messages: 20 max_context_messages: 20
openai: openai:
url: url: https://api.openai.com
api_key: api_key:
model: model: gpt4o
max_tokens: max_tokens: 2000
max_words: max_words: 1000
temperature: max_context_messages: 20
temperature: 1
anthropic: anthropic:
url: url:
api_key: api_key:

View File

@@ -1,14 +1,15 @@
import json
from typing import List from typing import List
from aiohttp import ClientSession from aiohttp import ClientSession
from maubot import Plugin
from mautrix.types import MessageEvent from mautrix.types import MessageEvent
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
import maubot_llmplus import maubot_llmplus
import maubot_llmplus.platforms import maubot_llmplus.platforms
from maubot_llmplus.platforms import Platform, ChatCompletion from maubot_llmplus.platforms import Platform, ChatCompletion
from maubot_llmplus.plugin import AbsExtraConfigPlugin
class Ollama(Platform): class Ollama(Platform):
@@ -18,7 +19,7 @@ class Ollama(Platform):
super().__init__(config, http) super().__init__(config, http)
self.chat_api = '/api/chat' self.chat_api = '/api/chat'
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
full_context = [] full_context = []
context = await maubot_llmplus.platforms.get_context(plugin, self, evt) context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
full_context.extend(list(context)) full_context.extend(list(context))
@@ -59,5 +60,5 @@ class LmStudio(Platform):
super().__init__(config, http) super().__init__(config, http)
pass pass
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
pass pass

View File

@@ -51,7 +51,7 @@ class Platform:
调用AI对话接口, 响应结果 调用AI对话接口, 响应结果
""" """
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_models(self) -> List[str]: async def list_models(self) -> List[str]:

View File

@@ -1,21 +1,62 @@
import json
from aiohttp import ClientSession from aiohttp import ClientSession
from maubot import Plugin
from mautrix.types import MessageEvent from mautrix.types import MessageEvent
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
import maubot_llmplus.platforms
from maubot_llmplus.platforms import Platform, ChatCompletion from maubot_llmplus.platforms import Platform, ChatCompletion
from maubot_llmplus.plugin import AbsExtraConfigPlugin
class OpenAi(Platform): class OpenAi(Platform):
max_tokens: int
temperature: int
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, http)
self.max_tokens = self.config['max_tokens']
self.temperature = self.config['temperature']
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词 full_context = []
# 获取额外的其他角色的提示词: role: user role: system context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
full_context.extend(list(context))
pass headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config['gpt_api_key']}"
}
data = {
"model": self.model,
"messages": full_context,
}
if 'max_tokens' in self.config and self.max_tokens:
data["max_tokens"] = self.max_tokens
if 'temperature' in self.config and self.temperature:
data["temperature"] = self.temperature
endpoint = f"{self.url}/v1/chat/completions"
async with self.http.post(
endpoint, headers=headers, data=json.dumps(data)
) as response:
# plugin.log.debug(f"响应内容:{response.status}, {await response.json()}")
if response.status != 200:
return ChatCompletion(
message={},
finish_reason=f"Error: {await response.text()}",
model=None
)
response_json = await response.json()
choice = response_json["choices"][0]
return ChatCompletion(
message=choice["message"],
finish_reason=choice["finish_reason"],
model=choice.get("model", None)
)
def get_type(self) -> str: def get_type(self) -> str:
return "openai" return "openai"
@@ -26,7 +67,7 @@ class Anthropic(Platform):
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None: def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
super().__init__(config, http) super().__init__(config, http)
async def create_chat_completion(self, plugin: Plugin, evt: MessageEvent) -> ChatCompletion: async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
# 获取系统提示词 # 获取系统提示词
# 获取额外的其他角色的提示词: role: user role: system # 获取额外的其他角色的提示词: role: user role: system