Files
mini_code/ai_main/client.py
2026-01-09 17:51:09 +08:00

184 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""DeepSeek聊天客户端"""
from datetime import datetime
from typing import Dict, List, Optional
from openai import OpenAI
from session_manager import SessionManager
class DeepSeekChatClient:
"""简化的DeepSeek聊天客户端"""
def __init__(self, api_key: str, db_path: str = None):
"""
初始化客户端
Args:
api_key: DeepSeek API密钥
db_path: 数据库路径默认为chat_sessions.db
"""
self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com')
self.session_manager = SessionManager(db_path)
def chat(self, user_input: str, name: str, model: str = None,
system_prompt: str = None, **kwargs) -> str:
"""
发送聊天消息
Args:
user_input: 用户输入的问题
name: 会话名称
model: 模型名称(仅在创建新会话时使用)
system_prompt: 系统提示(仅在创建新会话时使用)
**kwargs: 其他OpenAI API参数
Returns:
AI的回复内容
"""
# 获取或创建会话
session = self.session_manager.get_or_create_session(
name=name,
model=model or 'deepseek-chat',
system_prompt=system_prompt
)
# 保存用户消息
self.session_manager.add_message(
session_name=name,
role='user',
content=user_input,
tokens=int(len(user_input) * 0.8)
)
# 构建消息列表
messages = self._build_messages(session, user_input)
# 调用API
try:
response = self.client.chat.completions.create(
model=session.model,
messages=messages,
max_tokens=kwargs.get('max_tokens', 2000),
temperature=kwargs.get('temperature', 0.7),
stream=kwargs.get('stream', False)
)
# 处理响应
ai_reply = response.choices[0].message.content
# 保存AI回复
self.session_manager.add_message(
session_name=name,
role='assistant',
content=ai_reply,
tokens=int(len(ai_reply) * 0.8)
)
return ai_reply
except Exception as e:
error_msg = f"API调用错误: {str(e)}"
print(error_msg)
# 保存错误信息
self.session_manager.add_message(
session_name=name,
role='system',
content=error_msg,
metadata={'error': True, 'error_type': type(e).__name__}
)
return ""
def _build_messages(self, session, user_input: str) -> List[Dict]:
"""
构建消息列表
Args:
session: 会话对象
user_input: 用户输入
Returns:
消息列表
"""
# 获取会话历史
history = self.session_manager.get_session_history(name=session.name, limit=10)
messages = []
# 添加系统提示(如果历史中没有)
if not any(msg['role'] == 'system' for msg in history):
messages.append({
'role': 'system',
'content': session.system_prompt
})
# 添加上下文历史
messages.extend(history)
# 添加当前用户消息
messages.append({
'role': 'user',
'content': user_input
})
return messages
def get_session_info(self, name: str) -> Dict:
"""
获取会话信息
Args:
name: 会话名称
Returns:
会话信息字典
"""
return self.session_manager.get_session_info(name)
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
会话列表
"""
return self.session_manager.list_sessions()
def export_session(self, name: str) -> Dict:
"""
导出会话数据
Args:
name: 会话名称
Returns:
导出的会话数据
"""
info = self.get_session_info(name)
history = self.session_manager.get_session_history(name, limit=100)
return {
'session_info': info,
'history': history,
'exported_at': datetime.now().isoformat()
}
def clear_session_history(self, name: str, keep_system: bool = True) -> bool:
"""
清空会话历史(保留系统提示)
Args:
name: 会话名称
keep_system: 是否保留系统提示
Returns:
是否成功
"""
return self.session_manager.clear_session_history(name, keep_system)
def close(self):
"""关闭客户端"""
self.session_manager.close()