add
This commit is contained in:
@@ -10,7 +10,7 @@ from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
|
||||
from maubot_llmplus.local_paltform import Ollama, LmStudio
|
||||
from maubot_llmplus.platforms import Platform
|
||||
from maubot_llmplus.plugin import AbsExtraConfigPlugin, Config
|
||||
from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi, Deepseek
|
||||
from maubot_llmplus.thrid_platform import OpenAi, Anthropic, XAi, Deepseek, Gemini
|
||||
|
||||
|
||||
class AiBotPlugin(AbsExtraConfigPlugin):
|
||||
@@ -160,6 +160,8 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
||||
return XAi(self.config, self.http)
|
||||
if use_platform == 'deepseek':
|
||||
return Deepseek(self.config, self.http)
|
||||
if use_platform == 'gemini':
|
||||
return Gemini(self.config, self.http)
|
||||
if use_platform == 'local_ai#ollama':
|
||||
return Ollama(self.config, self.http)
|
||||
if use_platform == 'local_ai#lmstudio':
|
||||
@@ -298,7 +300,7 @@ class AiBotPlugin(AbsExtraConfigPlugin):
|
||||
self.config.cur_model = self.config['platforms'][argus.split("#")[0]]['model']
|
||||
await event.react("✅")
|
||||
# 如果是openai或者是claude
|
||||
elif argus == 'openai' or argus == 'anthropic' or argus == 'xai' or argus == 'deepseek':
|
||||
elif argus == 'openai' or argus == 'anthropic' or argus == 'xai' or argus == 'deepseek' or argus == 'gemini':
|
||||
if argus == self.config.cur_platform:
|
||||
await event.reply(f"current ai platform has be {argus}")
|
||||
pass
|
||||
|
||||
@@ -181,6 +181,85 @@ class Anthropic(Platform):
|
||||
return "anthropic"
|
||||
|
||||
|
||||
class Gemini(Platform):
|
||||
max_tokens: int
|
||||
temperature: float
|
||||
|
||||
def __init__(self, config: BaseProxyConfig, http: ClientSession) -> None:
|
||||
super().__init__(config, http)
|
||||
self.max_tokens = self.config['max_tokens']
|
||||
self.temperature = self.config['temperature']
|
||||
|
||||
async def create_chat_completion(self, plugin: AbsExtraConfigPlugin, evt: MessageEvent) -> ChatCompletion:
|
||||
context = await maubot_llmplus.platforms.get_context(plugin, self, evt)
|
||||
|
||||
system_parts = []
|
||||
contents = []
|
||||
for msg in context:
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
if role == 'system':
|
||||
system_parts.append({"text": content})
|
||||
elif role == 'assistant':
|
||||
contents.append({"role": "model", "parts": [{"text": content}]})
|
||||
else:
|
||||
contents.append({"role": "user", "parts": [{"text": content}]})
|
||||
|
||||
request_body = {
|
||||
"contents": contents,
|
||||
"generationConfig": {}
|
||||
}
|
||||
|
||||
if system_parts:
|
||||
request_body["system_instruction"] = {"parts": system_parts}
|
||||
|
||||
if self.max_tokens:
|
||||
request_body["generationConfig"]["maxOutputTokens"] = self.max_tokens
|
||||
|
||||
if self.temperature:
|
||||
request_body["generationConfig"]["temperature"] = self.temperature
|
||||
|
||||
endpoint = f"{self.url}/v1beta/models/{self.model}:generateContent"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-goog-api-key": self.api_key
|
||||
}
|
||||
|
||||
async with self.http.post(endpoint, headers=headers, data=json.dumps(request_body)) as response:
|
||||
if response.status != 200:
|
||||
return ChatCompletion(
|
||||
result=False,
|
||||
message={},
|
||||
finish_reason=f"Error: {await response.text()}",
|
||||
model=None
|
||||
)
|
||||
response_json = await response.json()
|
||||
candidate = response_json["candidates"][0]
|
||||
text = "".join(part["text"] for part in candidate["content"]["parts"])
|
||||
return ChatCompletion(
|
||||
result=True,
|
||||
message={"role": "assistant", "content": text},
|
||||
finish_reason=candidate.get("finishReason", "STOP"),
|
||||
model=response_json.get("modelVersion", self.model)
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
full_url = f"{self.url}/v1beta/models"
|
||||
headers = {"x-goog-api-key": self.api_key}
|
||||
async with self.http.get(full_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return []
|
||||
response_data = await response.json()
|
||||
return [
|
||||
f"- {m['name'].replace('models/', '')}"
|
||||
for m in response_data.get("models", [])
|
||||
if "generateContent" in m.get("supportedGenerationMethods", [])
|
||||
]
|
||||
|
||||
def get_type(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
|
||||
class XAi(Platform):
|
||||
max_tokens: int
|
||||
temperature: int
|
||||
|
||||
Reference in New Issue
Block a user