ggrg
This commit is contained in:
BIN
ai_main/chat_history.db
Normal file
BIN
ai_main/chat_history.db
Normal file
Binary file not shown.
492
ai_main/main.py
Normal file
492
ai_main/main.py
Normal file
@@ -0,0 +1,492 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any, Union
|
||||
|
||||
from peewee import (
|
||||
Model, CharField, TextField, IntegerField,
|
||||
DateTimeField, BooleanField, ForeignKeyField,
|
||||
fn, SqliteDatabase
|
||||
)
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
# ==================== 数据库配置 ====================
|
||||
class DatabaseManager:
|
||||
"""数据库管理器"""
|
||||
|
||||
def __init__(self, db_path: str = 'chat_sessions.db'):
|
||||
self.db = SqliteExtDatabase(
|
||||
db_path,
|
||||
pragmas={
|
||||
'journal_mode': 'wal',
|
||||
'cache_size': -1024 * 64,
|
||||
'foreign_keys': 1,
|
||||
'ignore_check_constraints': 0,
|
||||
'synchronous': 1
|
||||
}
|
||||
)
|
||||
self.init_models()
|
||||
|
||||
def init_models(self):
|
||||
"""初始化数据模型"""
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = self.db
|
||||
|
||||
class Session(BaseModel):
|
||||
"""聊天会话模型"""
|
||||
name = CharField(max_length=255, unique=True) # 会话名称唯一
|
||||
model = CharField(max_length=50, default='deepseek-chat')
|
||||
system_prompt = TextField(default='你是一个乐于助人的助手。')
|
||||
created_at = DateTimeField(default=datetime.now)
|
||||
updated_at = DateTimeField(default=datetime.now)
|
||||
is_active = BooleanField(default=True)
|
||||
metadata_json = TextField(default='{}')
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict:
|
||||
return json.loads(self.metadata_json) if self.metadata_json else {}
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, value: Dict):
|
||||
self.metadata_json = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def update_timestamp(self):
|
||||
self.updated_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
class Message(BaseModel):
|
||||
"""聊天消息模型"""
|
||||
session = ForeignKeyField(Session, backref='messages', on_delete='CASCADE')
|
||||
role = CharField(max_length=20, index=True)
|
||||
content = TextField()
|
||||
tokens = IntegerField(default=0)
|
||||
timestamp = DateTimeField(default=datetime.now, index=True)
|
||||
metadata_json = TextField(default='{}')
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict:
|
||||
return json.loads(self.metadata_json) if self.metadata_json else {}
|
||||
|
||||
@metadata.setter
|
||||
def metadata(self, value: Dict):
|
||||
self.metadata_json = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
class Meta:
|
||||
indexes = (
|
||||
(('session', 'timestamp'), False),
|
||||
)
|
||||
|
||||
# 保存模型类
|
||||
self.Session = Session
|
||||
self.Message = Message
|
||||
|
||||
# 创建表
|
||||
self.db.connect()
|
||||
self.db.create_tables([Session, Message], safe=True)
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if not self.db.is_closed():
|
||||
self.db.close()
|
||||
|
||||
|
||||
# ==================== 会话管理器 ====================
|
||||
class SessionManager:
|
||||
"""会话管理器"""
|
||||
|
||||
def __init__(self, db_path: str = None):
|
||||
self.db_manager = DatabaseManager(db_path or 'chat_sessions.db')
|
||||
self.Session = self.db_manager.Session
|
||||
self.Message = self.db_manager.Message
|
||||
|
||||
def get_or_create_session(self, name: str, **kwargs) -> Any:
|
||||
"""获取或创建会话"""
|
||||
# 尝试获取现有会话
|
||||
session = self.Session.get_or_none(self.Session.name == name)
|
||||
|
||||
if session:
|
||||
return session
|
||||
else:
|
||||
# 创建新会话
|
||||
defaults = {
|
||||
'model': kwargs.get('model', 'deepseek-chat'),
|
||||
'system_prompt': kwargs.get('system_prompt', '你是一个乐于助人的助手。'),
|
||||
'metadata': kwargs.get('metadata', {})
|
||||
}
|
||||
|
||||
session = self.Session.create(
|
||||
name=name,
|
||||
**defaults
|
||||
)
|
||||
|
||||
# 添加系统提示作为第一条消息
|
||||
if defaults['system_prompt']:
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role='system',
|
||||
content=defaults['system_prompt']
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def add_message(self, session_name: str, role: str, content: str,
|
||||
tokens: int = 0, metadata: Dict = None) -> str:
|
||||
"""添加消息到会话"""
|
||||
session = self.get_or_create_session(session_name)
|
||||
|
||||
self.Message.create(
|
||||
session=session,
|
||||
role=role,
|
||||
content=content,
|
||||
tokens=tokens,
|
||||
timestamp=datetime.now(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
# 更新会话时间戳
|
||||
session.update_timestamp()
|
||||
|
||||
return content
|
||||
|
||||
def get_session_history(self, session_name: str, limit: int = 20) -> List[Dict]:
|
||||
"""获取会话历史"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return []
|
||||
|
||||
messages = (self.Message
|
||||
.select()
|
||||
.where(self.Message.session == session)
|
||||
.order_by(self.Message.timestamp.asc())
|
||||
.limit(limit))
|
||||
|
||||
history = []
|
||||
for msg in messages:
|
||||
history.append({
|
||||
'role': msg.role,
|
||||
'content': msg.content
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
def get_session_info(self, session_name: str) -> Dict:
|
||||
"""获取会话信息"""
|
||||
session = self.Session.get_or_none(self.Session.name == session_name)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
messages = session.messages
|
||||
message_count = messages.count()
|
||||
user_messages = messages.where(self.Message.role == 'user').count()
|
||||
assistant_messages = messages.where(self.Message.role == 'assistant').count()
|
||||
|
||||
return {
|
||||
'name': session.name,
|
||||
'model': session.model,
|
||||
'system_prompt': session.system_prompt,
|
||||
'message_count': message_count,
|
||||
'user_messages': user_messages,
|
||||
'assistant_messages': assistant_messages,
|
||||
'created_at': session.created_at.isoformat(),
|
||||
'updated_at': session.updated_at.isoformat(),
|
||||
'metadata': session.metadata
|
||||
}
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""列出所有会话"""
|
||||
sessions = self.Session.select().order_by(self.Session.updated_at.desc())
|
||||
|
||||
result = []
|
||||
for session in sessions:
|
||||
message_count = session.messages.count()
|
||||
result.append({
|
||||
'name': session.name,
|
||||
'message_count': message_count,
|
||||
'updated_at': session.updated_at.isoformat()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
self.db_manager.close()
|
||||
|
||||
|
||||
# ==================== 简化的客户端接口 ====================
|
||||
class DeepSeekChatClient:
|
||||
"""简化的DeepSeek聊天客户端"""
|
||||
|
||||
def __init__(self, api_key: str, db_path: str = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: DeepSeek API密钥
|
||||
db_path: 数据库路径,默认为chat_sessions.db
|
||||
"""
|
||||
self.client = OpenAI(api_key=api_key, base_url='https://api.deepseek.com')
|
||||
self.session_manager = SessionManager(db_path)
|
||||
|
||||
# 缓存已加载的会话配置
|
||||
self._session_cache = {}
|
||||
|
||||
def chat(self, user_input: str, name: str, model: str = None,
|
||||
system_prompt: str = None, **kwargs) -> str:
|
||||
"""
|
||||
发送聊天消息
|
||||
|
||||
Args:
|
||||
user_input: 用户输入的问题
|
||||
name: 会话名称
|
||||
model: 模型名称(仅在创建新会话时使用)
|
||||
system_prompt: 系统提示(仅在创建新会话时使用)
|
||||
**kwargs: 其他OpenAI API参数
|
||||
|
||||
Returns:
|
||||
AI的回复内容
|
||||
"""
|
||||
# 获取或创建会话
|
||||
session = self.session_manager.get_or_create_session(
|
||||
name=name,
|
||||
model=model or 'deepseek-chat',
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
# 保存用户消息
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='user',
|
||||
content=user_input,
|
||||
tokens=len(user_input) * 0.8
|
||||
)
|
||||
|
||||
# 获取会话历史
|
||||
history = self.session_manager.get_session_history(name, limit=10)
|
||||
|
||||
# 构建消息列表
|
||||
messages = []
|
||||
|
||||
# 添加系统提示(如果历史中没有)
|
||||
if not any(msg['role'] == 'system' for msg in history):
|
||||
messages.append({
|
||||
'role': 'system',
|
||||
'content': session.system_prompt
|
||||
})
|
||||
|
||||
# 添加上下文历史
|
||||
for msg in history:
|
||||
messages.append({
|
||||
'role': msg['role'],
|
||||
'content': msg['content']
|
||||
})
|
||||
|
||||
# 添加当前用户消息
|
||||
messages.append({
|
||||
'role': 'user',
|
||||
'content': user_input
|
||||
})
|
||||
|
||||
# 调用API
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=session.model,
|
||||
messages=messages,
|
||||
max_tokens=kwargs.get('max_tokens', 2000),
|
||||
temperature=kwargs.get('temperature', 0.7),
|
||||
stream=kwargs.get('stream', False)
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
ai_reply = response.choices[0].message.content
|
||||
|
||||
# 保存AI回复
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='assistant',
|
||||
content=ai_reply,
|
||||
tokens=len(ai_reply) * 0.8
|
||||
)
|
||||
|
||||
return ai_reply
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"API调用错误: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
# 保存错误信息
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='system',
|
||||
content=error_msg,
|
||||
metadata={'error': True, 'error_type': type(e).__name__}
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
def get_session_info(self, name: str) -> Dict:
|
||||
"""获取会话信息"""
|
||||
return self.session_manager.get_session_info(name)
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""列出所有会话"""
|
||||
return self.session_manager.list_sessions()
|
||||
|
||||
def export_session(self, name: str) -> Dict:
|
||||
"""导出会话数据"""
|
||||
info = self.get_session_info(name)
|
||||
history = self.session_manager.get_session_history(name, limit=100)
|
||||
|
||||
return {
|
||||
'session_info': info,
|
||||
'history': history,
|
||||
'exported_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def clear_session_history(self, name: str, keep_system: bool = True) -> bool:
|
||||
"""清空会话历史(保留系统提示)"""
|
||||
session = self.session_manager.Session.get_or_none(
|
||||
self.session_manager.Session.name == name
|
||||
)
|
||||
|
||||
if not session:
|
||||
return False
|
||||
|
||||
# 删除消息
|
||||
query = session.messages
|
||||
if keep_system:
|
||||
query = query.where(self.session_manager.Message.role != 'system')
|
||||
|
||||
query.delete().execute()
|
||||
|
||||
# 重新添加系统提示(如果需要)
|
||||
if keep_system and session.system_prompt:
|
||||
messages = session.messages.where(
|
||||
self.session_manager.Message.role == 'system'
|
||||
)
|
||||
if not messages.exists():
|
||||
self.session_manager.add_message(
|
||||
session_name=name,
|
||||
role='system',
|
||||
content=session.system_prompt
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
self.session_manager.close()
|
||||
|
||||
|
||||
# ==================== 使用示例 ====================
|
||||
def main():
|
||||
# 初始化客户端(请使用环境变量管理API密钥)
|
||||
API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3"
|
||||
|
||||
client = None
|
||||
try:
|
||||
# 创建客户端
|
||||
client = DeepSeekChatClient(api_key=API_KEY, db_path='my_chats.db')
|
||||
print("✅ 客户端初始化成功!")
|
||||
|
||||
# ========== 第一次调用:创建会话 ==========
|
||||
print("\n=== 第一次调用:创建技术讨论会话 ===")
|
||||
res1 = client.chat(
|
||||
name='技术讨论',
|
||||
model='deepseek-chat',
|
||||
system_prompt='你是一个技术专家,请提供详细的解释和代码示例。',
|
||||
user_input='请解释Python中的装饰器。'
|
||||
)
|
||||
print("回答:", res1[:100] + "..." if len(res1) > 100 else res1)
|
||||
|
||||
# ========== 第二次调用:使用现有会话 ==========
|
||||
print("\n=== 第二次调用:使用现有技术讨论会话 ===")
|
||||
res2 = client.chat(
|
||||
name='技术讨论', # 使用相同名称
|
||||
user_input='请给我一个装饰器的实际应用例子。' # 只需要用户输入
|
||||
)
|
||||
print("回答:", res2[:100] + "..." if len(res2) > 100 else res2)
|
||||
|
||||
# ========== 第三次调用:创建新类型会话 ==========
|
||||
print("\n=== 第三次调用:创建创意写作会话 ===")
|
||||
res3 = client.chat(
|
||||
name='创意写作',
|
||||
model='deepseek-chat',
|
||||
system_prompt='你是一个创意作家,请提供富有想象力的回答。',
|
||||
user_input='写一个关于人工智能的短故事开头'
|
||||
)
|
||||
print("回答:", res3[:100] + "..." if len(res3) > 100 else res3)
|
||||
|
||||
# ========== 查看会话信息 ==========
|
||||
print("\n=== 查看技术讨论会话信息 ===")
|
||||
session_info = client.get_session_info('技术讨论')
|
||||
print(f"会话名称: {session_info['name']}")
|
||||
print(f"模型: {session_info['model']}")
|
||||
print(f"消息总数: {session_info['message_count']}")
|
||||
print(f"用户消息: {session_info['user_messages']}")
|
||||
print(f"助手消息: {session_info['assistant_messages']}")
|
||||
|
||||
# ========== 列出所有会话 ==========
|
||||
print("\n=== 所有会话列表 ===")
|
||||
sessions = client.list_sessions()
|
||||
for session in sessions:
|
||||
print(f" {session['name']}: {session['message_count']} 条消息")
|
||||
|
||||
# ========== 导出会话数据 ==========
|
||||
print("\n=== 导出技术讨论会话 ===")
|
||||
export_data = client.export_session('技术讨论')
|
||||
print(f"导出了 {len(export_data['history'])} 条消息")
|
||||
|
||||
print("\n🎉 测试完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 程序运行出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
if client:
|
||||
client.close()
|
||||
|
||||
|
||||
# ==================== 简洁的使用示例 ====================
|
||||
def simple_usage():
|
||||
"""简洁的使用示例"""
|
||||
API_KEY = "sk-bab97a2b9be042e18d945394f8feefa3"
|
||||
|
||||
# 1. 创建客户端
|
||||
client = DeepSeekChatClient(api_key=API_KEY)
|
||||
|
||||
# 2. 创建新会话(第一次调用)
|
||||
response1 = client.chat(
|
||||
name='数学辅导',
|
||||
model='deepseek-chat',
|
||||
system_prompt='你是一个数学老师,请用简单易懂的方式解释数学概念。',
|
||||
user_input='什么是微积分?'
|
||||
)
|
||||
print("第一次回答:", response1[:100])
|
||||
|
||||
# 3. 继续对话(只需要会话名称和用户输入)
|
||||
response2 = client.chat(
|
||||
name='数学辅导',
|
||||
user_input='微积分有哪些实际应用?'
|
||||
)
|
||||
print("第二次回答:", response2[:100])
|
||||
|
||||
# 4. 查看会话信息
|
||||
info = client.get_session_info('数学辅导')
|
||||
print(f"会话 '{info['name']}' 有 {info['message_count']} 条消息")
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行完整示例
|
||||
main()
|
||||
|
||||
# 或运行简洁示例
|
||||
# simple_usage()
|
||||
BIN
ai_main/my_chats.db
Normal file
BIN
ai_main/my_chats.db
Normal file
Binary file not shown.
Reference in New Issue
Block a user