204 lines
5.6 KiB
Python
204 lines
5.6 KiB
Python
|
|
"""会话管理器"""
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import List, Dict, Any
|
|||
|
|
|
|||
|
|
from database import DatabaseManager, Session, Message
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SessionManager:
|
|||
|
|
"""会话管理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path: str = None):
|
|||
|
|
"""
|
|||
|
|
初始化会话管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path: 数据库路径,默认为 chat_sessions.db
|
|||
|
|
"""
|
|||
|
|
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
|
|||
|
|
self.Session = Session
|
|||
|
|
self.Message = Message
|
|||
|
|
|
|||
|
|
def get_or_create_session(self, name: str, **kwargs) -> Session:
|
|||
|
|
"""
|
|||
|
|
获取或创建会话
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 会话名称
|
|||
|
|
**kwargs: 会话参数(model, system_prompt, metadata)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Session对象
|
|||
|
|
"""
|
|||
|
|
# 尝试获取现有会话
|
|||
|
|
session = self.Session.get_or_none(self.Session.name == name)
|
|||
|
|
|
|||
|
|
if session:
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
# 创建新会话
|
|||
|
|
defaults = {
|
|||
|
|
'model': kwargs.get('model', 'deepseek-chat'),
|
|||
|
|
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
|
|||
|
|
'metadata': kwargs.get('metadata', {})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
session = self.Session.create(
|
|||
|
|
name=name,
|
|||
|
|
**defaults
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 添加系统提示作为第一条消息
|
|||
|
|
if defaults['system_prompt']:
|
|||
|
|
self.Message.create(
|
|||
|
|
session=session,
|
|||
|
|
role='system',
|
|||
|
|
content=defaults['system_prompt']
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def add_message(self, session_name: str, role: str, content: str,
|
|||
|
|
tokens: int = 0, metadata: Dict = None) -> str:
|
|||
|
|
"""
|
|||
|
|
添加消息到会话
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
session_name: 会话名称
|
|||
|
|
role: 消息角色(user/assistant/system)
|
|||
|
|
content: 消息内容
|
|||
|
|
tokens: token数量
|
|||
|
|
metadata: 元数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息内容
|
|||
|
|
"""
|
|||
|
|
session = self.get_or_create_session(session_name)
|
|||
|
|
|
|||
|
|
self.Message.create(
|
|||
|
|
session=session,
|
|||
|
|
role=role,
|
|||
|
|
content=content,
|
|||
|
|
tokens=tokens,
|
|||
|
|
timestamp=datetime.now(),
|
|||
|
|
metadata=metadata or {}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 更新会话时间戳
|
|||
|
|
session.update_timestamp()
|
|||
|
|
|
|||
|
|
return content
|
|||
|
|
|
|||
|
|
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
获取会话历史
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
session_name: 会话名称
|
|||
|
|
limit: 返回消息数量限制
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息历史列表
|
|||
|
|
"""
|
|||
|
|
session = self.Session.get_or_none(self.Session.name == session_name)
|
|||
|
|
if not session:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
messages = (self.Message
|
|||
|
|
.select()
|
|||
|
|
.where(self.Message.session == session)
|
|||
|
|
.order_by(self.Message.timestamp.asc())
|
|||
|
|
.limit(limit))
|
|||
|
|
|
|||
|
|
return [
|
|||
|
|
{'role': msg.role, 'content': msg.content}
|
|||
|
|
for msg in messages
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def get_session_info(self, session_name: str) -> Dict:
|
|||
|
|
"""
|
|||
|
|
获取会话信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
session_name: 会话名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
会话信息字典
|
|||
|
|
"""
|
|||
|
|
session = self.Session.get_or_none(self.Session.name == session_name)
|
|||
|
|
if not session:
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
messages = session.messages
|
|||
|
|
message_count = messages.count()
|
|||
|
|
user_messages = messages.where(self.Message.role == 'user').count()
|
|||
|
|
assistant_messages = messages.where(self.Message.role == 'assistant').count()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'name': session.name,
|
|||
|
|
'model': session.model,
|
|||
|
|
'system_prompt': session.system_prompt,
|
|||
|
|
'message_count': message_count,
|
|||
|
|
'user_messages': user_messages,
|
|||
|
|
'assistant_messages': assistant_messages,
|
|||
|
|
'created_at': session.created_at.isoformat(),
|
|||
|
|
'updated_at': session.updated_at.isoformat(),
|
|||
|
|
'metadata': session.metadata
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def list_sessions(self) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
列出所有会话
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
会话列表
|
|||
|
|
"""
|
|||
|
|
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
|
|||
|
|
|
|||
|
|
return [
|
|||
|
|
{
|
|||
|
|
'name': session.name,
|
|||
|
|
'message_count': session.messages.count(),
|
|||
|
|
'updated_at': session.updated_at.isoformat()
|
|||
|
|
}
|
|||
|
|
for session in sessions
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
|
|||
|
|
"""
|
|||
|
|
清空会话历史(保留系统提示)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
session_name: 会话名称
|
|||
|
|
keep_system: 是否保留系统提示
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
session = self.Session.get_or_none(self.Session.name == session_name)
|
|||
|
|
if not session:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 删除消息
|
|||
|
|
query = session.messages
|
|||
|
|
if keep_system:
|
|||
|
|
query = query.where(self.Message.role != 'system')
|
|||
|
|
|
|||
|
|
query.delete().execute()
|
|||
|
|
|
|||
|
|
# 重新添加系统提示(如果需要)
|
|||
|
|
if keep_system and session.system_prompt:
|
|||
|
|
messages = session.messages.where(self.Message.role == 'system')
|
|||
|
|
if not messages.exists():
|
|||
|
|
self.add_message(
|
|||
|
|
session_name=session_name,
|
|||
|
|
role='system',
|
|||
|
|
content=session.system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭数据库连接"""
|
|||
|
|
self.db_manager.close()
|