184 lines
4.9 KiB
Python
184 lines
4.9 KiB
Python
|
|
"""DeepSeek聊天客户端"""
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Dict, List, Optional
|
|||
|
|
|
|||
|
|
from openai import OpenAI
|
|||
|
|
|
|||
|
|
from session_manager import SessionManager
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DeepSeekChatClient:
|
|||
|
|
"""简化的DeepSeek聊天客户端"""
|
|||
|
|
|
|||
|
|
def __init__(self, api_key: str, db_path: str = None):
|
|||
|
|
"""
|
|||
|
|
初始化客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
api_key: DeepSeek API密钥
|
|||
|
|
db_path: 数据库路径,默认为chat_sessions.db
|
|||
|
|
"""
|
|||
|
|
self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com')
|
|||
|
|
self.session_manager = SessionManager(db_path)
|
|||
|
|
|
|||
|
|
def chat(self, user_input: str, name: str, model: str = None,
|
|||
|
|
system_prompt: str = None, **kwargs) -> str:
|
|||
|
|
"""
|
|||
|
|
发送聊天消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_input: 用户输入的问题
|
|||
|
|
name: 会话名称
|
|||
|
|
model: 模型名称(仅在创建新会话时使用)
|
|||
|
|
system_prompt: 系统提示(仅在创建新会话时使用)
|
|||
|
|
**kwargs: 其他OpenAI API参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AI的回复内容
|
|||
|
|
"""
|
|||
|
|
# 获取或创建会话
|
|||
|
|
session = self.session_manager.get_or_create_session(
|
|||
|
|
name=name,
|
|||
|
|
model=model or 'deepseek-chat',
|
|||
|
|
system_prompt=system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 保存用户消息
|
|||
|
|
self.session_manager.add_message(
|
|||
|
|
session_name=name,
|
|||
|
|
role='user',
|
|||
|
|
content=user_input,
|
|||
|
|
tokens=int(len(user_input) * 0.8)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 构建消息列表
|
|||
|
|
messages = self._build_messages(session, user_input)
|
|||
|
|
|
|||
|
|
# 调用API
|
|||
|
|
try:
|
|||
|
|
response = self.client.chat.completions.create(
|
|||
|
|
model=session.model,
|
|||
|
|
messages=messages,
|
|||
|
|
max_tokens=kwargs.get('max_tokens', 2000),
|
|||
|
|
temperature=kwargs.get('temperature', 0.7),
|
|||
|
|
stream=kwargs.get('stream', False)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 处理响应
|
|||
|
|
ai_reply = response.choices[0].message.content
|
|||
|
|
|
|||
|
|
# 保存AI回复
|
|||
|
|
self.session_manager.add_message(
|
|||
|
|
session_name=name,
|
|||
|
|
role='assistant',
|
|||
|
|
content=ai_reply,
|
|||
|
|
tokens=int(len(ai_reply) * 0.8)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ai_reply
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"API调用错误: {str(e)}"
|
|||
|
|
print(error_msg)
|
|||
|
|
|
|||
|
|
# 保存错误信息
|
|||
|
|
self.session_manager.add_message(
|
|||
|
|
session_name=name,
|
|||
|
|
role='system',
|
|||
|
|
content=error_msg,
|
|||
|
|
metadata={'error': True, 'error_type': type(e).__name__}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
def _build_messages(self, session, user_input: str) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
构建消息列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
session: 会话对象
|
|||
|
|
user_input: 用户输入
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表
|
|||
|
|
"""
|
|||
|
|
# 获取会话历史
|
|||
|
|
history = self.session_manager.get_session_history(name=session.name, limit=10)
|
|||
|
|
|
|||
|
|
messages = []
|
|||
|
|
|
|||
|
|
# 添加系统提示(如果历史中没有)
|
|||
|
|
if not any(msg['role'] == 'system' for msg in history):
|
|||
|
|
messages.append({
|
|||
|
|
'role': 'system',
|
|||
|
|
'content': session.system_prompt
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 添加上下文历史
|
|||
|
|
messages.extend(history)
|
|||
|
|
|
|||
|
|
# 添加当前用户消息
|
|||
|
|
messages.append({
|
|||
|
|
'role': 'user',
|
|||
|
|
'content': user_input
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
def get_session_info(self, name: str) -> Dict:
|
|||
|
|
"""
|
|||
|
|
获取会话信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 会话名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
会话信息字典
|
|||
|
|
"""
|
|||
|
|
return self.session_manager.get_session_info(name)
|
|||
|
|
|
|||
|
|
def list_sessions(self) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
列出所有会话
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
会话列表
|
|||
|
|
"""
|
|||
|
|
return self.session_manager.list_sessions()
|
|||
|
|
|
|||
|
|
def export_session(self, name: str) -> Dict:
|
|||
|
|
"""
|
|||
|
|
导出会话数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 会话名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
导出的会话数据
|
|||
|
|
"""
|
|||
|
|
info = self.get_session_info(name)
|
|||
|
|
history = self.session_manager.get_session_history(name, limit=100)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'session_info': info,
|
|||
|
|
'history': history,
|
|||
|
|
'exported_at': datetime.now().isoformat()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def clear_session_history(self, name: str, keep_system: bool = True) -> bool:
|
|||
|
|
"""
|
|||
|
|
清空会话历史(保留系统提示)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
name: 会话名称
|
|||
|
|
keep_system: 是否保留系统提示
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否成功
|
|||
|
|
"""
|
|||
|
|
return self.session_manager.clear_session_history(name, keep_system)
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭客户端"""
|
|||
|
|
self.session_manager.close()
|