Files
mini_code/ai_main/session_manager.py
2026-01-09 17:51:09 +08:00

204 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""会话管理器"""
from datetime import datetime
from typing import List, Dict, Any
from database import DatabaseManager, Session, Message
class SessionManager:
"""会话管理器"""
def __init__(self, db_path: str = None):
"""
初始化会话管理器
Args:
db_path: 数据库路径,默认为 chat_sessions.db
"""
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
self.Session = Session
self.Message = Message
def get_or_create_session(self, name: str, **kwargs) -> Session:
"""
获取或创建会话
Args:
name: 会话名称
**kwargs: 会话参数model, system_prompt, metadata
Returns:
Session对象
"""
# 尝试获取现有会话
session = self.Session.get_or_none(self.Session.name == name)
if session:
return session
# 创建新会话
defaults = {
'model': kwargs.get('model', 'deepseek-chat'),
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
'metadata': kwargs.get('metadata', {})
}
session = self.Session.create(
name=name,
**defaults
)
# 添加系统提示作为第一条消息
if defaults['system_prompt']:
self.Message.create(
session=session,
role='system',
content=defaults['system_prompt']
)
return session
def add_message(self, session_name: str, role: str, content: str,
tokens: int = 0, metadata: Dict = None) -> str:
"""
添加消息到会话
Args:
session_name: 会话名称
role: 消息角色user/assistant/system
content: 消息内容
tokens: token数量
metadata: 元数据
Returns:
消息内容
"""
session = self.get_or_create_session(session_name)
self.Message.create(
session=session,
role=role,
content=content,
tokens=tokens,
timestamp=datetime.now(),
metadata=metadata or {}
)
# 更新会话时间戳
session.update_timestamp()
return content
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
"""
获取会话历史
Args:
session_name: 会话名称
limit: 返回消息数量限制
Returns:
消息历史列表
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return []
messages = (self.Message
.select()
.where(self.Message.session == session)
.order_by(self.Message.timestamp.asc())
.limit(limit))
return [
{'role': msg.role, 'content': msg.content}
for msg in messages
]
def get_session_info(self, session_name: str) -> Dict:
"""
获取会话信息
Args:
session_name: 会话名称
Returns:
会话信息字典
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return {}
messages = session.messages
message_count = messages.count()
user_messages = messages.where(self.Message.role == 'user').count()
assistant_messages = messages.where(self.Message.role == 'assistant').count()
return {
'name': session.name,
'model': session.model,
'system_prompt': session.system_prompt,
'message_count': message_count,
'user_messages': user_messages,
'assistant_messages': assistant_messages,
'created_at': session.created_at.isoformat(),
'updated_at': session.updated_at.isoformat(),
'metadata': session.metadata
}
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
会话列表
"""
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
return [
{
'name': session.name,
'message_count': session.messages.count(),
'updated_at': session.updated_at.isoformat()
}
for session in sessions
]
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
"""
清空会话历史(保留系统提示)
Args:
session_name: 会话名称
keep_system: 是否保留系统提示
Returns:
是否成功
"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return False
# 删除消息
query = session.messages
if keep_system:
query = query.where(self.Message.role != 'system')
query.delete().execute()
# 重新添加系统提示(如果需要)
if keep_system and session.system_prompt:
messages = session.messages.where(self.Message.role == 'system')
if not messages.exists():
self.add_message(
session_name=session_name,
role='system',
content=session.system_prompt
)
return True
def close(self):
"""关闭数据库连接"""
self.db_manager.close()