Files
dictia-public/src/services/embeddings.py

423 lines
16 KiB
Python

"""
Embedding generation and semantic search services.
"""
import os
import numpy as np
from flask import current_app
from sqlalchemy.orm import joinedload
try:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
EMBEDDINGS_AVAILABLE = True
except ImportError:
EMBEDDINGS_AVAILABLE = False
cosine_similarity = None
from src.database import db
from src.models import Recording, TranscriptChunk, InternalShare, RecordingTag
ENABLE_INTERNAL_SHARING = os.environ.get('ENABLE_INTERNAL_SHARING', 'false').lower() == 'true'
# Initialize embedding model (lazy loading)
_embedding_model = None
def get_embedding_model():
"""Get or initialize the sentence transformer model."""
global _embedding_model
if not EMBEDDINGS_AVAILABLE:
return None
if _embedding_model is None:
try:
_embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
current_app.logger.info("Embedding model loaded successfully")
except Exception as e:
current_app.logger.error(f"Failed to load embedding model: {e}")
return None
return _embedding_model
def chunk_transcription(transcription, max_chunk_length=500, overlap=50):
"""
Split transcription into overlapping chunks for better context retrieval.
Args:
transcription (str): The full transcription text
max_chunk_length (int): Maximum characters per chunk
overlap (int): Character overlap between chunks
Returns:
list: List of text chunks
"""
if not transcription or len(transcription) <= max_chunk_length:
return [transcription] if transcription else []
chunks = []
start = 0
while start < len(transcription):
end = start + max_chunk_length
# Try to break at sentence boundaries
if end < len(transcription):
# Look for sentence endings within the last 100 characters
sentence_end = -1
for i in range(max(0, end - 100), end):
if transcription[i] in '.!?':
# Check if it's not an abbreviation
if i + 1 < len(transcription) and transcription[i + 1].isspace():
sentence_end = i + 1
if sentence_end > start:
end = sentence_end
chunk = transcription[start:end].strip()
if chunk:
chunks.append(chunk)
# Move start position with overlap
start = max(start + 1, end - overlap)
# Prevent infinite loop
if start >= len(transcription):
break
return chunks
def generate_embeddings(texts):
"""
Generate embeddings for a list of texts.
Args:
texts (list): List of text strings
Returns:
list: List of embedding vectors as numpy arrays, or empty list if embeddings unavailable
"""
if not EMBEDDINGS_AVAILABLE:
current_app.logger.warning("Embeddings not available - skipping embedding generation")
return []
model = get_embedding_model()
if not model or not texts:
return []
try:
embeddings = model.encode(texts)
return [embedding.astype(np.float32) for embedding in embeddings]
except Exception as e:
current_app.logger.error(f"Error generating embeddings: {e}")
return []
def serialize_embedding(embedding):
"""Convert numpy array to binary for database storage."""
if embedding is None or not EMBEDDINGS_AVAILABLE:
return None
return embedding.tobytes()
def deserialize_embedding(binary_data):
"""Convert binary data back to numpy array."""
if binary_data is None or not EMBEDDINGS_AVAILABLE:
return None
return np.frombuffer(binary_data, dtype=np.float32)
def get_accessible_recording_ids(user_id):
"""
Get all recording IDs that a user has access to.
Includes:
- Recordings owned by the user
- Recordings shared with the user via InternalShare
- Recordings shared via group tags (if team membership exists)
Args:
user_id (int): User ID to check access for
Returns:
list: List of recording IDs the user can access
"""
accessible_ids = set()
# 1. User's own recordings
own_recordings = db.session.query(Recording.id).filter_by(user_id=user_id).all()
accessible_ids.update([r.id for r in own_recordings])
# 2. Internally shared recordings
if ENABLE_INTERNAL_SHARING:
shared_recordings = db.session.query(InternalShare.recording_id).filter_by(
shared_with_user_id=user_id
).all()
accessible_ids.update([r.recording_id for r in shared_recordings])
return list(accessible_ids)
def process_recording_chunks(recording_id):
"""
Process a recording by creating chunks and generating embeddings.
This should be called after a recording is transcribed.
"""
try:
recording = db.session.get(Recording, recording_id)
if not recording or not recording.transcription:
return False
# Delete existing chunks for this recording
TranscriptChunk.query.filter_by(recording_id=recording_id).delete()
# Create chunks
chunks = chunk_transcription(recording.transcription)
if not chunks:
return True
# Generate embeddings
embeddings = generate_embeddings(chunks)
# Store chunks in database
for i, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)):
chunk = TranscriptChunk(
recording_id=recording_id,
user_id=recording.user_id,
chunk_index=i,
content=chunk_text,
embedding=serialize_embedding(embedding) if embedding is not None else None
)
db.session.add(chunk)
db.session.commit()
current_app.logger.info(f"Created {len(chunks)} chunks for recording {recording_id}")
return True
except Exception as e:
current_app.logger.error(f"Error processing chunks for recording {recording_id}: {e}")
db.session.rollback()
return False
def basic_text_search_chunks(user_id, query, filters=None, top_k=5):
"""
Basic text search fallback when embeddings are not available.
Uses simple text matching instead of semantic search.
Searches across user's own recordings and recordings shared with them.
"""
try:
# Get all accessible recording IDs (own + shared)
accessible_recording_ids = get_accessible_recording_ids(user_id)
if not accessible_recording_ids:
return []
# Build base query for chunks from accessible recordings with eager loading
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
TranscriptChunk.recording_id.in_(accessible_recording_ids)
)
# Apply filters if provided
if filters:
if filters.get('tag_ids'):
chunks_query = chunks_query.join(Recording).join(
RecordingTag, Recording.id == RecordingTag.recording_id
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
if filters.get('speaker_names'):
# Filter by participants field in recordings instead of chunk speaker_name
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
chunks_query = chunks_query.join(Recording)
# Build OR conditions for each speaker name in participants
speaker_conditions = []
for speaker_name in filters['speaker_names']:
speaker_conditions.append(
Recording.participants.ilike(f'%{speaker_name}%')
)
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
if filters.get('recording_ids'):
chunks_query = chunks_query.filter(
TranscriptChunk.recording_id.in_(filters['recording_ids'])
)
if filters.get('date_from') or filters.get('date_to'):
chunks_query = chunks_query.join(Recording)
if filters.get('date_from'):
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
if filters.get('date_to'):
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
# Text search - filter stop words and rank by match count
stop_words = {'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
'would', 'could', 'should', 'may', 'might', 'shall', 'can',
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
'up', 'about', 'into', 'through', 'during', 'before', 'after',
'and', 'but', 'or', 'nor', 'not', 'so', 'yet', 'both',
'it', 'its', 'this', 'that', 'these', 'those', 'what', 'which',
'who', 'whom', 'how', 'when', 'where', 'why',
'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'she',
'his', 'her', 'they', 'them', 'their'}
query_words = [w for w in query.lower().split() if w not in stop_words and len(w) > 1]
if not query_words:
# If all words were stop words, fall back to using original query words
query_words = [w for w in query.lower().split() if len(w) > 1]
if query_words:
from sqlalchemy import or_, func, case, literal
# Filter: match ANY keyword (OR) to get candidates
text_conditions = []
for word in query_words:
text_conditions.append(TranscriptChunk.content.ilike(f'%{word}%'))
chunks_query = chunks_query.filter(or_(*text_conditions))
# Fetch more candidates than needed so we can rank them
chunks = chunks_query.limit(top_k * 5).all()
# Rank by how many query words each chunk matches
scored_chunks = []
for chunk in chunks:
content_lower = chunk.content.lower()
match_count = sum(1 for word in query_words if word in content_lower)
score = match_count / len(query_words) # 0.0 to 1.0
scored_chunks.append((chunk, score))
# Sort by score descending, take top_k
scored_chunks.sort(key=lambda x: x[1], reverse=True)
return scored_chunks[:top_k]
# No usable query words
return []
except Exception as e:
current_app.logger.error(f"Error in basic text search: {e}")
return []
def semantic_search_chunks(user_id, query, filters=None, top_k=5):
"""
Perform semantic search on transcript chunks with filtering.
Searches across user's own recordings and recordings shared with them.
Args:
user_id (int): User ID for permission filtering
query (str): Search query
filters (dict): Optional filters for tags, speakers, dates, recording_ids
top_k (int): Number of top chunks to return
Returns:
list: List of relevant chunks with similarity scores
"""
try:
# If embeddings are not available, fall back to basic text search
if not EMBEDDINGS_AVAILABLE:
current_app.logger.info("Embeddings not available - using basic text search as fallback")
return basic_text_search_chunks(user_id, query, filters, top_k)
# Generate embedding for the query
model = get_embedding_model()
if not model:
return basic_text_search_chunks(user_id, query, filters, top_k)
query_embedding = model.encode([query])[0]
# Get all accessible recording IDs (own + shared)
accessible_recording_ids = get_accessible_recording_ids(user_id)
if not accessible_recording_ids:
return []
# Build base query for chunks from accessible recordings with eager loading
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
TranscriptChunk.recording_id.in_(accessible_recording_ids)
)
# Apply filters if provided
if filters:
if filters.get('tag_ids'):
# Join with recordings that have specified tags
chunks_query = chunks_query.join(Recording).join(
RecordingTag, Recording.id == RecordingTag.recording_id
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
if filters.get('speaker_names'):
# Filter by participants field in recordings instead of chunk speaker_name
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
chunks_query = chunks_query.join(Recording)
# Build OR conditions for each speaker name in participants
speaker_conditions = []
for speaker_name in filters['speaker_names']:
speaker_conditions.append(
Recording.participants.ilike(f'%{speaker_name}%')
)
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
if filters.get('recording_ids'):
chunks_query = chunks_query.filter(
TranscriptChunk.recording_id.in_(filters['recording_ids'])
)
if filters.get('date_from') or filters.get('date_to'):
chunks_query = chunks_query.join(Recording)
if filters.get('date_from'):
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
if filters.get('date_to'):
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
# Get chunks that have embeddings
chunks = chunks_query.filter(TranscriptChunk.embedding.isnot(None)).all()
if not chunks:
return []
# Calculate similarities
chunk_similarities = []
for chunk in chunks:
try:
chunk_embedding = deserialize_embedding(chunk.embedding)
if chunk_embedding is not None:
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
chunk_embedding.reshape(1, -1)
)[0][0]
chunk_similarities.append((chunk, float(similarity)))
except Exception as e:
current_app.logger.warning(f"Error calculating similarity for chunk {chunk.id}: {e}")
continue
# Sort by similarity and return top k
chunk_similarities.sort(key=lambda x: x[1], reverse=True)
return chunk_similarities[:top_k]
except Exception as e:
current_app.logger.error(f"Error in semantic search: {e}")
return []
# --- Helper Functions for Document Processing ---