use sqlalchemy to query db without plain sql

This commit is contained in:
Muhammad18557 2025-04-05 17:17:08 +08:00
parent e12dca71a9
commit 13b670f712

View file

@ -1,48 +1,65 @@
"""
Database operations for retrieving and managing Telegram data.
Uses SQLAlchemy ORM for database access instead of raw SQL.
"""
import sqlite3
from sqlalchemy import create_engine, or_, and_, desc
from sqlalchemy.orm import sessionmaker, scoped_session
from datetime import datetime
from typing import Optional, List, Tuple
from . import MESSAGES_DB_PATH
from .models import Message, Chat, Contact, MessageContext
# Initialize SQLAlchemy engine and session
engine = create_engine(
f"sqlite:///{MESSAGES_DB_PATH}",
connect_args={"check_same_thread": False} # Needed for SQLite
)
SessionFactory = sessionmaker(bind=engine)
Session = scoped_session(SessionFactory)
# Import SQLAlchemy models from telegram-bridge
import sys
import os
# Add the parent directory to path to find telegram-bridge
bridge_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'telegram-bridge')
sys.path.append(bridge_path)
# Import models directly from telegram-bridge
from database.models import Chat as DBChat, Message as DBMessage
def search_contacts(query: str) -> List[Contact]:
"""Search contacts by name or username."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
# Search in chats where type is 'user'
cursor.execute("""
SELECT id, username, title
FROM chats
WHERE type = 'user' AND (title LIKE ? OR username LIKE ?)
ORDER BY title
LIMIT 50
""", (f"%{query}%", f"%{query}%"))
contacts = cursor.fetchall()
db_contacts = session.query(DBChat).filter(
DBChat.type == 'user',
or_(
DBChat.title.ilike(f"%{query}%"),
DBChat.username.ilike(f"%{query}%")
)
).order_by(DBChat.title).limit(50).all()
result = []
for contact_data in contacts:
for db_contact in db_contacts:
contact = Contact(
id=contact_data[0],
username=contact_data[1],
name=contact_data[2]
id=db_contact.id,
username=db_contact.username,
name=db_contact.title
)
result.append(contact)
return result
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return []
finally:
if 'conn' in locals():
conn.close()
session.close()
def list_messages(
date_range: Optional[Tuple[datetime, datetime]] = None,
@ -57,66 +74,50 @@ def list_messages(
) -> List[Message]:
"""Get messages matching the specified criteria with optional context."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
# Build base query
query_parts = ["""
SELECT
m.id,
m.chat_id,
c.title,
m.sender_name,
m.content,
m.timestamp,
m.is_from_me,
m.sender_id
FROM messages m
"""]
query_parts.append("JOIN chats c ON m.chat_id = c.id")
where_clauses = []
params = []
# Build base query with join
db_query = session.query(DBMessage, DBChat).join(DBChat)
# Add filters
filters = []
if date_range:
where_clauses.append("m.timestamp BETWEEN ? AND ?")
params.extend([date_range[0].isoformat(), date_range[1].isoformat()])
filters.append(and_(
DBMessage.timestamp >= date_range[0],
DBMessage.timestamp <= date_range[1]
))
if sender_id:
where_clauses.append("m.sender_id = ?")
params.append(sender_id)
filters.append(DBMessage.sender_id == sender_id)
if chat_id:
where_clauses.append("m.chat_id = ?")
params.append(chat_id)
filters.append(DBMessage.chat_id == chat_id)
if query:
where_clauses.append("LOWER(m.content) LIKE LOWER(?)")
params.append(f"%{query}%")
filters.append(DBMessage.content.ilike(f"%{query}%"))
if where_clauses:
query_parts.append("WHERE " + " AND ".join(where_clauses))
if filters:
db_query = db_query.filter(and_(*filters))
# Add pagination
offset = page * limit
query_parts.append("ORDER BY m.timestamp DESC")
query_parts.append("LIMIT ? OFFSET ?")
params.extend([limit, offset])
db_query = db_query.order_by(desc(DBMessage.timestamp))
db_query = db_query.limit(limit).offset(offset)
cursor.execute(" ".join(query_parts), tuple(params))
messages = cursor.fetchall()
# Execute query
db_results = db_query.all()
result = []
for msg in messages:
for db_msg, db_chat in db_results:
message = Message(
id=msg[0],
chat_id=msg[1],
chat_title=msg[2],
sender_name=msg[3],
content=msg[4],
timestamp=datetime.fromisoformat(msg[5]),
is_from_me=bool(msg[6]),
sender_id=msg[7]
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
result.append(message)
@ -132,12 +133,11 @@ def list_messages(
return result
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return []
finally:
if 'conn' in locals():
conn.close()
session.close()
def get_message_context(
message_id: int,
@ -147,76 +147,75 @@ def get_message_context(
) -> MessageContext:
"""Get context around a specific message."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
# Get the target message first
cursor.execute("""
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id
FROM messages m
JOIN chats c ON m.chat_id = c.id
WHERE m.id = ? AND m.chat_id = ?
""", (message_id, chat_id))
msg_data = cursor.fetchone()
result = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(DBMessage.id == message_id, DBMessage.chat_id == chat_id) \
.first()
if not msg_data:
if not result:
raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found")
db_msg, db_chat = result
target_message = Message(
id=msg_data[0],
chat_id=msg_data[1],
chat_title=msg_data[2],
sender_name=msg_data[3],
content=msg_data[4],
timestamp=datetime.fromisoformat(msg_data[5]),
is_from_me=bool(msg_data[6]),
sender_id=msg_data[7]
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
# Get messages before
cursor.execute("""
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id
FROM messages m
JOIN chats c ON m.chat_id = c.id
WHERE m.chat_id = ? AND m.timestamp < ?
ORDER BY m.timestamp DESC
LIMIT ?
""", (chat_id, target_message.timestamp.isoformat(), before))
before_results = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(
DBMessage.chat_id == chat_id,
DBMessage.timestamp < target_message.timestamp
) \
.order_by(desc(DBMessage.timestamp)) \
.limit(before) \
.all()
before_messages = []
for msg in cursor.fetchall():
for db_msg, db_chat in before_results:
before_messages.append(Message(
id=msg[0],
chat_id=msg[1],
chat_title=msg[2],
sender_name=msg[3],
content=msg[4],
timestamp=datetime.fromisoformat(msg[5]),
is_from_me=bool(msg[6]),
sender_id=msg[7]
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
))
# Get messages after
cursor.execute("""
SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id
FROM messages m
JOIN chats c ON m.chat_id = c.id
WHERE m.chat_id = ? AND m.timestamp > ?
ORDER BY m.timestamp ASC
LIMIT ?
""", (chat_id, target_message.timestamp.isoformat(), after))
after_results = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(
DBMessage.chat_id == chat_id,
DBMessage.timestamp > target_message.timestamp
) \
.order_by(DBMessage.timestamp) \
.limit(after) \
.all()
after_messages = []
for msg in cursor.fetchall():
for db_msg, db_chat in after_results:
after_messages.append(Message(
id=msg[0],
chat_id=msg[1],
chat_title=msg[2],
sender_name=msg[3],
content=msg[4],
timestamp=datetime.fromisoformat(msg[5]),
is_from_me=bool(msg[6]),
sender_id=msg[7]
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
))
return MessageContext(
@ -225,12 +224,11 @@ def get_message_context(
after=after_messages
)
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
raise
finally:
if 'conn' in locals():
conn.close()
session.close()
def list_chats(
query: Optional[str] = None,
@ -241,124 +239,108 @@ def list_chats(
) -> List[Chat]:
"""Get chats matching the specified criteria."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
# Build base query
query_parts = ["SELECT id, title, username, type, last_message_time FROM chats"]
db_query = session.query(DBChat)
where_clauses = []
params = []
# Add filters
filters = []
if query:
where_clauses.append("(LOWER(title) LIKE LOWER(?) OR LOWER(username) LIKE LOWER(?))")
params.extend([f"%{query}%", f"%{query}%"])
filters.append(or_(
DBChat.title.ilike(f"%{query}%"),
DBChat.username.ilike(f"%{query}%")
))
if chat_type:
where_clauses.append("type = ?")
params.append(chat_type)
filters.append(DBChat.type == chat_type)
if where_clauses:
query_parts.append("WHERE " + " AND ".join(where_clauses))
if filters:
db_query = db_query.filter(and_(*filters))
# Add sorting
order_by = "last_message_time DESC" if sort_by == "last_active" else "title"
query_parts.append(f"ORDER BY {order_by}")
if sort_by == "last_active":
db_query = db_query.order_by(desc(DBChat.last_message_time))
else:
db_query = db_query.order_by(DBChat.title)
# Add pagination
offset = (page) * limit
query_parts.append("LIMIT ? OFFSET ?")
params.extend([limit, offset])
offset = page * limit
db_query = db_query.limit(limit).offset(offset)
cursor.execute(" ".join(query_parts), tuple(params))
chats = cursor.fetchall()
# Execute query
db_chats = db_query.all()
result = []
for chat_data in chats:
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
for db_chat in db_chats:
chat = Chat(
id=chat_data[0],
title=chat_data[1],
username=chat_data[2],
type=chat_data[3],
last_message_time=last_message_time
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
result.append(chat)
return result
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return []
finally:
if 'conn' in locals():
conn.close()
session.close()
def get_chat(chat_id: int) -> Optional[Chat]:
"""Get chat metadata by ID."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
cursor.execute("""
SELECT id, title, username, type, last_message_time
FROM chats
WHERE id = ?
""", (chat_id,))
db_chat = session.query(DBChat).filter(DBChat.id == chat_id).first()
chat_data = cursor.fetchone()
if not chat_data:
if not db_chat:
return None
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
return Chat(
id=chat_data[0],
title=chat_data[1],
username=chat_data[2],
type=chat_data[3],
last_message_time=last_message_time
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return None
finally:
if 'conn' in locals():
conn.close()
session.close()
def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]:
"""Get direct chat metadata by contact ID."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
cursor.execute("""
SELECT id, title, username, type, last_message_time
FROM chats
WHERE id = ? AND type = 'user'
""", (contact_id,))
db_chat = session.query(DBChat).filter(
DBChat.id == contact_id,
DBChat.type == 'user'
).first()
chat_data = cursor.fetchone()
if not chat_data:
if not db_chat:
return None
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
return Chat(
id=chat_data[0],
title=chat_data[1],
username=chat_data[2],
type=chat_data[3],
last_message_time=last_message_time
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return None
finally:
if 'conn' in locals():
conn.close()
session.close()
def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]:
"""Get all chats involving the contact.
@ -369,77 +351,64 @@ def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[C
page: Page number for pagination (default 0)
"""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
cursor.execute("""
SELECT DISTINCT
c.id, c.title, c.username, c.type, c.last_message_time
FROM chats c
JOIN messages m ON c.id = m.chat_id
WHERE m.sender_id = ? OR c.id = ?
ORDER BY c.last_message_time DESC
LIMIT ? OFFSET ?
""", (contact_id, contact_id, limit, page * limit))
chats = cursor.fetchall()
# Using a subquery to get distinct chats for the contact
db_chats = session.query(DBChat).join(DBMessage, DBChat.id == DBMessage.chat_id).filter(
or_(
DBMessage.sender_id == contact_id,
DBChat.id == contact_id
)
).distinct().order_by(desc(DBChat.last_message_time)).limit(limit).offset(page * limit).all()
result = []
for chat_data in chats:
last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None
for db_chat in db_chats:
chat = Chat(
id=chat_data[0],
title=chat_data[1],
username=chat_data[2],
type=chat_data[3],
last_message_time=last_message_time
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
result.append(chat)
return result
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return []
finally:
if 'conn' in locals():
conn.close()
session.close()
def get_last_interaction(contact_id: int) -> Optional[Message]:
"""Get most recent message involving the contact."""
try:
conn = sqlite3.connect(MESSAGES_DB_PATH)
cursor = conn.cursor()
session = Session()
cursor.execute("""
SELECT
m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id
FROM messages m
JOIN chats c ON m.chat_id = c.id
WHERE m.sender_id = ? OR c.id = ?
ORDER BY m.timestamp DESC
LIMIT 1
""", (contact_id, contact_id))
result = session.query(DBMessage, DBChat).join(DBChat).filter(
or_(
DBMessage.sender_id == contact_id,
DBChat.id == contact_id
)
).order_by(desc(DBMessage.timestamp)).first()
msg_data = cursor.fetchone()
if not msg_data:
if not result:
return None
db_msg, db_chat = result
return Message(
id=msg_data[0],
chat_id=msg_data[1],
chat_title=msg_data[2],
sender_name=msg_data[3],
content=msg_data[4],
timestamp=datetime.fromisoformat(msg_data[5]),
is_from_me=bool(msg_data[6]),
sender_id=msg_data[7]
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
except sqlite3.Error as e:
except Exception as e:
print(f"Database error: {e}")
return None
finally:
if 'conn' in locals():
conn.close()
session.close()