telegram-mcp/telegram-mcp-server/telegram/database.py
2025-04-05 17:17:08 +08:00

414 lines
No EOL
12 KiB
Python

"""
Database operations for retrieving and managing Telegram data.
Uses SQLAlchemy ORM for database access instead of raw SQL.
"""
from sqlalchemy import create_engine, or_, and_, desc
from sqlalchemy.orm import sessionmaker, scoped_session
from datetime import datetime
from typing import Optional, List, Tuple
from . import MESSAGES_DB_PATH
from .models import Message, Chat, Contact, MessageContext
# Initialize SQLAlchemy engine and session
engine = create_engine(
f"sqlite:///{MESSAGES_DB_PATH}",
connect_args={"check_same_thread": False} # Needed for SQLite
)
SessionFactory = sessionmaker(bind=engine)
Session = scoped_session(SessionFactory)
# Import SQLAlchemy models from telegram-bridge
import sys
import os
# Add the parent directory to path to find telegram-bridge
bridge_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'telegram-bridge')
sys.path.append(bridge_path)
# Import models directly from telegram-bridge
from database.models import Chat as DBChat, Message as DBMessage
def search_contacts(query: str) -> List[Contact]:
"""Search contacts by name or username."""
try:
session = Session()
# Search in chats where type is 'user'
db_contacts = session.query(DBChat).filter(
DBChat.type == 'user',
or_(
DBChat.title.ilike(f"%{query}%"),
DBChat.username.ilike(f"%{query}%")
)
).order_by(DBChat.title).limit(50).all()
result = []
for db_contact in db_contacts:
contact = Contact(
id=db_contact.id,
username=db_contact.username,
name=db_contact.title
)
result.append(contact)
return result
except Exception as e:
print(f"Database error: {e}")
return []
finally:
session.close()
def list_messages(
date_range: Optional[Tuple[datetime, datetime]] = None,
sender_id: Optional[int] = None,
chat_id: Optional[int] = None,
query: Optional[str] = None,
limit: int = 20,
page: int = 0,
include_context: bool = True,
context_before: int = 1,
context_after: int = 1
) -> List[Message]:
"""Get messages matching the specified criteria with optional context."""
try:
session = Session()
# Build base query with join
db_query = session.query(DBMessage, DBChat).join(DBChat)
# Add filters
filters = []
if date_range:
filters.append(and_(
DBMessage.timestamp >= date_range[0],
DBMessage.timestamp <= date_range[1]
))
if sender_id:
filters.append(DBMessage.sender_id == sender_id)
if chat_id:
filters.append(DBMessage.chat_id == chat_id)
if query:
filters.append(DBMessage.content.ilike(f"%{query}%"))
if filters:
db_query = db_query.filter(and_(*filters))
# Add pagination
offset = page * limit
db_query = db_query.order_by(desc(DBMessage.timestamp))
db_query = db_query.limit(limit).offset(offset)
# Execute query
db_results = db_query.all()
result = []
for db_msg, db_chat in db_results:
message = Message(
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
result.append(message)
if include_context and result:
# Add context for each message
messages_with_context = []
for msg in result:
context = get_message_context(msg.id, msg.chat_id, context_before, context_after)
messages_with_context.extend(context.before)
messages_with_context.append(context.message)
messages_with_context.extend(context.after)
return messages_with_context
return result
except Exception as e:
print(f"Database error: {e}")
return []
finally:
session.close()
def get_message_context(
message_id: int,
chat_id: int,
before: int = 5,
after: int = 5
) -> MessageContext:
"""Get context around a specific message."""
try:
session = Session()
# Get the target message first
result = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(DBMessage.id == message_id, DBMessage.chat_id == chat_id) \
.first()
if not result:
raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found")
db_msg, db_chat = result
target_message = Message(
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
# Get messages before
before_results = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(
DBMessage.chat_id == chat_id,
DBMessage.timestamp < target_message.timestamp
) \
.order_by(desc(DBMessage.timestamp)) \
.limit(before) \
.all()
before_messages = []
for db_msg, db_chat in before_results:
before_messages.append(Message(
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
))
# Get messages after
after_results = session.query(DBMessage, DBChat) \
.join(DBChat) \
.filter(
DBMessage.chat_id == chat_id,
DBMessage.timestamp > target_message.timestamp
) \
.order_by(DBMessage.timestamp) \
.limit(after) \
.all()
after_messages = []
for db_msg, db_chat in after_results:
after_messages.append(Message(
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
))
return MessageContext(
message=target_message,
before=before_messages,
after=after_messages
)
except Exception as e:
print(f"Database error: {e}")
raise
finally:
session.close()
def list_chats(
query: Optional[str] = None,
limit: int = 20,
page: int = 0,
chat_type: Optional[str] = None,
sort_by: str = "last_active"
) -> List[Chat]:
"""Get chats matching the specified criteria."""
try:
session = Session()
# Build base query
db_query = session.query(DBChat)
# Add filters
filters = []
if query:
filters.append(or_(
DBChat.title.ilike(f"%{query}%"),
DBChat.username.ilike(f"%{query}%")
))
if chat_type:
filters.append(DBChat.type == chat_type)
if filters:
db_query = db_query.filter(and_(*filters))
# Add sorting
if sort_by == "last_active":
db_query = db_query.order_by(desc(DBChat.last_message_time))
else:
db_query = db_query.order_by(DBChat.title)
# Add pagination
offset = page * limit
db_query = db_query.limit(limit).offset(offset)
# Execute query
db_chats = db_query.all()
result = []
for db_chat in db_chats:
chat = Chat(
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
result.append(chat)
return result
except Exception as e:
print(f"Database error: {e}")
return []
finally:
session.close()
def get_chat(chat_id: int) -> Optional[Chat]:
"""Get chat metadata by ID."""
try:
session = Session()
db_chat = session.query(DBChat).filter(DBChat.id == chat_id).first()
if not db_chat:
return None
return Chat(
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
except Exception as e:
print(f"Database error: {e}")
return None
finally:
session.close()
def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]:
"""Get direct chat metadata by contact ID."""
try:
session = Session()
db_chat = session.query(DBChat).filter(
DBChat.id == contact_id,
DBChat.type == 'user'
).first()
if not db_chat:
return None
return Chat(
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
except Exception as e:
print(f"Database error: {e}")
return None
finally:
session.close()
def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]:
"""Get all chats involving the contact.
Args:
contact_id: The contact's ID to search for
limit: Maximum number of chats to return (default 20)
page: Page number for pagination (default 0)
"""
try:
session = Session()
# Using a subquery to get distinct chats for the contact
db_chats = session.query(DBChat).join(DBMessage, DBChat.id == DBMessage.chat_id).filter(
or_(
DBMessage.sender_id == contact_id,
DBChat.id == contact_id
)
).distinct().order_by(desc(DBChat.last_message_time)).limit(limit).offset(page * limit).all()
result = []
for db_chat in db_chats:
chat = Chat(
id=db_chat.id,
title=db_chat.title,
username=db_chat.username,
type=db_chat.type,
last_message_time=db_chat.last_message_time
)
result.append(chat)
return result
except Exception as e:
print(f"Database error: {e}")
return []
finally:
session.close()
def get_last_interaction(contact_id: int) -> Optional[Message]:
"""Get most recent message involving the contact."""
try:
session = Session()
result = session.query(DBMessage, DBChat).join(DBChat).filter(
or_(
DBMessage.sender_id == contact_id,
DBChat.id == contact_id
)
).order_by(desc(DBMessage.timestamp)).first()
if not result:
return None
db_msg, db_chat = result
return Message(
id=db_msg.id,
chat_id=db_msg.chat_id,
chat_title=db_chat.title,
sender_name=db_msg.sender_name,
content=db_msg.content,
timestamp=db_msg.timestamp,
is_from_me=db_msg.is_from_me,
sender_id=db_msg.sender_id
)
except Exception as e:
print(f"Database error: {e}")
return None
finally:
session.close()