change
This commit is contained in:
parent
a5164cf6bc
commit
d65c238668
|
@ -0,0 +1,161 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Desc : 生成式人工智能
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
from utils.config import Config, log
|
||||
from utils.model.data import ChatModel
|
||||
|
||||
|
||||
class BaseClientModel:
|
||||
|
||||
def __init__(self, chat_model: ChatModel):
|
||||
self.client: OpenAI = OpenAI(
|
||||
api_key=chat_model.api_key,
|
||||
base_url=chat_model.api_base,
|
||||
)
|
||||
|
||||
if chat_model.flag: # 启动系统代理
|
||||
os.environ['http_proxy'] = chat_model.proxy
|
||||
os.environ['https_proxy'] = chat_model.proxy
|
||||
|
||||
self.prompt = self.__load_prompt
|
||||
|
||||
@property
|
||||
def __load_prompt(self) -> str:
|
||||
file = str(Config.get_path("res", "prompt.key"))
|
||||
with open(file, "r", encoding="utf-8") as file_obj:
|
||||
return file_obj.read()
|
||||
|
||||
def set_prompt(self, msg: str) -> list[dict]:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"\nYou are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: "
|
||||
f"2021-09\nCurrent model: gpt-4\nCurrent time: {datetime.now()}\n {self.prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"提取 ```rule {msg} ```"
|
||||
}
|
||||
]
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def generate_frontend_data(keys: list) -> list:
|
||||
result = [
|
||||
{
|
||||
"name": key,
|
||||
"itemStyle": {
|
||||
"color": f'rgb({random.randint(0, 255)},{random.randint(0, 255)},{random.randint(0, 255)})'
|
||||
}
|
||||
} for key in keys
|
||||
]
|
||||
return result
|
||||
|
||||
def complete_data(self, resp: str):
|
||||
""" 数据补全 """
|
||||
result = {"clean": [], 'origin': self.clean_parse_data(resp)}
|
||||
|
||||
if result['origin'] is None:
|
||||
raise RuntimeError(f"数据为空")
|
||||
|
||||
try:
|
||||
# 提取function中所有的第一层键
|
||||
func_keys = list(result['origin']['function'].keys())
|
||||
|
||||
# 基础功能点
|
||||
result['clean'].append(self.generate_frontend_data(func_keys))
|
||||
|
||||
except Exception as err:
|
||||
log.debug(err)
|
||||
raise RuntimeError(err)
|
||||
|
||||
return result
|
||||
|
||||
def clean_parse_data(self, input_str: str):
|
||||
""" 数据清洗 """
|
||||
parsed_content: Optional[dict] = None
|
||||
|
||||
try:
|
||||
try:
|
||||
parsed_content = json.loads(input_str)
|
||||
except json.JSONDecodeError as _:
|
||||
|
||||
input_str = input_str.replace("`", "").strip("\n")
|
||||
try:
|
||||
# 寻找第一个 "{" 的位置
|
||||
first_brace_index = input_str.find("{")
|
||||
|
||||
# 寻找最后一个 "}" 的位置
|
||||
last_brace_index = input_str.rfind("}")
|
||||
input_str = input_str[first_brace_index:last_brace_index + 1]
|
||||
except IndexError or Exception as err:
|
||||
log.warning(err)
|
||||
pass
|
||||
|
||||
# 使用正则表达式提取大括号内的内容
|
||||
matches = re.findall(r'\{(.*)}', input_str)
|
||||
|
||||
if matches:
|
||||
# 获取匹配到的内容
|
||||
extracted = matches[0] # 假设只有一个匹配,你可以根据实际情况调整
|
||||
|
||||
# 尝试将提取的内容反序列化为JSON
|
||||
parsed_content = json.loads("{" + extracted + "}")
|
||||
else:
|
||||
# 如果没有匹配到大括号,抛出异常
|
||||
raise Exception("未找到大括号内的内容")
|
||||
finally:
|
||||
log.warning(parsed_content)
|
||||
if isinstance(parsed_content, dict):
|
||||
self.save_to_file(parsed_content)
|
||||
# 返回反序列化后的JSON
|
||||
return parsed_content
|
||||
|
||||
except Exception as err:
|
||||
log.exception(f"发生异常: {err}")
|
||||
raise RuntimeError(err)
|
||||
|
||||
@staticmethod
|
||||
def save_to_file(parsed_content):
|
||||
try:
|
||||
# 获取当前时间作为文件名
|
||||
date_line = datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
filename = Config.get_path("res", "gen", f"{date_line}.json")
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as file_obj:
|
||||
file_obj.write(json.dumps(parsed_content, ensure_ascii=False))
|
||||
|
||||
log.info(f"文件 '{date_line}.json' 已保存至 res/gen 目录中")
|
||||
|
||||
except Exception as e:
|
||||
# 捕获写入文件时的异常并抛出
|
||||
raise Exception(f"写入文件时发生异常: {str(e)}")
|
||||
|
||||
def send(self, msg: str):
|
||||
try:
|
||||
response = self.client.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt=self.set_prompt(msg),
|
||||
temperature=0.8,
|
||||
presence_penalty=0,
|
||||
top_p=1
|
||||
)
|
||||
except openai.RateLimitError:
|
||||
raise RuntimeError("访问太过频繁了!")
|
||||
else:
|
||||
content = response['choices'][0]['message']['content']
|
||||
result = self.complete_data(content)
|
||||
result['tokens'] = response.usage["total_tokens"]
|
||||
log.warning(f"实际消耗 Tokens: {result['tokens']}")
|
||||
return result
|
|
@ -0,0 +1,4 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
自实现 GPT-Agent 代理
|
||||
"""
|
|
@ -0,0 +1,10 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Desc : 代理执行器
|
||||
"""
|
||||
|
||||
|
||||
class Executor:
|
||||
"""
|
||||
执行器
|
||||
"""
|
|
@ -0,0 +1,10 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Desc : 解释器
|
||||
"""
|
||||
|
||||
|
||||
class Interpreter:
|
||||
"""
|
||||
解释器
|
||||
"""
|
|
@ -0,0 +1,20 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
"""
|
||||
@Desc : 数据模型
|
||||
|
||||
> 用于校验数据格式
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
# 秘钥
|
||||
api_key: str
|
||||
# 中转站
|
||||
api_base: Optional[str]
|
||||
# 请求代理
|
||||
proxy: Optional[str]
|
||||
# 是否代理
|
||||
flag: bool = False
|
Loading…
Reference in New Issue