From 13b670f7127e6b75e70f64f740b7a4f193f6f040 Mon Sep 17 00:00:00 2001 From: Muhammad18557 Date: Sat, 5 Apr 2025 17:17:08 +0800 Subject: [PATCH] use sqlalchemy to query db without plain sql --- telegram-mcp-server/telegram/database.py | 435 +++++++++++------------ 1 file changed, 202 insertions(+), 233 deletions(-) diff --git a/telegram-mcp-server/telegram/database.py b/telegram-mcp-server/telegram/database.py index 2a1e6bb..d5f8e84 100644 --- a/telegram-mcp-server/telegram/database.py +++ b/telegram-mcp-server/telegram/database.py @@ -1,48 +1,65 @@ """ Database operations for retrieving and managing Telegram data. +Uses SQLAlchemy ORM for database access instead of raw SQL. """ -import sqlite3 +from sqlalchemy import create_engine, or_, and_, desc +from sqlalchemy.orm import sessionmaker, scoped_session from datetime import datetime from typing import Optional, List, Tuple from . import MESSAGES_DB_PATH from .models import Message, Chat, Contact, MessageContext +# Initialize SQLAlchemy engine and session +engine = create_engine( + f"sqlite:///{MESSAGES_DB_PATH}", + connect_args={"check_same_thread": False} # Needed for SQLite +) +SessionFactory = sessionmaker(bind=engine) +Session = scoped_session(SessionFactory) + +# Import SQLAlchemy models from telegram-bridge +import sys +import os + +# Add the parent directory to path to find telegram-bridge +bridge_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'telegram-bridge') +sys.path.append(bridge_path) + +# Import models directly from telegram-bridge +from database.models import Chat as DBChat, Message as DBMessage + def search_contacts(query: str) -> List[Contact]: """Search contacts by name or username.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() # Search in chats where type is 'user' - cursor.execute(""" - SELECT id, username, title - FROM chats - WHERE type = 'user' AND (title LIKE ? OR username LIKE ?) - ORDER BY title - LIMIT 50 - """, (f"%{query}%", f"%{query}%")) - - contacts = cursor.fetchall() + db_contacts = session.query(DBChat).filter( + DBChat.type == 'user', + or_( + DBChat.title.ilike(f"%{query}%"), + DBChat.username.ilike(f"%{query}%") + ) + ).order_by(DBChat.title).limit(50).all() result = [] - for contact_data in contacts: + for db_contact in db_contacts: contact = Contact( - id=contact_data[0], - username=contact_data[1], - name=contact_data[2] + id=db_contact.id, + username=db_contact.username, + name=db_contact.title ) result.append(contact) return result - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return [] finally: - if 'conn' in locals(): - conn.close() + session.close() def list_messages( date_range: Optional[Tuple[datetime, datetime]] = None, @@ -57,66 +74,50 @@ def list_messages( ) -> List[Message]: """Get messages matching the specified criteria with optional context.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() - # Build base query - query_parts = [""" - SELECT - m.id, - m.chat_id, - c.title, - m.sender_name, - m.content, - m.timestamp, - m.is_from_me, - m.sender_id - FROM messages m - """] - query_parts.append("JOIN chats c ON m.chat_id = c.id") - where_clauses = [] - params = [] + # Build base query with join + db_query = session.query(DBMessage, DBChat).join(DBChat) # Add filters + filters = [] if date_range: - where_clauses.append("m.timestamp BETWEEN ? AND ?") - params.extend([date_range[0].isoformat(), date_range[1].isoformat()]) + filters.append(and_( + DBMessage.timestamp >= date_range[0], + DBMessage.timestamp <= date_range[1] + )) if sender_id: - where_clauses.append("m.sender_id = ?") - params.append(sender_id) + filters.append(DBMessage.sender_id == sender_id) if chat_id: - where_clauses.append("m.chat_id = ?") - params.append(chat_id) + filters.append(DBMessage.chat_id == chat_id) if query: - where_clauses.append("LOWER(m.content) LIKE LOWER(?)") - params.append(f"%{query}%") + filters.append(DBMessage.content.ilike(f"%{query}%")) - if where_clauses: - query_parts.append("WHERE " + " AND ".join(where_clauses)) + if filters: + db_query = db_query.filter(and_(*filters)) # Add pagination offset = page * limit - query_parts.append("ORDER BY m.timestamp DESC") - query_parts.append("LIMIT ? OFFSET ?") - params.extend([limit, offset]) + db_query = db_query.order_by(desc(DBMessage.timestamp)) + db_query = db_query.limit(limit).offset(offset) - cursor.execute(" ".join(query_parts), tuple(params)) - messages = cursor.fetchall() + # Execute query + db_results = db_query.all() result = [] - for msg in messages: + for db_msg, db_chat in db_results: message = Message( - id=msg[0], - chat_id=msg[1], - chat_title=msg[2], - sender_name=msg[3], - content=msg[4], - timestamp=datetime.fromisoformat(msg[5]), - is_from_me=bool(msg[6]), - sender_id=msg[7] + id=db_msg.id, + chat_id=db_msg.chat_id, + chat_title=db_chat.title, + sender_name=db_msg.sender_name, + content=db_msg.content, + timestamp=db_msg.timestamp, + is_from_me=db_msg.is_from_me, + sender_id=db_msg.sender_id ) result.append(message) @@ -132,12 +133,11 @@ def list_messages( return result - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return [] finally: - if 'conn' in locals(): - conn.close() + session.close() def get_message_context( message_id: int, @@ -147,76 +147,75 @@ def get_message_context( ) -> MessageContext: """Get context around a specific message.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() # Get the target message first - cursor.execute(""" - SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id - FROM messages m - JOIN chats c ON m.chat_id = c.id - WHERE m.id = ? AND m.chat_id = ? - """, (message_id, chat_id)) - msg_data = cursor.fetchone() + result = session.query(DBMessage, DBChat) \ + .join(DBChat) \ + .filter(DBMessage.id == message_id, DBMessage.chat_id == chat_id) \ + .first() - if not msg_data: + if not result: raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found") + db_msg, db_chat = result target_message = Message( - id=msg_data[0], - chat_id=msg_data[1], - chat_title=msg_data[2], - sender_name=msg_data[3], - content=msg_data[4], - timestamp=datetime.fromisoformat(msg_data[5]), - is_from_me=bool(msg_data[6]), - sender_id=msg_data[7] + id=db_msg.id, + chat_id=db_msg.chat_id, + chat_title=db_chat.title, + sender_name=db_msg.sender_name, + content=db_msg.content, + timestamp=db_msg.timestamp, + is_from_me=db_msg.is_from_me, + sender_id=db_msg.sender_id ) # Get messages before - cursor.execute(""" - SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id - FROM messages m - JOIN chats c ON m.chat_id = c.id - WHERE m.chat_id = ? AND m.timestamp < ? - ORDER BY m.timestamp DESC - LIMIT ? - """, (chat_id, target_message.timestamp.isoformat(), before)) + before_results = session.query(DBMessage, DBChat) \ + .join(DBChat) \ + .filter( + DBMessage.chat_id == chat_id, + DBMessage.timestamp < target_message.timestamp + ) \ + .order_by(desc(DBMessage.timestamp)) \ + .limit(before) \ + .all() before_messages = [] - for msg in cursor.fetchall(): + for db_msg, db_chat in before_results: before_messages.append(Message( - id=msg[0], - chat_id=msg[1], - chat_title=msg[2], - sender_name=msg[3], - content=msg[4], - timestamp=datetime.fromisoformat(msg[5]), - is_from_me=bool(msg[6]), - sender_id=msg[7] + id=db_msg.id, + chat_id=db_msg.chat_id, + chat_title=db_chat.title, + sender_name=db_msg.sender_name, + content=db_msg.content, + timestamp=db_msg.timestamp, + is_from_me=db_msg.is_from_me, + sender_id=db_msg.sender_id )) # Get messages after - cursor.execute(""" - SELECT m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id - FROM messages m - JOIN chats c ON m.chat_id = c.id - WHERE m.chat_id = ? AND m.timestamp > ? - ORDER BY m.timestamp ASC - LIMIT ? - """, (chat_id, target_message.timestamp.isoformat(), after)) + after_results = session.query(DBMessage, DBChat) \ + .join(DBChat) \ + .filter( + DBMessage.chat_id == chat_id, + DBMessage.timestamp > target_message.timestamp + ) \ + .order_by(DBMessage.timestamp) \ + .limit(after) \ + .all() after_messages = [] - for msg in cursor.fetchall(): + for db_msg, db_chat in after_results: after_messages.append(Message( - id=msg[0], - chat_id=msg[1], - chat_title=msg[2], - sender_name=msg[3], - content=msg[4], - timestamp=datetime.fromisoformat(msg[5]), - is_from_me=bool(msg[6]), - sender_id=msg[7] + id=db_msg.id, + chat_id=db_msg.chat_id, + chat_title=db_chat.title, + sender_name=db_msg.sender_name, + content=db_msg.content, + timestamp=db_msg.timestamp, + is_from_me=db_msg.is_from_me, + sender_id=db_msg.sender_id )) return MessageContext( @@ -225,12 +224,11 @@ def get_message_context( after=after_messages ) - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") raise finally: - if 'conn' in locals(): - conn.close() + session.close() def list_chats( query: Optional[str] = None, @@ -241,124 +239,108 @@ def list_chats( ) -> List[Chat]: """Get chats matching the specified criteria.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() # Build base query - query_parts = ["SELECT id, title, username, type, last_message_time FROM chats"] + db_query = session.query(DBChat) - where_clauses = [] - params = [] + # Add filters + filters = [] if query: - where_clauses.append("(LOWER(title) LIKE LOWER(?) OR LOWER(username) LIKE LOWER(?))") - params.extend([f"%{query}%", f"%{query}%"]) + filters.append(or_( + DBChat.title.ilike(f"%{query}%"), + DBChat.username.ilike(f"%{query}%") + )) if chat_type: - where_clauses.append("type = ?") - params.append(chat_type) + filters.append(DBChat.type == chat_type) - if where_clauses: - query_parts.append("WHERE " + " AND ".join(where_clauses)) + if filters: + db_query = db_query.filter(and_(*filters)) # Add sorting - order_by = "last_message_time DESC" if sort_by == "last_active" else "title" - query_parts.append(f"ORDER BY {order_by}") + if sort_by == "last_active": + db_query = db_query.order_by(desc(DBChat.last_message_time)) + else: + db_query = db_query.order_by(DBChat.title) # Add pagination - offset = (page) * limit - query_parts.append("LIMIT ? OFFSET ?") - params.extend([limit, offset]) + offset = page * limit + db_query = db_query.limit(limit).offset(offset) - cursor.execute(" ".join(query_parts), tuple(params)) - chats = cursor.fetchall() + # Execute query + db_chats = db_query.all() result = [] - for chat_data in chats: - last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None + for db_chat in db_chats: chat = Chat( - id=chat_data[0], - title=chat_data[1], - username=chat_data[2], - type=chat_data[3], - last_message_time=last_message_time + id=db_chat.id, + title=db_chat.title, + username=db_chat.username, + type=db_chat.type, + last_message_time=db_chat.last_message_time ) result.append(chat) return result - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return [] finally: - if 'conn' in locals(): - conn.close() + session.close() def get_chat(chat_id: int) -> Optional[Chat]: """Get chat metadata by ID.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() - cursor.execute(""" - SELECT id, title, username, type, last_message_time - FROM chats - WHERE id = ? - """, (chat_id,)) + db_chat = session.query(DBChat).filter(DBChat.id == chat_id).first() - chat_data = cursor.fetchone() - - if not chat_data: + if not db_chat: return None - last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None return Chat( - id=chat_data[0], - title=chat_data[1], - username=chat_data[2], - type=chat_data[3], - last_message_time=last_message_time + id=db_chat.id, + title=db_chat.title, + username=db_chat.username, + type=db_chat.type, + last_message_time=db_chat.last_message_time ) - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return None finally: - if 'conn' in locals(): - conn.close() + session.close() def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]: """Get direct chat metadata by contact ID.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() - cursor.execute(""" - SELECT id, title, username, type, last_message_time - FROM chats - WHERE id = ? AND type = 'user' - """, (contact_id,)) + db_chat = session.query(DBChat).filter( + DBChat.id == contact_id, + DBChat.type == 'user' + ).first() - chat_data = cursor.fetchone() - - if not chat_data: + if not db_chat: return None - last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None return Chat( - id=chat_data[0], - title=chat_data[1], - username=chat_data[2], - type=chat_data[3], - last_message_time=last_message_time + id=db_chat.id, + title=db_chat.title, + username=db_chat.username, + type=db_chat.type, + last_message_time=db_chat.last_message_time ) - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return None finally: - if 'conn' in locals(): - conn.close() + session.close() def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]: """Get all chats involving the contact. @@ -369,77 +351,64 @@ def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[C page: Page number for pagination (default 0) """ try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() - cursor.execute(""" - SELECT DISTINCT - c.id, c.title, c.username, c.type, c.last_message_time - FROM chats c - JOIN messages m ON c.id = m.chat_id - WHERE m.sender_id = ? OR c.id = ? - ORDER BY c.last_message_time DESC - LIMIT ? OFFSET ? - """, (contact_id, contact_id, limit, page * limit)) - - chats = cursor.fetchall() + # Using a subquery to get distinct chats for the contact + db_chats = session.query(DBChat).join(DBMessage, DBChat.id == DBMessage.chat_id).filter( + or_( + DBMessage.sender_id == contact_id, + DBChat.id == contact_id + ) + ).distinct().order_by(desc(DBChat.last_message_time)).limit(limit).offset(page * limit).all() result = [] - for chat_data in chats: - last_message_time = datetime.fromisoformat(chat_data[4]) if chat_data[4] else None + for db_chat in db_chats: chat = Chat( - id=chat_data[0], - title=chat_data[1], - username=chat_data[2], - type=chat_data[3], - last_message_time=last_message_time + id=db_chat.id, + title=db_chat.title, + username=db_chat.username, + type=db_chat.type, + last_message_time=db_chat.last_message_time ) result.append(chat) return result - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return [] finally: - if 'conn' in locals(): - conn.close() + session.close() def get_last_interaction(contact_id: int) -> Optional[Message]: """Get most recent message involving the contact.""" try: - conn = sqlite3.connect(MESSAGES_DB_PATH) - cursor = conn.cursor() + session = Session() - cursor.execute(""" - SELECT - m.id, m.chat_id, c.title, m.sender_name, m.content, m.timestamp, m.is_from_me, m.sender_id - FROM messages m - JOIN chats c ON m.chat_id = c.id - WHERE m.sender_id = ? OR c.id = ? - ORDER BY m.timestamp DESC - LIMIT 1 - """, (contact_id, contact_id)) + result = session.query(DBMessage, DBChat).join(DBChat).filter( + or_( + DBMessage.sender_id == contact_id, + DBChat.id == contact_id + ) + ).order_by(desc(DBMessage.timestamp)).first() - msg_data = cursor.fetchone() - - if not msg_data: + if not result: return None + db_msg, db_chat = result return Message( - id=msg_data[0], - chat_id=msg_data[1], - chat_title=msg_data[2], - sender_name=msg_data[3], - content=msg_data[4], - timestamp=datetime.fromisoformat(msg_data[5]), - is_from_me=bool(msg_data[6]), - sender_id=msg_data[7] + id=db_msg.id, + chat_id=db_msg.chat_id, + chat_title=db_chat.title, + sender_name=db_msg.sender_name, + content=db_msg.content, + timestamp=db_msg.timestamp, + is_from_me=db_msg.is_from_me, + sender_id=db_msg.sender_id ) - except sqlite3.Error as e: + except Exception as e: print(f"Database error: {e}") return None finally: - if 'conn' in locals(): - conn.close() \ No newline at end of file + session.close() \ No newline at end of file