Files
mini_code/ai_main/session_manager.py

204 lines
5.6 KiB
Python
Raw Normal View History

2026-01-09 17:51:09 +08:00
"""会话管理器"""
from datetime import datetime
from typing import List, Dict, Any
from database import DatabaseManager, Session, Message
class SessionManager:
"""会话管理器"""
def __init__(self, db_path: str = None):
"""
初始化会话管理器
Args:
db_path: 数据库路径默认为 chat_sessions.db
"""
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
self.Session = Session
self.Message = Message
def get_or_create_session(self, name: str, **kwargs) -> Session:
"""
获取或创建会话
Args:
name: 会话名称
**kwargs: 会话参数model, system_prompt, metadata
Returns:
Session对象
"""
# 尝试获取现有会话
session = self.Session.get_or_none(self.Session.name == name)
if session:
return session
# 创建新会话
defaults = {
'model': kwargs.get('model', 'deepseek-chat'),
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
'metadata': kwargs.get('metadata', {})
}
session = self.Session.create(
name=name,
**defaults
)
# 添加系统提示作为第一条消息
if defaults['system_prompt']:
self.Message.create(
session=session,
role='system',
content=defaults['system_prompt']
)
return session
def add_message(self, session_name: str, role: str, content: str,
tokens: int = 0, metadata: Dict = None) -> str:
"""
添加消息到会话
Args:
session_name: 会话名称
role: 消息角色user/assistant/system
content: 消息内容
tokens: token数量
metadata: 元数据
Returns:
消息内容
"""
session = self.get_or_create_session(session_name)
self.Message.create(
session=session,
role=role,
content=content,
tokens=tokens,
timestamp=datetime.now(),
metadata=metadata or {}
)
# 更新会话时间戳
session.update_timestamp()
return content
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
"""
获取会话历史
Args:
session_name: 会话名称
limit: 返回消息数量限制
Returns:
消息历史列表
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return []
messages = (self.Message
.select()
.where(self.Message.session == session)
.order_by(self.Message.timestamp.asc())
.limit(limit))
return [
{'role': msg.role, 'content': msg.content}
for msg in messages
]
def get_session_info(self, session_name: str) -> Dict:
"""
获取会话信息
Args:
session_name: 会话名称
Returns:
会话信息字典
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return {}
messages = session.messages
message_count = messages.count()
user_messages = messages.where(self.Message.role == 'user').count()
assistant_messages = messages.where(self.Message.role == 'assistant').count()
return {
'name': session.name,
'model': session.model,
'system_prompt': session.system_prompt,
'message_count': message_count,
'user_messages': user_messages,
'assistant_messages': assistant_messages,
'created_at': session.created_at.isoformat(),
'updated_at': session.updated_at.isoformat(),
'metadata': session.metadata
}
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
会话列表
"""
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
return [
{
'name': session.name,
'message_count': session.messages.count(),
'updated_at': session.updated_at.isoformat()
}
for session in sessions
]
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
"""
清空会话历史保留系统提示
Args:
session_name: 会话名称
keep_system: 是否保留系统提示
Returns:
是否成功
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return False
# 删除消息
query = session.messages
if keep_system:
query = query.where(self.Message.role != 'system')
query.delete().execute()
# 重新添加系统提示(如果需要)
if keep_system and session.system_prompt:
messages = session.messages.where(self.Message.role == 'system')
if not messages.exists():
self.add_message(
session_name=session_name,
role='system',
content=session.system_prompt
)
return True
def close(self):
"""关闭数据库连接"""
self.db_manager.close()