"""会话管理器""" from datetime import datetime from typing import List, Dict, Any from database import DatabaseManager, Session, Message class SessionManager: """会话管理器""" def __init__(self, db_path: str = None): """ 初始化会话管理器 Args: db_path: 数据库路径,默认为 chat_sessions.db """ self.db_manager = DatabaseManager(db_path or 'chat_sessions.db') self.Session = Session self.Message = Message def get_or_create_session(self, name: str, **kwargs) -> Session: """ 获取或创建会话 Args: name: 会话名称 **kwargs: 会话参数(model, system_prompt, metadata) Returns: Session对象 """ # 尝试获取现有会话 session = self.Session.get_or_none(self.Session.name == name) if session: return session # 创建新会话 defaults = { 'model': kwargs.get('model', 'deepseek-chat'), 'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'), 'metadata': kwargs.get('metadata', {}) } session = self.Session.create( name=name, **defaults ) # 添加系统提示作为第一条消息 if defaults['system_prompt']: self.Message.create( session=session, role='system', content=defaults['system_prompt'] ) return session def add_message(self, session_name: str, role: str, content: str, tokens: int = 0, metadata: Dict = None) -> str: """ 添加消息到会话 Args: session_name: 会话名称 role: 消息角色(user/assistant/system) content: 消息内容 tokens: token数量 metadata: 元数据 Returns: 消息内容 """ session = self.get_or_create_session(session_name) self.Message.create( session=session, role=role, content=content, tokens=tokens, timestamp=datetime.now(), metadata=metadata or {} ) # 更新会话时间戳 session.update_timestamp() return content def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]: """ 获取会话历史 Args: session_name: 会话名称 limit: 返回消息数量限制 Returns: 消息历史列表 """ session = self.Session.get_or_none(self.Session.name == session_name) if not session: return [] messages = (self.Message .select() .where(self.Message.session == session) .order_by(self.Message.timestamp.asc()) .limit(limit)) return [ {'role': msg.role, 'content': msg.content} for msg in messages ] def get_session_info(self, session_name: str) -> Dict: """ 获取会话信息 Args: session_name: 会话名称 Returns: 会话信息字典 """ session = self.Session.get_or_none(self.Session.name == session_name) if not session: return {} messages = session.messages message_count = messages.count() user_messages = messages.where(self.Message.role == 'user').count() assistant_messages = messages.where(self.Message.role == 'assistant').count() return { 'name': session.name, 'model': session.model, 'system_prompt': session.system_prompt, 'message_count': message_count, 'user_messages': user_messages, 'assistant_messages': assistant_messages, 'created_at': session.created_at.isoformat(), 'updated_at': session.updated_at.isoformat(), 'metadata': session.metadata } def list_sessions(self) -> List[Dict]: """ 列出所有会话 Returns: 会话列表 """ sessions = self.Session.select().order_by(self.Session.updated_at.desc()) return [ { 'name': session.name, 'message_count': session.messages.count(), 'updated_at': session.updated_at.isoformat() } for session in sessions ] def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool: """ 清空会话历史(保留系统提示) Args: session_name: 会话名称 keep_system: 是否保留系统提示 Returns: 是否成功 """ session = self.Session.get_or_none(self.Session.name == session_name) if not session: return False # 删除消息 query = session.messages if keep_system: query = query.where(self.Message.role != 'system') query.delete().execute() # 重新添加系统提示(如果需要) if keep_system and session.system_prompt: messages = session.messages.where(self.Message.role == 'system') if not messages.exists(): self.add_message( session_name=session_name, role='system', content=session.system_prompt ) return True def close(self): """关闭数据库连接""" self.db_manager.close()