204 lines
5.6 KiB
Python
204 lines
5.6 KiB
Python
"""会话管理器"""
|
||
from datetime import datetime
|
||
from typing import List, Dict, Any
|
||
|
||
from database import DatabaseManager, Session, Message
|
||
|
||
|
||
class SessionManager:
|
||
"""会话管理器"""
|
||
|
||
def __init__(self, db_path: str = None):
|
||
"""
|
||
初始化会话管理器
|
||
|
||
Args:
|
||
db_path: 数据库路径,默认为 chat_sessions.db
|
||
"""
|
||
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
|
||
self.Session = Session
|
||
self.Message = Message
|
||
|
||
def get_or_create_session(self, name: str, **kwargs) -> Session:
|
||
"""
|
||
获取或创建会话
|
||
|
||
Args:
|
||
name: 会话名称
|
||
**kwargs: 会话参数(model, system_prompt, metadata)
|
||
|
||
Returns:
|
||
Session对象
|
||
"""
|
||
# 尝试获取现有会话
|
||
session = self.Session.get_or_none(self.Session.name == name)
|
||
|
||
if session:
|
||
return session
|
||
|
||
# 创建新会话
|
||
defaults = {
|
||
'model': kwargs.get('model', 'deepseek-chat'),
|
||
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
|
||
'metadata': kwargs.get('metadata', {})
|
||
}
|
||
|
||
session = self.Session.create(
|
||
name=name,
|
||
**defaults
|
||
)
|
||
|
||
# 添加系统提示作为第一条消息
|
||
if defaults['system_prompt']:
|
||
self.Message.create(
|
||
session=session,
|
||
role='system',
|
||
content=defaults['system_prompt']
|
||
)
|
||
|
||
return session
|
||
|
||
def add_message(self, session_name: str, role: str, content: str,
|
||
tokens: int = 0, metadata: Dict = None) -> str:
|
||
"""
|
||
添加消息到会话
|
||
|
||
Args:
|
||
session_name: 会话名称
|
||
role: 消息角色(user/assistant/system)
|
||
content: 消息内容
|
||
tokens: token数量
|
||
metadata: 元数据
|
||
|
||
Returns:
|
||
消息内容
|
||
"""
|
||
session = self.get_or_create_session(session_name)
|
||
|
||
self.Message.create(
|
||
session=session,
|
||
role=role,
|
||
content=content,
|
||
tokens=tokens,
|
||
timestamp=datetime.now(),
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
# 更新会话时间戳
|
||
session.update_timestamp()
|
||
|
||
return content
|
||
|
||
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
|
||
"""
|
||
获取会话历史
|
||
|
||
Args:
|
||
session_name: 会话名称
|
||
limit: 返回消息数量限制
|
||
|
||
Returns:
|
||
消息历史列表
|
||
"""
|
||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||
if not session:
|
||
return []
|
||
|
||
messages = (self.Message
|
||
.select()
|
||
.where(self.Message.session == session)
|
||
.order_by(self.Message.timestamp.asc())
|
||
.limit(limit))
|
||
|
||
return [
|
||
{'role': msg.role, 'content': msg.content}
|
||
for msg in messages
|
||
]
|
||
|
||
def get_session_info(self, session_name: str) -> Dict:
|
||
"""
|
||
获取会话信息
|
||
|
||
Args:
|
||
session_name: 会话名称
|
||
|
||
Returns:
|
||
会话信息字典
|
||
"""
|
||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||
if not session:
|
||
return {}
|
||
|
||
messages = session.messages
|
||
message_count = messages.count()
|
||
user_messages = messages.where(self.Message.role == 'user').count()
|
||
assistant_messages = messages.where(self.Message.role == 'assistant').count()
|
||
|
||
return {
|
||
'name': session.name,
|
||
'model': session.model,
|
||
'system_prompt': session.system_prompt,
|
||
'message_count': message_count,
|
||
'user_messages': user_messages,
|
||
'assistant_messages': assistant_messages,
|
||
'created_at': session.created_at.isoformat(),
|
||
'updated_at': session.updated_at.isoformat(),
|
||
'metadata': session.metadata
|
||
}
|
||
|
||
def list_sessions(self) -> List[Dict]:
|
||
"""
|
||
列出所有会话
|
||
|
||
Returns:
|
||
会话列表
|
||
"""
|
||
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
|
||
|
||
return [
|
||
{
|
||
'name': session.name,
|
||
'message_count': session.messages.count(),
|
||
'updated_at': session.updated_at.isoformat()
|
||
}
|
||
for session in sessions
|
||
]
|
||
|
||
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
|
||
"""
|
||
清空会话历史(保留系统提示)
|
||
|
||
Args:
|
||
session_name: 会话名称
|
||
keep_system: 是否保留系统提示
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||
if not session:
|
||
return False
|
||
|
||
# 删除消息
|
||
query = session.messages
|
||
if keep_system:
|
||
query = query.where(self.Message.role != 'system')
|
||
|
||
query.delete().execute()
|
||
|
||
# 重新添加系统提示(如果需要)
|
||
if keep_system and session.system_prompt:
|
||
messages = session.messages.where(self.Message.role == 'system')
|
||
if not messages.exists():
|
||
self.add_message(
|
||
session_name=session_name,
|
||
role='system',
|
||
content=session.system_prompt
|
||
)
|
||
|
||
return True
|
||
|
||
def close(self):
|
||
"""关闭数据库连接"""
|
||
self.db_manager.close()
|