414 lines
No EOL
12 KiB
Python
414 lines
No EOL
12 KiB
Python
"""
|
|
Database operations for retrieving and managing Telegram data.
|
|
Uses SQLAlchemy ORM for database access instead of raw SQL.
|
|
"""
|
|
|
|
from sqlalchemy import create_engine, or_, and_, desc
|
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
from datetime import datetime
|
|
from typing import Optional, List, Tuple
|
|
|
|
from . import MESSAGES_DB_PATH
|
|
from .models import Message, Chat, Contact, MessageContext
|
|
|
|
# Initialize SQLAlchemy engine and session
|
|
engine = create_engine(
|
|
f"sqlite:///{MESSAGES_DB_PATH}",
|
|
connect_args={"check_same_thread": False} # Needed for SQLite
|
|
)
|
|
SessionFactory = sessionmaker(bind=engine)
|
|
Session = scoped_session(SessionFactory)
|
|
|
|
# Import SQLAlchemy models from telegram-bridge
|
|
import sys
|
|
import os
|
|
|
|
# Add the parent directory to path to find telegram-bridge
|
|
bridge_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'telegram-bridge')
|
|
sys.path.append(bridge_path)
|
|
|
|
# Import models directly from telegram-bridge
|
|
from database.models import Chat as DBChat, Message as DBMessage
|
|
|
|
def search_contacts(query: str) -> List[Contact]:
|
|
"""Search contacts by name or username."""
|
|
try:
|
|
session = Session()
|
|
|
|
# Search in chats where type is 'user'
|
|
db_contacts = session.query(DBChat).filter(
|
|
DBChat.type == 'user',
|
|
or_(
|
|
DBChat.title.ilike(f"%{query}%"),
|
|
DBChat.username.ilike(f"%{query}%")
|
|
)
|
|
).order_by(DBChat.title).limit(50).all()
|
|
|
|
result = []
|
|
for db_contact in db_contacts:
|
|
contact = Contact(
|
|
id=db_contact.id,
|
|
username=db_contact.username,
|
|
name=db_contact.title
|
|
)
|
|
result.append(contact)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return []
|
|
finally:
|
|
session.close()
|
|
|
|
def list_messages(
|
|
date_range: Optional[Tuple[datetime, datetime]] = None,
|
|
sender_id: Optional[int] = None,
|
|
chat_id: Optional[int] = None,
|
|
query: Optional[str] = None,
|
|
limit: int = 20,
|
|
page: int = 0,
|
|
include_context: bool = True,
|
|
context_before: int = 1,
|
|
context_after: int = 1
|
|
) -> List[Message]:
|
|
"""Get messages matching the specified criteria with optional context."""
|
|
try:
|
|
session = Session()
|
|
|
|
# Build base query with join
|
|
db_query = session.query(DBMessage, DBChat).join(DBChat)
|
|
|
|
# Add filters
|
|
filters = []
|
|
if date_range:
|
|
filters.append(and_(
|
|
DBMessage.timestamp >= date_range[0],
|
|
DBMessage.timestamp <= date_range[1]
|
|
))
|
|
|
|
if sender_id:
|
|
filters.append(DBMessage.sender_id == sender_id)
|
|
|
|
if chat_id:
|
|
filters.append(DBMessage.chat_id == chat_id)
|
|
|
|
if query:
|
|
filters.append(DBMessage.content.ilike(f"%{query}%"))
|
|
|
|
if filters:
|
|
db_query = db_query.filter(and_(*filters))
|
|
|
|
# Add pagination
|
|
offset = page * limit
|
|
db_query = db_query.order_by(desc(DBMessage.timestamp))
|
|
db_query = db_query.limit(limit).offset(offset)
|
|
|
|
# Execute query
|
|
db_results = db_query.all()
|
|
|
|
result = []
|
|
for db_msg, db_chat in db_results:
|
|
message = Message(
|
|
id=db_msg.id,
|
|
chat_id=db_msg.chat_id,
|
|
chat_title=db_chat.title,
|
|
sender_name=db_msg.sender_name,
|
|
content=db_msg.content,
|
|
timestamp=db_msg.timestamp,
|
|
is_from_me=db_msg.is_from_me,
|
|
sender_id=db_msg.sender_id
|
|
)
|
|
result.append(message)
|
|
|
|
if include_context and result:
|
|
# Add context for each message
|
|
messages_with_context = []
|
|
for msg in result:
|
|
context = get_message_context(msg.id, msg.chat_id, context_before, context_after)
|
|
messages_with_context.extend(context.before)
|
|
messages_with_context.append(context.message)
|
|
messages_with_context.extend(context.after)
|
|
return messages_with_context
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return []
|
|
finally:
|
|
session.close()
|
|
|
|
def get_message_context(
|
|
message_id: int,
|
|
chat_id: int,
|
|
before: int = 5,
|
|
after: int = 5
|
|
) -> MessageContext:
|
|
"""Get context around a specific message."""
|
|
try:
|
|
session = Session()
|
|
|
|
# Get the target message first
|
|
result = session.query(DBMessage, DBChat) \
|
|
.join(DBChat) \
|
|
.filter(DBMessage.id == message_id, DBMessage.chat_id == chat_id) \
|
|
.first()
|
|
|
|
if not result:
|
|
raise ValueError(f"Message with ID {message_id} in chat {chat_id} not found")
|
|
|
|
db_msg, db_chat = result
|
|
target_message = Message(
|
|
id=db_msg.id,
|
|
chat_id=db_msg.chat_id,
|
|
chat_title=db_chat.title,
|
|
sender_name=db_msg.sender_name,
|
|
content=db_msg.content,
|
|
timestamp=db_msg.timestamp,
|
|
is_from_me=db_msg.is_from_me,
|
|
sender_id=db_msg.sender_id
|
|
)
|
|
|
|
# Get messages before
|
|
before_results = session.query(DBMessage, DBChat) \
|
|
.join(DBChat) \
|
|
.filter(
|
|
DBMessage.chat_id == chat_id,
|
|
DBMessage.timestamp < target_message.timestamp
|
|
) \
|
|
.order_by(desc(DBMessage.timestamp)) \
|
|
.limit(before) \
|
|
.all()
|
|
|
|
before_messages = []
|
|
for db_msg, db_chat in before_results:
|
|
before_messages.append(Message(
|
|
id=db_msg.id,
|
|
chat_id=db_msg.chat_id,
|
|
chat_title=db_chat.title,
|
|
sender_name=db_msg.sender_name,
|
|
content=db_msg.content,
|
|
timestamp=db_msg.timestamp,
|
|
is_from_me=db_msg.is_from_me,
|
|
sender_id=db_msg.sender_id
|
|
))
|
|
|
|
# Get messages after
|
|
after_results = session.query(DBMessage, DBChat) \
|
|
.join(DBChat) \
|
|
.filter(
|
|
DBMessage.chat_id == chat_id,
|
|
DBMessage.timestamp > target_message.timestamp
|
|
) \
|
|
.order_by(DBMessage.timestamp) \
|
|
.limit(after) \
|
|
.all()
|
|
|
|
after_messages = []
|
|
for db_msg, db_chat in after_results:
|
|
after_messages.append(Message(
|
|
id=db_msg.id,
|
|
chat_id=db_msg.chat_id,
|
|
chat_title=db_chat.title,
|
|
sender_name=db_msg.sender_name,
|
|
content=db_msg.content,
|
|
timestamp=db_msg.timestamp,
|
|
is_from_me=db_msg.is_from_me,
|
|
sender_id=db_msg.sender_id
|
|
))
|
|
|
|
return MessageContext(
|
|
message=target_message,
|
|
before=before_messages,
|
|
after=after_messages
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
def list_chats(
|
|
query: Optional[str] = None,
|
|
limit: int = 20,
|
|
page: int = 0,
|
|
chat_type: Optional[str] = None,
|
|
sort_by: str = "last_active"
|
|
) -> List[Chat]:
|
|
"""Get chats matching the specified criteria."""
|
|
try:
|
|
session = Session()
|
|
|
|
# Build base query
|
|
db_query = session.query(DBChat)
|
|
|
|
# Add filters
|
|
filters = []
|
|
|
|
if query:
|
|
filters.append(or_(
|
|
DBChat.title.ilike(f"%{query}%"),
|
|
DBChat.username.ilike(f"%{query}%")
|
|
))
|
|
|
|
if chat_type:
|
|
filters.append(DBChat.type == chat_type)
|
|
|
|
if filters:
|
|
db_query = db_query.filter(and_(*filters))
|
|
|
|
# Add sorting
|
|
if sort_by == "last_active":
|
|
db_query = db_query.order_by(desc(DBChat.last_message_time))
|
|
else:
|
|
db_query = db_query.order_by(DBChat.title)
|
|
|
|
# Add pagination
|
|
offset = page * limit
|
|
db_query = db_query.limit(limit).offset(offset)
|
|
|
|
# Execute query
|
|
db_chats = db_query.all()
|
|
|
|
result = []
|
|
for db_chat in db_chats:
|
|
chat = Chat(
|
|
id=db_chat.id,
|
|
title=db_chat.title,
|
|
username=db_chat.username,
|
|
type=db_chat.type,
|
|
last_message_time=db_chat.last_message_time
|
|
)
|
|
result.append(chat)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return []
|
|
finally:
|
|
session.close()
|
|
|
|
def get_chat(chat_id: int) -> Optional[Chat]:
|
|
"""Get chat metadata by ID."""
|
|
try:
|
|
session = Session()
|
|
|
|
db_chat = session.query(DBChat).filter(DBChat.id == chat_id).first()
|
|
|
|
if not db_chat:
|
|
return None
|
|
|
|
return Chat(
|
|
id=db_chat.id,
|
|
title=db_chat.title,
|
|
username=db_chat.username,
|
|
type=db_chat.type,
|
|
last_message_time=db_chat.last_message_time
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return None
|
|
finally:
|
|
session.close()
|
|
|
|
def get_direct_chat_by_contact(contact_id: int) -> Optional[Chat]:
|
|
"""Get direct chat metadata by contact ID."""
|
|
try:
|
|
session = Session()
|
|
|
|
db_chat = session.query(DBChat).filter(
|
|
DBChat.id == contact_id,
|
|
DBChat.type == 'user'
|
|
).first()
|
|
|
|
if not db_chat:
|
|
return None
|
|
|
|
return Chat(
|
|
id=db_chat.id,
|
|
title=db_chat.title,
|
|
username=db_chat.username,
|
|
type=db_chat.type,
|
|
last_message_time=db_chat.last_message_time
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return None
|
|
finally:
|
|
session.close()
|
|
|
|
def get_contact_chats(contact_id: int, limit: int = 20, page: int = 0) -> List[Chat]:
|
|
"""Get all chats involving the contact.
|
|
|
|
Args:
|
|
contact_id: The contact's ID to search for
|
|
limit: Maximum number of chats to return (default 20)
|
|
page: Page number for pagination (default 0)
|
|
"""
|
|
try:
|
|
session = Session()
|
|
|
|
# Using a subquery to get distinct chats for the contact
|
|
db_chats = session.query(DBChat).join(DBMessage, DBChat.id == DBMessage.chat_id).filter(
|
|
or_(
|
|
DBMessage.sender_id == contact_id,
|
|
DBChat.id == contact_id
|
|
)
|
|
).distinct().order_by(desc(DBChat.last_message_time)).limit(limit).offset(page * limit).all()
|
|
|
|
result = []
|
|
for db_chat in db_chats:
|
|
chat = Chat(
|
|
id=db_chat.id,
|
|
title=db_chat.title,
|
|
username=db_chat.username,
|
|
type=db_chat.type,
|
|
last_message_time=db_chat.last_message_time
|
|
)
|
|
result.append(chat)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return []
|
|
finally:
|
|
session.close()
|
|
|
|
def get_last_interaction(contact_id: int) -> Optional[Message]:
|
|
"""Get most recent message involving the contact."""
|
|
try:
|
|
session = Session()
|
|
|
|
result = session.query(DBMessage, DBChat).join(DBChat).filter(
|
|
or_(
|
|
DBMessage.sender_id == contact_id,
|
|
DBChat.id == contact_id
|
|
)
|
|
).order_by(desc(DBMessage.timestamp)).first()
|
|
|
|
if not result:
|
|
return None
|
|
|
|
db_msg, db_chat = result
|
|
return Message(
|
|
id=db_msg.id,
|
|
chat_id=db_msg.chat_id,
|
|
chat_title=db_chat.title,
|
|
sender_name=db_msg.sender_name,
|
|
content=db_msg.content,
|
|
timestamp=db_msg.timestamp,
|
|
is_from_me=db_msg.is_from_me,
|
|
sender_id=db_msg.sender_id
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Database error: {e}")
|
|
return None
|
|
finally:
|
|
session.close() |