fewfefffwefweef
This commit is contained in:
183
ai_main/client.py
Normal file
183
ai_main/client.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""DeepSeek聊天客户端"""
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from session_manager import SessionManager
|
||||
|
||||
|
||||
class DeepSeekChatClient:
|
||||
"""简化的DeepSeek聊天客户端"""
|
||||
|
||||
def __init__(self, api_key: str, db_path: str = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: DeepSeek API密钥
|
||||
db_path: 数据库路径,默认为chat_sessions.db
|
||||
"""
|
||||
self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com')
|
||||
self.session_manager = SessionManager(db_path)
|
||||
|
||||
def chat(self, user_input: str, name: str, model: str = None,
|
||||
system_prompt: str = None, **kwargs) -> str:
|
||||
"""
|
||||
发送聊天消息
|
||||
|
||||
Args:
|
||||
user_input: 用户输入的问题
|
||||
name: 会话名称
|
||||
model: 模型名称(仅在创建新会话时使用)
|
||||
system_prompt: 系统提示(仅在创建新会话时使用)
|
||||
**kwargs: 其他OpenAI API参数
|
||||
|
||||
Returns:
|
||||
AI的回复内容
|
||||
"""
|
||||
# 获取或创建会话
|
||||
session = self.session_manager.get_or_create_session(
|
||||
name=name,
|
||||
model=model or 'deepseek-chat',
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
# 保存用户消息
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='user',
|
||||
content=user_input,
|
||||
tokens=int(len(user_input) * 0.8)
|
||||
)
|
||||
|
||||
# 构建消息列表
|
||||
messages = self._build_messages(session, user_input)
|
||||
|
||||
# 调用API
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=session.model,
|
||||
messages=messages,
|
||||
max_tokens=kwargs.get('max_tokens', 2000),
|
||||
temperature=kwargs.get('temperature', 0.7),
|
||||
stream=kwargs.get('stream', False)
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
ai_reply = response.choices[0].message.content
|
||||
|
||||
# 保存AI回复
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='assistant',
|
||||
content=ai_reply,
|
||||
tokens=int(len(ai_reply) * 0.8)
|
||||
)
|
||||
|
||||
return ai_reply
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"API调用错误: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
# 保存错误信息
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='system',
|
||||
content=error_msg,
|
||||
metadata={'error': True, 'error_type': type(e).__name__}
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
def _build_messages(self, session, user_input: str) -> List[Dict]:
|
||||
"""
|
||||
构建消息列表
|
||||
|
||||
Args:
|
||||
session: 会话对象
|
||||
user_input: 用户输入
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 获取会话历史
|
||||
history = self.session_manager.get_session_history(name=session.name, limit=10)
|
||||
|
||||
messages = []
|
||||
|
||||
# 添加系统提示(如果历史中没有)
|
||||
if not any(msg['role'] == 'system' for msg in history):
|
||||
messages.append({
|
||||
'role': 'system',
|
||||
'content': session.system_prompt
|
||||
})
|
||||
|
||||
# 添加上下文历史
|
||||
messages.extend(history)
|
||||
|
||||
# 添加当前用户消息
|
||||
messages.append({
|
||||
'role': 'user',
|
||||
'content': user_input
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
def get_session_info(self, name: str) -> Dict:
|
||||
"""
|
||||
获取会话信息
|
||||
|
||||
Args:
|
||||
name: 会话名称
|
||||
|
||||
Returns:
|
||||
会话信息字典
|
||||
"""
|
||||
return self.session_manager.get_session_info(name)
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""
|
||||
列出所有会话
|
||||
|
||||
Returns:
|
||||
会话列表
|
||||
"""
|
||||
return self.session_manager.list_sessions()
|
||||
|
||||
def export_session(self, name: str) -> Dict:
|
||||
"""
|
||||
导出会话数据
|
||||
|
||||
Args:
|
||||
name: 会话名称
|
||||
|
||||
Returns:
|
||||
导出的会话数据
|
||||
"""
|
||||
info = self.get_session_info(name)
|
||||
history = self.session_manager.get_session_history(name, limit=100)
|
||||
|
||||
return {
|
||||
'session_info': info,
|
||||
'history': history,
|
||||
'exported_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def clear_session_history(self, name: str, keep_system: bool = True) -> bool:
|
||||
"""
|
||||
清空会话历史(保留系统提示)
|
||||
|
||||
Args:
|
||||
name: 会话名称
|
||||
keep_system: 是否保留系统提示
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return self.session_manager.clear_session_history(name, keep_system)
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
self.session_manager.close()
|
||||
108
ai_main/database.py
Normal file
108
ai_main/database.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""数据库模型和管理器"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
from peewee import Model, CharField, TextField, IntegerField, DateTimeField, BooleanField, ForeignKeyField
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
"""基础模型类"""
|
||||
pass
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""聊天会话模型"""
|
||||
name = CharField(max_length=255, unique=True)
|
||||
model = CharField(max_length=50, default='deepseek-chat')
|
||||
system_prompt = TextField(default='你是一个乐于助人的助手。')
|
||||
created_at = DateTimeField(default=datetime.now)
|
||||
updated_at = DateTimeField(default=datetime.now)
|
||||
is_active = BooleanField(default=True)
|
||||
metadata_json = TextField(default='{}')
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict:
|
||||
"""获取元数据"""
|
||||
return json.loads(self.metadata_json) if self.metadata_json else {}
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, value: Dict):
|
||||
"""设置元数据"""
|
||||
self.metadata_json = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def update_timestamp(self):
|
||||
"""更新时间戳"""
|
||||
self.updated_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
class Meta:
|
||||
table_name = 'session'
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""聊天消息模型"""
|
||||
session = ForeignKeyField(Session, backref='messages', on_delete='CASCADE')
|
||||
role = CharField(max_length=20, index=True)
|
||||
content = TextField()
|
||||
tokens = IntegerField(default=0)
|
||||
timestamp = DateTimeField(default=datetime.now, index=True)
|
||||
metadata_json = TextField(default='{}')
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict:
|
||||
"""获取元数据"""
|
||||
return json.loads(self.metadata_json) if self.metadata_json else {}
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, value: Dict):
|
||||
"""设置元数据"""
|
||||
self.metadata_json = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
class Meta:
|
||||
table_name = 'message'
|
||||
indexes = (
|
||||
(('session', 'timestamp'), False),
|
||||
)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""数据库管理器"""
|
||||
|
||||
def __init__(self, db_path: str = 'chat_sessions.db'):
|
||||
"""
|
||||
初始化数据库管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
"""
|
||||
self.db = SqliteExtDatabase(
|
||||
db_path,
|
||||
pragmas={
|
||||
'journal_mode': 'wal',
|
||||
'cache_size': -1024 * 64,
|
||||
'foreign_keys': 1,
|
||||
'ignore_check_constraints': 0,
|
||||
'synchronous': 1
|
||||
}
|
||||
)
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化数据库连接和表"""
|
||||
# 设置数据库
|
||||
BaseModel._meta.database = self.db
|
||||
Session._meta.database = self.db
|
||||
Message._meta.database = self.db
|
||||
|
||||
# 连接数据库
|
||||
self.db.connect()
|
||||
|
||||
# 创建表
|
||||
self.db.create_tables([Session, Message], safe=True)
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if not self.db.is_closed():
|
||||
self.db.close()
|
||||
203
ai_main/session_manager.py
Normal file
203
ai_main/session_manager.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""会话管理器"""
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from database import DatabaseManager, Session, Message
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, db_path: str = None):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库路径,默认为 chat_sessions.db
|
||||
"""
|
||||
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
|
||||
self.Session = Session
|
||||
self.Message = Message
|
||||
|
||||
def get_or_create_session(self, name: str, **kwargs) -> Session:
|
||||
"""
|
||||
获取或创建会话
|
||||
|
||||
Args:
|
||||
name: 会话名称
|
||||
**kwargs: 会话参数(model, system_prompt, metadata)
|
||||
|
||||
Returns:
|
||||
Session对象
|
||||
"""
|
||||
# 尝试获取现有会话
|
||||
session = self.Session.get_or_none(self.Session.name == name)
|
||||
|
||||
if session:
|
||||
return session
|
||||
|
||||
# 创建新会话
|
||||
defaults = {
|
||||
'model': kwargs.get('model', 'deepseek-chat'),
|
||||
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
|
||||
'metadata': kwargs.get('metadata', {})
|
||||
}
|
||||
|
||||
session = self.Session.create(
|
||||
name=name,
|
||||
**defaults
|
||||
)
|
||||
|
||||
# 添加系统提示作为第一条消息
|
||||
if defaults['system_prompt']:
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role='system',
|
||||
content=defaults['system_prompt']
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def add_message(self, session_name: str, role: str, content: str,
|
||||
tokens: int = 0, metadata: Dict = None) -> str:
|
||||
"""
|
||||
添加消息到会话
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
role: 消息角色(user/assistant/system)
|
||||
content: 消息内容
|
||||
tokens: token数量
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
消息内容
|
||||
"""
|
||||
session = self.get_or_create_session(session_name)
|
||||
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role=role,
|
||||
content=content,
|
||||
tokens=tokens,
|
||||
timestamp=datetime.now(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# 更新会话时间戳
|
||||
session.update_timestamp()
|
||||
|
||||
return content
|
||||
|
||||
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
|
||||
"""
|
||||
获取会话历史
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
limit: 返回消息数量限制
|
||||
|
||||
Returns:
|
||||
消息历史列表
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return []
|
||||
|
||||
messages = (self.Message
|
||||
.select()
|
||||
.where(self.Message.session == session)
|
||||
.order_by(self.Message.timestamp.asc())
|
||||
.limit(limit))
|
||||
|
||||
return [
|
||||
{'role': msg.role, 'content': msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
def get_session_info(self, session_name: str) -> Dict:
|
||||
"""
|
||||
获取会话信息
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
|
||||
Returns:
|
||||
会话信息字典
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
messages = session.messages
|
||||
message_count = messages.count()
|
||||
user_messages = messages.where(self.Message.role == 'user').count()
|
||||
assistant_messages = messages.where(self.Message.role == 'assistant').count()
|
||||
|
||||
return {
|
||||
'name': session.name,
|
||||
'model': session.model,
|
||||
'system_prompt': session.system_prompt,
|
||||
'message_count': message_count,
|
||||
'user_messages': user_messages,
|
||||
'assistant_messages': assistant_messages,
|
||||
'created_at': session.created_at.isoformat(),
|
||||
'updated_at': session.updated_at.isoformat(),
|
||||
'metadata': session.metadata
|
||||
}
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""
|
||||
列出所有会话
|
||||
|
||||
Returns:
|
||||
会话列表
|
||||
"""
|
||||
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
|
||||
|
||||
return [
|
||||
{
|
||||
'name': session.name,
|
||||
'message_count': session.messages.count(),
|
||||
'updated_at': session.updated_at.isoformat()
|
||||
}
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
|
||||
"""
|
||||
清空会话历史(保留系统提示)
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
keep_system: 是否保留系统提示
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
# 删除消息
|
||||
query = session.messages
|
||||
if keep_system:
|
||||
query = query.where(self.Message.role != 'system')
|
||||
|
||||
query.delete().execute()
|
||||
|
||||
# 重新添加系统提示(如果需要)
|
||||
if keep_system and session.system_prompt:
|
||||
messages = session.messages.where(self.Message.role == 'system')
|
||||
if not messages.exists():
|
||||
self.add_message(
|
||||
session_name=session_name,
|
||||
role='system',
|
||||
content=session.system_prompt
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
self.db_manager.close()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
sessions/14386887066.session-journal
Normal file
BIN
sessions/14386887066.session-journal
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
sessions/14434436234.session-journal
Normal file
BIN
sessions/14434436234.session-journal
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
sessions/15482999443.session-journal
Normal file
BIN
sessions/15482999443.session-journal
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user