Files
mini_code/ai_main/main.py
2025-12-16 10:22:15 +08:00

492 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import uuid
from datetime import datetime
from typing import List, Dict, Optional, Any, Union
from peewee import (
Model, CharField, TextField, IntegerField,
DateTimeField, BooleanField, ForeignKeyField,
fn, SqliteDatabase
)
from playhouse.sqlite_ext import SqliteExtDatabase
from openai import OpenAI
# ==================== 数据库配置 ====================
class DatabaseManager:
"""数据库管理器"""
def __init__(self, db_path: str = 'chat_sessions.db'):
self.db = SqliteExtDatabase(
db_path,
pragmas={
'journal_mode': 'wal',
'cache_size': -1024 * 64,
'foreign_keys': 1,
'ignore_check_constraints': 0,
'synchronous': 1
}
)
self.init_models()
def init_models(self):
"""初始化数据模型"""
class BaseModel(Model):
class Meta:
database = self.db
class Session(BaseModel):
"""聊天会话模型"""
name = CharField(max_length=255, unique=True) # 会话名称唯一
model = CharField(max_length=50, default='deepseek-chat')
system_prompt = TextField(default='你是一个乐于助人的助手。')
created_at = DateTimeField(default=datetime.now)
updated_at = DateTimeField(default=datetime.now)
is_active = BooleanField(default=True)
metadata_json = TextField(default='{}')
@property
def metadata(self) -> Dict:
return json.loads(self.metadata_json) if self.metadata_json else {}
@metadata.setter
def metadata(self, value: Dict):
self.metadata_json = json.dumps(value, ensure_ascii=False)
def update_timestamp(self):
self.updated_at = datetime.now()
self.save()
class Message(BaseModel):
"""聊天消息模型"""
session = ForeignKeyField(Session, backref='messages', on_delete='CASCADE')
role = CharField(max_length=20, index=True)
content = TextField()
tokens = IntegerField(default=0)
timestamp = DateTimeField(default=datetime.now, index=True)
metadata_json = TextField(default='{}')
@property
def metadata(self) -> Dict:
return json.loads(self.metadata_json) if self.metadata_json else {}
@metadata.setter
def metadata(self, value: Dict):
self.metadata_json = json.dumps(value, ensure_ascii=False)
class Meta:
indexes = (
(('session', 'timestamp'), False),
)
# 保存模型类
self.Session = Session
self.Message = Message
# 创建表
self.db.connect()
self.db.create_tables([Session, Message], safe=True)
def close(self):
"""关闭数据库连接"""
if not self.db.is_closed():
self.db.close()
# ==================== 会话管理器 ====================
class SessionManager:
"""会话管理器"""
def __init__(self, db_path: str = None):
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
self.Session = self.db_manager.Session
self.Message = self.db_manager.Message
def get_or_create_session(self, name: str, **kwargs) -> Any:
"""获取或创建会话"""
# 尝试获取现有会话
session = self.Session.get_or_none(self.Session.name == name)
if session:
return session
else:
# 创建新会话
defaults = {
'model': kwargs.get('model', 'deepseek-chat'),
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
'metadata': kwargs.get('metadata', {})
}
session = self.Session.create(
name=name,
**defaults
)
# 添加系统提示作为第一条消息
if defaults['system_prompt']:
self.Message.create(
session=session,
role='system',
content=defaults['system_prompt']
)
return session
def add_message(self, session_name: str, role: str, content: str,
tokens: int = 0, metadata: Dict = None) -> str:
"""添加消息到会话"""
session = self.get_or_create_session(session_name)
self.Message.create(
session=session,
role=role,
content=content,
tokens=tokens,
timestamp=datetime.now(),
metadata=metadata or {}
)
# 更新会话时间戳
session.update_timestamp()
return content
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
"""获取会话历史"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return []
messages = (self.Message
.select()
.where(self.Message.session == session)
.order_by(self.Message.timestamp.asc())
.limit(limit))
history = []
for msg in messages:
history.append({
'role': msg.role,
'content': msg.content
})
return history
def get_session_info(self, session_name: str) -> Dict:
"""获取会话信息"""
session = self.Session.get_or_none(self.Session.name == session_name)
if not session:
return {}
messages = session.messages
message_count = messages.count()
user_messages = messages.where(self.Message.role == 'user').count()
assistant_messages = messages.where(self.Message.role == 'assistant').count()
return {
'name': session.name,
'model': session.model,
'system_prompt': session.system_prompt,
'message_count': message_count,
'user_messages': user_messages,
'assistant_messages': assistant_messages,
'created_at': session.created_at.isoformat(),
'updated_at': session.updated_at.isoformat(),
'metadata': session.metadata
}
def list_sessions(self) -> List[Dict]:
"""列出所有会话"""
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
result = []
for session in sessions:
message_count = session.messages.count()
result.append({
'name': session.name,
'message_count': message_count,
'updated_at': session.updated_at.isoformat()
})
return result
def close(self):
"""关闭连接"""
self.db_manager.close()
# ==================== 简化的客户端接口 ====================
class DeepSeekChatClient:
"""简化的DeepSeek聊天客户端"""
def __init__(self, api_key: str, db_path: str = None):
"""
初始化客户端
Args:
api_key: DeepSeek API密钥
db_path: 数据库路径默认为chat_sessions.db
"""
self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com')
self.session_manager = SessionManager(db_path)
# 缓存已加载的会话配置
self._session_cache = {}
def chat(self, user_input: str, name: str, model: str = None,
system_prompt: str = None, **kwargs) -> str:
"""
发送聊天消息
Args:
user_input: 用户输入的问题
name: 会话名称
model: 模型名称(仅在创建新会话时使用)
system_prompt: 系统提示(仅在创建新会话时使用)
**kwargs: 其他OpenAI API参数
Returns:
AI的回复内容
"""
# 获取或创建会话
session = self.session_manager.get_or_create_session(
name=name,
model=model or 'deepseek-chat',
system_prompt=system_prompt
)
# 保存用户消息
self.session_manager.add_message(
session_name=name,
role='user',
content=user_input,
tokens=len(user_input) * 0.8
)
# 获取会话历史
history = self.session_manager.get_session_history(name, limit=10)
# 构建消息列表
messages = []
# 添加系统提示(如果历史中没有)
if not any(msg['role'] == 'system' for msg in history):
messages.append({
'role': 'system',
'content': session.system_prompt
})
# 添加上下文历史
for msg in history:
messages.append({
'role': msg['role'],
'content': msg['content']
})
# 添加当前用户消息
messages.append({
'role': 'user',
'content': user_input
})
# 调用API
try:
response = self.client.chat.completions.create(
model=session.model,
messages=messages,
max_tokens=kwargs.get('max_tokens', 2000),
temperature=kwargs.get('temperature', 0.7),
stream=kwargs.get('stream', False)
)
# 处理响应
ai_reply = response.choices[0].message.content
# 保存AI回复
self.session_manager.add_message(
session_name=name,
role='assistant',
content=ai_reply,
tokens=len(ai_reply) * 0.8
)
return ai_reply
except Exception as e:
error_msg = f"API调用错误: {str(e)}"
print(error_msg)
# 保存错误信息
self.session_manager.add_message(
session_name=name,
role='system',
content=error_msg,
metadata={'error': True, 'error_type': type(e).__name__}
)
return ""
def get_session_info(self, name: str) -> Dict:
"""获取会话信息"""
return self.session_manager.get_session_info(name)
def list_sessions(self) -> List[Dict]:
"""列出所有会话"""
return self.session_manager.list_sessions()
def export_session(self, name: str) -> Dict:
"""导出会话数据"""
info = self.get_session_info(name)
history = self.session_manager.get_session_history(name, limit=100)
return {
'session_info': info,
'history': history,
'exported_at': datetime.now().isoformat()
}
def clear_session_history(self, name: str, keep_system: bool = True) -> bool:
"""清空会话历史(保留系统提示)"""
session = self.session_manager.Session.get_or_none(
self.session_manager.Session.name == name
)
if not session:
return False
# 删除消息
query = session.messages
if keep_system:
query = query.where(self.session_manager.Message.role != 'system')
query.delete().execute()
# 重新添加系统提示(如果需要)
if keep_system and session.system_prompt:
messages = session.messages.where(
self.session_manager.Message.role == 'system'
)
if not messages.exists():
self.session_manager.add_message(
session_name=name,
role='system',
content=session.system_prompt
)
return True
def close(self):
"""关闭客户端"""
self.session_manager.close()
# ==================== 使用示例 ====================
def main():
# 初始化客户端请使用环境变量管理API密钥
API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3"
client = None
try:
# 创建客户端
client = DeepSeekChatClient(api_key=API_KEY, db_path='my_chats.db')
print("✅ 客户端初始化成功!")
# ========== 第一次调用:创建会话 ==========
print("\n=== 第一次调用:创建技术讨论会话 ===")
res1 = client.chat(
name='技术讨论',
model='deepseek-chat',
system_prompt='你是一个技术专家,请提供详细的解释和代码示例。',
user_input='请解释Python中的装饰器。'
)
print("回答:", res1[:100] + "..." if len(res1) > 100 else res1)
# ========== 第二次调用:使用现有会话 ==========
print("\n=== 第二次调用:使用现有技术讨论会话 ===")
res2 = client.chat(
name='技术讨论', # 使用相同名称
user_input='请给我一个装饰器的实际应用例子。' # 只需要用户输入
)
print("回答:", res2[:100] + "..." if len(res2) > 100 else res2)
# ========== 第三次调用:创建新类型会话 ==========
print("\n=== 第三次调用:创建创意写作会话 ===")
res3 = client.chat(
name='创意写作',
model='deepseek-chat',
system_prompt='你是一个创意作家,请提供富有想象力的回答。',
user_input='写一个关于人工智能的短故事开头'
)
print("回答:", res3[:100] + "..." if len(res3) > 100 else res3)
# ========== 查看会话信息 ==========
print("\n=== 查看技术讨论会话信息 ===")
session_info = client.get_session_info('技术讨论')
print(f"会话名称: {session_info['name']}")
print(f"模型: {session_info['model']}")
print(f"消息总数: {session_info['message_count']}")
print(f"用户消息: {session_info['user_messages']}")
print(f"助手消息: {session_info['assistant_messages']}")
# ========== 列出所有会话 ==========
print("\n=== 所有会话列表 ===")
sessions = client.list_sessions()
for session in sessions:
print(f" {session['name']}: {session['message_count']} 条消息")
# ========== 导出会话数据 ==========
print("\n=== 导出技术讨论会话 ===")
export_data = client.export_session('技术讨论')
print(f"导出了 {len(export_data['history'])} 条消息")
print("\n🎉 测试完成!")
except Exception as e:
print(f"❌ 程序运行出错: {e}")
import traceback
traceback.print_exc()
finally:
# 确保关闭连接
if client:
client.close()
# ==================== 简洁的使用示例 ====================
def simple_usage():
"""简洁的使用示例"""
API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3"
# 1. 创建客户端
client = DeepSeekChatClient(api_key=API_KEY)
# 2. 创建新会话(第一次调用)
response1 = client.chat(
name='数学辅导',
model='deepseek-chat',
system_prompt='你是一个数学老师,请用简单易懂的方式解释数学概念。',
user_input='什么是微积分?'
)
print("第一次回答:", response1[:100])
# 3. 继续对话(只需要会话名称和用户输入)
response2 = client.chat(
name='数学辅导',
user_input='微积分有哪些实际应用?'
)
print("第二次回答:", response2[:100])
# 4. 查看会话信息
info = client.get_session_info('数学辅导')
print(f"会话 '{info['name']}'{info['message_count']} 条消息")
client.close()
if __name__ == '__main__':
# 运行完整示例
main()
# 或运行简洁示例
# simple_usage()