use sqlalchemy to query db without plain sql

This commit is contained in:
Muhammad18557 2025-04-05 17:17:08 +08:00
parent e12dca71a9
commit 13b670f712

View file

@ -1,48 +1,65 @@
""" """
Database operations for retrieving and managing Telegram data. Database operations for retrieving and managing Telegram data.
Uses SQLAlchemy ORM for database access instead of raw SQL.
""" """
import sqlite3 from sqlalchemy import create_engine, or_, and_, desc
from sqlalchemy.orm import sessionmaker, scoped_session
from datetime import datetime from datetime import datetime
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from . import MESSAGES_DB_PATH from . import MESSAGES_DB_PATH
from .models import Message, Chat, Contact, MessageContext from .models import Message, Chat, Contact, MessageContext
# Initialize SQLAlchemy engine and session
engine = create_engine(
f"sqlite:///{MESSAGES_DB_PATH}",
connect_args={"check_same_thread": False} # Needed for SQLite
)
SessionFactory = sessionmaker(bind=engine)
Session = scoped_session(SessionFactory)
# Import SQLAlchemy models from telegram-bridge
import sys
import os
# Add the parent directory to path to find telegram-bridge
bridge_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'telegram-bridge')
sys.path.append(bridge_path)
# Import models directly from telegram-bridge
from database.models import Chat as DBChat, Message as DBMessage
def search_contacts(query: str) -> List[Contact]: def search_contacts(query: str) -> List[Contact]:
"""Search contacts by name or username.""" """Search contacts by name or username."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
# Search in chats where type is 'user' # Search in chats where type is 'user'
cursor.execute(""" db_contacts = session.query(DBChat).filter(
SELECT id, username, title DBChat.type == 'user',
FROM chats or_(
WHERE type = 'user' AND (title LIKE ? OR username LIKE ?) DBChat.title.ilike(f"%{query}%"),
ORDER BY title DBChat.username.ilike(f"%{query}%")
LIMIT 50 )
""", (f"%{query}%", f"%{query}%")) ).order_by(DBChat.title).limit(50).all()
contacts = cursor.fetchall()
result = [] result = []
for contact_data in contacts: for db_contact in db_contacts:
contact = Contact( contact = Contact(
id=contact_data[0], id=db_contact.id,
username=contact_data[1], username=db_contact.username,
name=contact_data[2] name=db_contact.title
) )
result.append(contact) result.append(contact)
return result return result
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return [] return []
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def list_messages( def list_messages(
date_range: Optional[Tuple[datetime, datetime]] = None, date_range: Optional[Tuple[datetime, datetime]] = None,
@ -57,66 +74,50 @@ def list_messages(
) -> List[Message]: ) -> List[Message]:
"""Get messages matching the specified criteria with optional context.""" """Get messages matching the specified criteria with optional context."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
# Build base query # Build base query with join
query_parts = [""" db_query = session.query(DBMessage, DBChat).join(DBChat)
SELECT
m.id,
m.chat_id,
c.title,
m.sender_name,
m.content,
m.timestamp,
m.is_from_me,
m.sender_id
FROM messages m
"""]
query_parts.append("JOIN chats c ON m.chat_id = c.id")
where_clauses = []
params = []
# Add filters # Add filters
filters = []
if date_range: if date_range:
where_clauses.append("m.timestamp BETWEEN ? AND ?") filters.append(and_(
params.extend([date_range[0].isoformat(), date_range[1].isoformat()]) DBMessage.timestamp >= date_range[0],
DBMessage.timestamp <= date_range[1]
))
if sender_id: if sender_id:
where_clauses.append("m.sender_id = ?") filters.append(DBMessage.sender_id == sender_id)
params.append(sender_id)
if chat_id: if chat_id:
where_clauses.append("m.chat_id = ?") filters.append(DBMessage.chat_id == chat_id)
params.append(chat_id)
if query: if query:
where_clauses.append("LOWER(m.content) LIKE LOWER(?)") filters.append(DBMessage.content.ilike(f"%{query}%"))
params.append(f"%{query}%")
if where_clauses: if filters:
query_parts.append("WHERE " + " AND ".join(where_clauses)) db_query = db_query.filter(and_(*filters))
# Add pagination # Add pagination
offset = page * limit offset = page * limit
query_parts.append("ORDER BY m.timestamp DESC") db_query = db_query.order_by(desc(DBMessage.timestamp))
query_parts.append("LIMIT ? OFFSET ?") db_query = db_query.limit(limit).offset(offset)
params.extend([limit, offset])
cursor.execute(" ".join(query_parts), tuple(params)) # Execute query
messages = cursor.fetchall() db_results = db_query.all()
result = [] result = []
for msg in messages: for db_msg, db_chat in db_results:
message = Message( message = Message(
id=msg[0], id=db_msg.id,
chat_id=msg[1], chat_id=db_msg.chat_id,
chat_title=msg[2], chat_title=db_chat.title,
sender_name=msg[3], sender_name=db_msg.sender_name,
content=msg[4], content=db_msg.content,
timestamp=datetime.fromisoformat(msg[5]), timestamp=db_msg.timestamp,
is_from_me=bool(msg[6]), is_from_me=db_msg.is_from_me,
sender_id=msg[7] sender_id=db_msg.sender_id
) )
result.append(message) result.append(message)
@ -132,12 +133,11 @@ def list_messages(
return result return result
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return [] return []
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def get_message_context( def get_message_context(
message_id: int, message_id: int,
@ -147,76 +147,75 @@ def get_message_context(
) -> MessageContext: ) -> MessageContext:
"""Get context around a specific message.""" """Get context around a specific message."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
# Get the target message first # Get the target message first
cursor.execute(""" result = session.query(DBMessage, DBChat) \
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id .join(DBChat) \
FROM messages m .filter(DBMessage.id == message_id, DBMessage.chat_id == chat_id) \
JOIN chats c ON m.chat_id = c.id .first()
WHERE m.id = ? AND m.chat_id = ?
""", (message_id, chat_id))
msg_data = cursor.fetchone()
if not msg_data: if not result:
raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found") raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found")
db_msg, db_chat = result
target_message = Message( target_message = Message(
id=msg_data[0], id=db_msg.id,
chat_id=msg_data[1], chat_id=db_msg.chat_id,
chat_title=msg_data[2], chat_title=db_chat.title,
sender_name=msg_data[3], sender_name=db_msg.sender_name,
content=msg_data[4], content=db_msg.content,
timestamp=datetime.fromisoformat(msg_data[5]), timestamp=db_msg.timestamp,
is_from_me=bool(msg_data[6]), is_from_me=db_msg.is_from_me,
sender_id=msg_data[7] sender_id=db_msg.sender_id
) )
# Get messages before # Get messages before
cursor.execute(""" before_results = session.query(DBMessage, DBChat) \
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id .join(DBChat) \
FROM messages m .filter(
JOIN chats c ON m.chat_id = c.id DBMessage.chat_id == chat_id,
WHERE m.chat_id = ? AND m.timestamp < ? DBMessage.timestamp < target_message.timestamp
ORDER BY m.timestamp DESC ) \
LIMIT ? .order_by(desc(DBMessage.timestamp)) \
""", (chat_id, target_message.timestamp.isoformat(), before)) .limit(before) \
.all()
before_messages = [] before_messages = []
for msg in cursor.fetchall(): for db_msg, db_chat in before_results:
before_messages.append(Message( before_messages.append(Message(
id=msg[0], id=db_msg.id,
chat_id=msg[1], chat_id=db_msg.chat_id,
chat_title=msg[2], chat_title=db_chat.title,
sender_name=msg[3], sender_name=db_msg.sender_name,
content=msg[4], content=db_msg.content,
timestamp=datetime.fromisoformat(msg[5]), timestamp=db_msg.timestamp,
is_from_me=bool(msg[6]), is_from_me=db_msg.is_from_me,
sender_id=msg[7] sender_id=db_msg.sender_id
)) ))
# Get messages after # Get messages after
cursor.execute(""" after_results = session.query(DBMessage, DBChat) \
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id .join(DBChat) \
FROM messages m .filter(
JOIN chats c ON m.chat_id = c.id DBMessage.chat_id == chat_id,
WHERE m.chat_id = ? AND m.timestamp > ? DBMessage.timestamp > target_message.timestamp
ORDER BY m.timestamp ASC ) \
LIMIT ? .order_by(DBMessage.timestamp) \
""", (chat_id, target_message.timestamp.isoformat(), after)) .limit(after) \
.all()
after_messages = [] after_messages = []
for msg in cursor.fetchall(): for db_msg, db_chat in after_results:
after_messages.append(Message( after_messages.append(Message(
id=msg[0], id=db_msg.id,
chat_id=msg[1], chat_id=db_msg.chat_id,
chat_title=msg[2], chat_title=db_chat.title,
sender_name=msg[3], sender_name=db_msg.sender_name,
content=msg[4], content=db_msg.content,
timestamp=datetime.fromisoformat(msg[5]), timestamp=db_msg.timestamp,
is_from_me=bool(msg[6]), is_from_me=db_msg.is_from_me,
sender_id=msg[7] sender_id=db_msg.sender_id
)) ))
return MessageContext( return MessageContext(
@ -225,12 +224,11 @@ def get_message_context(
after=after_messages after=after_messages
) )
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
raise raise
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def list_chats( def list_chats(
query: Optional[str] = None, query: Optional[str] = None,
@ -241,124 +239,108 @@ def list_chats(
) -> List[Chat]: ) -> List[Chat]:
"""Get chats matching the specified criteria.""" """Get chats matching the specified criteria."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
# Build base query # Build base query
query_parts = ["SELECT id, title, username, type, last_message_time FROM chats"] db_query = session.query(DBChat)
where_clauses = [] # Add filters
params = [] filters = []
if query: if query:
where_clauses.append("(LOWER(title) LIKE LOWER(?) OR LOWER(username) LIKE LOWER(?))") filters.append(or_(
params.extend([f"%{query}%", f"%{query}%"]) DBChat.title.ilike(f"%{query}%"),
DBChat.username.ilike(f"%{query}%")
))
if chat_type: if chat_type:
where_clauses.append("type = ?") filters.append(DBChat.type == chat_type)
params.append(chat_type)
if where_clauses: if filters:
query_parts.append("WHERE " + " AND ".join(where_clauses)) db_query = db_query.filter(and_(*filters))
# Add sorting # Add sorting
order_by = "last_message_time DESC" if sort_by == "last_active" else "title" if sort_by == "last_active":
query_parts.append(f"ORDER BY {order_by}") db_query = db_query.order_by(desc(DBChat.last_message_time))
else:
db_query = db_query.order_by(DBChat.title)
# Add pagination # Add pagination
offset = (page) * limit offset = page * limit
query_parts.append("LIMIT ? OFFSET ?") db_query = db_query.limit(limit).offset(offset)
params.extend([limit, offset])
cursor.execute(" ".join(query_parts), tuple(params)) # Execute query
chats = cursor.fetchall() db_chats = db_query.all()
result = [] result = []
for chat_data in chats: for db_chat in db_chats:
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
chat = Chat( chat = Chat(
id=chat_data[0], id=db_chat.id,
title=chat_data[1], title=db_chat.title,
username=chat_data[2], username=db_chat.username,
type=chat_data[3], type=db_chat.type,
last_message_time=last_message_time last_message_time=db_chat.last_message_time
) )
result.append(chat) result.append(chat)
return result return result
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return [] return []
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def get_chat(chat_id: int) -> Optional[Chat]: def get_chat(chat_id: int) -> Optional[Chat]:
"""Get chat metadata by ID.""" """Get chat metadata by ID."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
cursor.execute(""" db_chat = session.query(DBChat).filter(DBChat.id == chat_id).first()
SELECT id, title, username, type, last_message_time
FROM chats
WHERE id = ?
""", (chat_id,))
chat_data = cursor.fetchone() if not db_chat:
if not chat_data:
return None return None
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
return Chat( return Chat(
id=chat_data[0], id=db_chat.id,
title=chat_data[1], title=db_chat.title,
username=chat_data[2], username=db_chat.username,
type=chat_data[3], type=db_chat.type,
last_message_time=last_message_time last_message_time=db_chat.last_message_time
) )
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return None return None
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]: def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]:
"""Get direct chat metadata by contact ID.""" """Get direct chat metadata by contact ID."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
cursor.execute(""" db_chat = session.query(DBChat).filter(
SELECT id, title, username, type, last_message_time DBChat.id == contact_id,
FROM chats DBChat.type == 'user'
WHERE id = ? AND type = 'user' ).first()
""", (contact_id,))
chat_data = cursor.fetchone() if not db_chat:
if not chat_data:
return None return None
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
return Chat( return Chat(
id=chat_data[0], id=db_chat.id,
title=chat_data[1], title=db_chat.title,
username=chat_data[2], username=db_chat.username,
type=chat_data[3], type=db_chat.type,
last_message_time=last_message_time last_message_time=db_chat.last_message_time
) )
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return None return None
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]: def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]:
"""Get all chats involving the contact. """Get all chats involving the contact.
@ -369,77 +351,64 @@ def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[C
page: Page number for pagination (default 0) page: Page number for pagination (default 0)
""" """
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
cursor.execute(""" # Using a subquery to get distinct chats for the contact
SELECT DISTINCT db_chats = session.query(DBChat).join(DBMessage, DBChat.id == DBMessage.chat_id).filter(
c.id, c.title, c.username, c.type, c.last_message_time or_(
FROM chats c DBMessage.sender_id == contact_id,
JOIN messages m ON c.id = m.chat_id DBChat.id == contact_id
WHERE m.sender_id = ? OR c.id = ? )
ORDER BY c.last_message_time DESC ).distinct().order_by(desc(DBChat.last_message_time)).limit(limit).offset(page * limit).all()
LIMIT ? OFFSET ?
""", (contact_id, contact_id, limit, page * limit))
chats = cursor.fetchall()
result = [] result = []
for chat_data in chats: for db_chat in db_chats:
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
chat = Chat( chat = Chat(
id=chat_data[0], id=db_chat.id,
title=chat_data[1], title=db_chat.title,
username=chat_data[2], username=db_chat.username,
type=chat_data[3], type=db_chat.type,
last_message_time=last_message_time last_message_time=db_chat.last_message_time
) )
result.append(chat) result.append(chat)
return result return result
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return [] return []
finally: finally:
if 'conn' in locals(): session.close()
conn.close()
def get_last_interaction(contact_id: int) -> Optional[Message]: def get_last_interaction(contact_id: int) -> Optional[Message]:
"""Get most recent message involving the contact.""" """Get most recent message involving the contact."""
try: try:
conn = sqlite3.connect(MESSAGES_DB_PATH) session = Session()
cursor = conn.cursor()
cursor.execute(""" result = session.query(DBMessage, DBChat).join(DBChat).filter(
SELECT or_(
m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id DBMessage.sender_id == contact_id,
FROM messages m DBChat.id == contact_id
JOIN chats c ON m.chat_id = c.id )
WHERE m.sender_id = ? OR c.id = ? ).order_by(desc(DBMessage.timestamp)).first()
ORDER BY m.timestamp DESC
LIMIT 1
""", (contact_id, contact_id))
msg_data = cursor.fetchone() if not result:
if not msg_data:
return None return None
db_msg, db_chat = result
return Message( return Message(
id=msg_data[0], id=db_msg.id,
chat_id=msg_data[1], chat_id=db_msg.chat_id,
chat_title=msg_data[2], chat_title=db_chat.title,
sender_name=msg_data[3], sender_name=db_msg.sender_name,
content=msg_data[4], content=db_msg.content,
timestamp=datetime.fromisoformat(msg_data[5]), timestamp=db_msg.timestamp,
is_from_me=bool(msg_data[6]), is_from_me=db_msg.is_from_me,
sender_id=msg_data[7] sender_id=db_msg.sender_id
) )
except sqlite3.Error as e: except Exception as e:
print(f"Database error: {e}") print(f"Database error: {e}")
return None return None
finally: finally:
if 'conn' in locals(): session.close()
conn.close()