diff --git a/ai_main/chat_history.db b/ai_main/chat_history.db new file mode 100644 index 000000000..3049b9d33 Binary files /dev/null and b/ai_main/chat_history.db differ diff --git a/ai_main/main.py b/ai_main/main.py new file mode 100644 index 000000000..9c9e9f4d9 --- /dev/null +++ b/ai_main/main.py @@ -0,0 +1,492 @@ +import json +import os +import uuid +from datetime import datetime +from typing import List, Dict, Optional, Any, Union + +from peewee import ( + Model, CharField, TextField, IntegerField, + DateTimeField, BooleanField, ForeignKeyField, + fn, SqliteDatabase +) +from playhouse.sqlite_ext import SqliteExtDatabase +from openai import OpenAI + + +# ==================== 数据库配置 ==================== +class DatabaseManager: + """数据库管理器""" + + def __init__(self, db_path: str = 'chat_sessions.db'): + self.db = SqliteExtDatabase( + db_path, + pragmas={ + 'journal_mode': 'wal', + 'cache_size': -1024 * 64, + 'foreign_keys': 1, + 'ignore_check_constraints': 0, + 'synchronous': 1 + } + ) + self.init_models() + + def init_models(self): + """初始化数据模型""" + + class BaseModel(Model): + class Meta: + database = self.db + + class Session(BaseModel): + """聊天会话模型""" + name = CharField(max_length=255, unique=True) # 会话名称唯一 + model = CharField(max_length=50, default='deepseek-chat') + system_prompt = TextField(default='你是一个乐于助人的助手。') + created_at = DateTimeField(default=datetime.now) + updated_at = DateTimeField(default=datetime.now) + is_active = BooleanField(default=True) + metadata_json = TextField(default='{}') + + @property + def metadata(self) -> Dict: + return json.loads(self.metadata_json) if self.metadata_json else {} + + @metadata.setter + def metadata(self, value: Dict): + self.metadata_json = json.dumps(value, ensure_ascii=False) + + def update_timestamp(self): + self.updated_at = datetime.now() + self.save() + + class Message(BaseModel): + """聊天消息模型""" + session = ForeignKeyField(Session, backref='messages', on_delete='CASCADE') + role = CharField(max_length=20, index=True) + content = TextField() + tokens = IntegerField(default=0) + timestamp = DateTimeField(default=datetime.now, index=True) + metadata_json = TextField(default='{}') + + @property + def metadata(self) -> Dict: + return json.loads(self.metadata_json) if self.metadata_json else {} + + @metadata.setter + def metadata(self, value: Dict): + self.metadata_json = json.dumps(value, ensure_ascii=False) + + class Meta: + indexes = ( + (('session', 'timestamp'), False), + ) + + # 保存模型类 + self.Session = Session + self.Message = Message + + # 创建表 + self.db.connect() + self.db.create_tables([Session, Message], safe=True) + + def close(self): + """关闭数据库连接""" + if not self.db.is_closed(): + self.db.close() + + +# ==================== 会话管理器 ==================== +class SessionManager: + """会话管理器""" + + def __init__(self, db_path: str = None): + self.db_manager = DatabaseManager(db_path or 'chat_sessions.db') + self.Session = self.db_manager.Session + self.Message = self.db_manager.Message + + def get_or_create_session(self, name: str, **kwargs) -> Any: + """获取或创建会话""" + # 尝试获取现有会话 + session = self.Session.get_or_none(self.Session.name == name) + + if session: + return session + else: + # 创建新会话 + defaults = { + 'model': kwargs.get('model', 'deepseek-chat'), + 'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'), + 'metadata': kwargs.get('metadata', {}) + } + + session = self.Session.create( + name=name, + **defaults + ) + + # 添加系统提示作为第一条消息 + if defaults['system_prompt']: + self.Message.create( + session=session, + role='system', + content=defaults['system_prompt'] + ) + + return session + + def add_message(self, session_name: str, role: str, content: str, + tokens: int = 0, metadata: Dict = None) -> str: + """添加消息到会话""" + session = self.get_or_create_session(session_name) + + self.Message.create( + session=session, + role=role, + content=content, + tokens=tokens, + timestamp=datetime.now(), + metadata=metadata or {} + ) + + # 更新会话时间戳 + session.update_timestamp() + + return content + + def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]: + """获取会话历史""" + session = self.Session.get_or_none(self.Session.name == session_name) + if not session: + return [] + + messages = (self.Message + .select() + .where(self.Message.session == session) + .order_by(self.Message.timestamp.asc()) + .limit(limit)) + + history = [] + for msg in messages: + history.append({ + 'role': msg.role, + 'content': msg.content + }) + + return history + + def get_session_info(self, session_name: str) -> Dict: + """获取会话信息""" + session = self.Session.get_or_none(self.Session.name == session_name) + if not session: + return {} + + messages = session.messages + message_count = messages.count() + user_messages = messages.where(self.Message.role == 'user').count() + assistant_messages = messages.where(self.Message.role == 'assistant').count() + + return { + 'name': session.name, + 'model': session.model, + 'system_prompt': session.system_prompt, + 'message_count': message_count, + 'user_messages': user_messages, + 'assistant_messages': assistant_messages, + 'created_at': session.created_at.isoformat(), + 'updated_at': session.updated_at.isoformat(), + 'metadata': session.metadata + } + + def list_sessions(self) -> List[Dict]: + """列出所有会话""" + sessions = self.Session.select().order_by(self.Session.updated_at.desc()) + + result = [] + for session in sessions: + message_count = session.messages.count() + result.append({ + 'name': session.name, + 'message_count': message_count, + 'updated_at': session.updated_at.isoformat() + }) + + return result + + def close(self): + """关闭连接""" + self.db_manager.close() + + +# ==================== 简化的客户端接口 ==================== +class DeepSeekChatClient: + """简化的DeepSeek聊天客户端""" + + def __init__(self, api_key: str, db_path: str = None): + """ + 初始化客户端 + + Args: + api_key: DeepSeek API密钥 + db_path: 数据库路径,默认为chat_sessions.db + """ + self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com') + self.session_manager = SessionManager(db_path) + + # 缓存已加载的会话配置 + self._session_cache = {} + + def chat(self, user_input: str, name: str, model: str = None, + system_prompt: str = None, **kwargs) -> str: + """ + 发送聊天消息 + + Args: + user_input: 用户输入的问题 + name: 会话名称 + model: 模型名称(仅在创建新会话时使用) + system_prompt: 系统提示(仅在创建新会话时使用) + **kwargs: 其他OpenAI API参数 + + Returns: + AI的回复内容 + """ + # 获取或创建会话 + session = self.session_manager.get_or_create_session( + name=name, + model=model or 'deepseek-chat', + system_prompt=system_prompt + ) + + # 保存用户消息 + self.session_manager.add_message( + session_name=name, + role='user', + content=user_input, + tokens=len(user_input) * 0.8 + ) + + # 获取会话历史 + history = self.session_manager.get_session_history(name, limit=10) + + # 构建消息列表 + messages = [] + + # 添加系统提示(如果历史中没有) + if not any(msg['role'] == 'system' for msg in history): + messages.append({ + 'role': 'system', + 'content': session.system_prompt + }) + + # 添加上下文历史 + for msg in history: + messages.append({ + 'role': msg['role'], + 'content': msg['content'] + }) + + # 添加当前用户消息 + messages.append({ + 'role': 'user', + 'content': user_input + }) + + # 调用API + try: + response = self.client.chat.completions.create( + model=session.model, + messages=messages, + max_tokens=kwargs.get('max_tokens', 2000), + temperature=kwargs.get('temperature', 0.7), + stream=kwargs.get('stream', False) + ) + + # 处理响应 + ai_reply = response.choices[0].message.content + + # 保存AI回复 + self.session_manager.add_message( + session_name=name, + role='assistant', + content=ai_reply, + tokens=len(ai_reply) * 0.8 + ) + + return ai_reply + + except Exception as e: + error_msg = f"API调用错误: {str(e)}" + print(error_msg) + + # 保存错误信息 + self.session_manager.add_message( + session_name=name, + role='system', + content=error_msg, + metadata={'error': True, 'error_type': type(e).__name__} + ) + + return "" + + def get_session_info(self, name: str) -> Dict: + """获取会话信息""" + return self.session_manager.get_session_info(name) + + def list_sessions(self) -> List[Dict]: + """列出所有会话""" + return self.session_manager.list_sessions() + + def export_session(self, name: str) -> Dict: + """导出会话数据""" + info = self.get_session_info(name) + history = self.session_manager.get_session_history(name, limit=100) + + return { + 'session_info': info, + 'history': history, + 'exported_at': datetime.now().isoformat() + } + + def clear_session_history(self, name: str, keep_system: bool = True) -> bool: + """清空会话历史(保留系统提示)""" + session = self.session_manager.Session.get_or_none( + self.session_manager.Session.name == name + ) + + if not session: + return False + + # 删除消息 + query = session.messages + if keep_system: + query = query.where(self.session_manager.Message.role != 'system') + + query.delete().execute() + + # 重新添加系统提示(如果需要) + if keep_system and session.system_prompt: + messages = session.messages.where( + self.session_manager.Message.role == 'system' + ) + if not messages.exists(): + self.session_manager.add_message( + session_name=name, + role='system', + content=session.system_prompt + ) + + return True + + def close(self): + """关闭客户端""" + self.session_manager.close() + + +# ==================== 使用示例 ==================== +def main(): + # 初始化客户端(请使用环境变量管理API密钥) + API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3" + + client = None + try: + # 创建客户端 + client = DeepSeekChatClient(api_key=API_KEY, db_path='my_chats.db') + print("✅ 客户端初始化成功!") + + # ========== 第一次调用:创建会话 ========== + print("\n=== 第一次调用:创建技术讨论会话 ===") + res1 = client.chat( + name='技术讨论', + model='deepseek-chat', + system_prompt='你是一个技术专家,请提供详细的解释和代码示例。', + user_input='请解释Python中的装饰器。' + ) + print("回答:", res1[:100] + "..." if len(res1) > 100 else res1) + + # ========== 第二次调用:使用现有会话 ========== + print("\n=== 第二次调用:使用现有技术讨论会话 ===") + res2 = client.chat( + name='技术讨论', # 使用相同名称 + user_input='请给我一个装饰器的实际应用例子。' # 只需要用户输入 + ) + print("回答:", res2[:100] + "..." if len(res2) > 100 else res2) + + # ========== 第三次调用:创建新类型会话 ========== + print("\n=== 第三次调用:创建创意写作会话 ===") + res3 = client.chat( + name='创意写作', + model='deepseek-chat', + system_prompt='你是一个创意作家,请提供富有想象力的回答。', + user_input='写一个关于人工智能的短故事开头' + ) + print("回答:", res3[:100] + "..." if len(res3) > 100 else res3) + + # ========== 查看会话信息 ========== + print("\n=== 查看技术讨论会话信息 ===") + session_info = client.get_session_info('技术讨论') + print(f"会话名称: {session_info['name']}") + print(f"模型: {session_info['model']}") + print(f"消息总数: {session_info['message_count']}") + print(f"用户消息: {session_info['user_messages']}") + print(f"助手消息: {session_info['assistant_messages']}") + + # ========== 列出所有会话 ========== + print("\n=== 所有会话列表 ===") + sessions = client.list_sessions() + for session in sessions: + print(f" {session['name']}: {session['message_count']} 条消息") + + # ========== 导出会话数据 ========== + print("\n=== 导出技术讨论会话 ===") + export_data = client.export_session('技术讨论') + print(f"导出了 {len(export_data['history'])} 条消息") + + print("\n🎉 测试完成!") + + except Exception as e: + print(f"❌ 程序运行出错: {e}") + import traceback + traceback.print_exc() + finally: + # 确保关闭连接 + if client: + client.close() + + +# ==================== 简洁的使用示例 ==================== +def simple_usage(): + """简洁的使用示例""" + API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3" + + # 1. 创建客户端 + client = DeepSeekChatClient(api_key=API_KEY) + + # 2. 创建新会话(第一次调用) + response1 = client.chat( + name='数学辅导', + model='deepseek-chat', + system_prompt='你是一个数学老师,请用简单易懂的方式解释数学概念。', + user_input='什么是微积分?' + ) + print("第一次回答:", response1[:100]) + + # 3. 继续对话(只需要会话名称和用户输入) + response2 = client.chat( + name='数学辅导', + user_input='微积分有哪些实际应用?' + ) + print("第二次回答:", response2[:100]) + + # 4. 查看会话信息 + info = client.get_session_info('数学辅导') + print(f"会话 '{info['name']}' 有 {info['message_count']} 条消息") + + client.close() + + +if __name__ == '__main__': + # 运行完整示例 + main() + + # 或运行简洁示例 + # simple_usage() \ No newline at end of file diff --git a/ai_main/my_chats.db b/ai_main/my_chats.db new file mode 100644 index 000000000..09f0cc924 Binary files /dev/null and b/ai_main/my_chats.db differ