fewfefffwefweef
This commit is contained in:
203
ai_main/session_manager.py
Normal file
203
ai_main/session_manager.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""会话管理器"""
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from database import DatabaseManager, Session, Message
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, db_path: str = None):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库路径,默认为 chat_sessions.db
|
||||
"""
|
||||
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
|
||||
self.Session = Session
|
||||
self.Message = Message
|
||||
|
||||
def get_or_create_session(self, name: str, **kwargs) -> Session:
|
||||
"""
|
||||
获取或创建会话
|
||||
|
||||
Args:
|
||||
name: 会话名称
|
||||
**kwargs: 会话参数(model, system_prompt, metadata)
|
||||
|
||||
Returns:
|
||||
Session对象
|
||||
"""
|
||||
# 尝试获取现有会话
|
||||
session = self.Session.get_or_none(self.Session.name == name)
|
||||
|
||||
if session:
|
||||
return session
|
||||
|
||||
# 创建新会话
|
||||
defaults = {
|
||||
'model': kwargs.get('model', 'deepseek-chat'),
|
||||
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
|
||||
'metadata': kwargs.get('metadata', {})
|
||||
}
|
||||
|
||||
session = self.Session.create(
|
||||
name=name,
|
||||
**defaults
|
||||
)
|
||||
|
||||
# 添加系统提示作为第一条消息
|
||||
if defaults['system_prompt']:
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role='system',
|
||||
content=defaults['system_prompt']
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def add_message(self, session_name: str, role: str, content: str,
|
||||
tokens: int = 0, metadata: Dict = None) -> str:
|
||||
"""
|
||||
添加消息到会话
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
role: 消息角色(user/assistant/system)
|
||||
content: 消息内容
|
||||
tokens: token数量
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
消息内容
|
||||
"""
|
||||
session = self.get_or_create_session(session_name)
|
||||
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role=role,
|
||||
content=content,
|
||||
tokens=tokens,
|
||||
timestamp=datetime.now(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# 更新会话时间戳
|
||||
session.update_timestamp()
|
||||
|
||||
return content
|
||||
|
||||
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
|
||||
"""
|
||||
获取会话历史
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
limit: 返回消息数量限制
|
||||
|
||||
Returns:
|
||||
消息历史列表
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return []
|
||||
|
||||
messages = (self.Message
|
||||
.select()
|
||||
.where(self.Message.session == session)
|
||||
.order_by(self.Message.timestamp.asc())
|
||||
.limit(limit))
|
||||
|
||||
return [
|
||||
{'role': msg.role, 'content': msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
def get_session_info(self, session_name: str) -> Dict:
|
||||
"""
|
||||
获取会话信息
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
|
||||
Returns:
|
||||
会话信息字典
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
messages = session.messages
|
||||
message_count = messages.count()
|
||||
user_messages = messages.where(self.Message.role == 'user').count()
|
||||
assistant_messages = messages.where(self.Message.role == 'assistant').count()
|
||||
|
||||
return {
|
||||
'name': session.name,
|
||||
'model': session.model,
|
||||
'system_prompt': session.system_prompt,
|
||||
'message_count': message_count,
|
||||
'user_messages': user_messages,
|
||||
'assistant_messages': assistant_messages,
|
||||
'created_at': session.created_at.isoformat(),
|
||||
'updated_at': session.updated_at.isoformat(),
|
||||
'metadata': session.metadata
|
||||
}
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""
|
||||
列出所有会话
|
||||
|
||||
Returns:
|
||||
会话列表
|
||||
"""
|
||||
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
|
||||
|
||||
return [
|
||||
{
|
||||
'name': session.name,
|
||||
'message_count': session.messages.count(),
|
||||
'updated_at': session.updated_at.isoformat()
|
||||
}
|
||||
for session in sessions
|
||||
]
|
||||
|
||||
def clear_session_history(self, session_name: str, keep_system: bool = True) -> bool:
|
||||
"""
|
||||
清空会话历史(保留系统提示)
|
||||
|
||||
Args:
|
||||
session_name: 会话名称
|
||||
keep_system: 是否保留系统提示
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
# 删除消息
|
||||
query = session.messages
|
||||
if keep_system:
|
||||
query = query.where(self.Message.role != 'system')
|
||||
|
||||
query.delete().execute()
|
||||
|
||||
# 重新添加系统提示(如果需要)
|
||||
if keep_system and session.system_prompt:
|
||||
messages = session.messages.where(self.Message.role == 'system')
|
||||
if not messages.exists():
|
||||
self.add_message(
|
||||
session_name=session_name,
|
||||
role='system',
|
||||
content=session.system_prompt
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
self.db_manager.close()
|
||||
Reference in New Issue
Block a user