""" Embedding generation and semantic search services. """ import os import numpy as np from flask import current_app from sqlalchemy.orm import joinedload try: from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity EMBEDDINGS_AVAILABLE = True except ImportError: EMBEDDINGS_AVAILABLE = False cosine_similarity = None from src.database import db from src.models import Recording, TranscriptChunk, InternalShare, RecordingTag ENABLE_INTERNAL_SHARING = os.environ.get('ENABLE_INTERNAL_SHARING', 'false').lower() == 'true' # Initialize embedding model (lazy loading) _embedding_model = None def get_embedding_model(): """Get or initialize the sentence transformer model.""" global _embedding_model if not EMBEDDINGS_AVAILABLE: return None if _embedding_model is None: try: _embedding_model = SentenceTransformer('all-MiniLM-L6-v2') current_app.logger.info("Embedding model loaded successfully") except Exception as e: current_app.logger.error(f"Failed to load embedding model: {e}") return None return _embedding_model def chunk_transcription(transcription, max_chunk_length=500, overlap=50): """ Split transcription into overlapping chunks for better context retrieval. Args: transcription (str): The full transcription text max_chunk_length (int): Maximum characters per chunk overlap (int): Character overlap between chunks Returns: list: List of text chunks """ if not transcription or len(transcription) <= max_chunk_length: return [transcription] if transcription else [] chunks = [] start = 0 while start < len(transcription): end = start + max_chunk_length # Try to break at sentence boundaries if end < len(transcription): # Look for sentence endings within the last 100 characters sentence_end = -1 for i in range(max(0, end - 100), end): if transcription[i] in '.!?': # Check if it's not an abbreviation if i + 1 < len(transcription) and transcription[i + 1].isspace(): sentence_end = i + 1 if sentence_end > start: end = sentence_end chunk = transcription[start:end].strip() if chunk: chunks.append(chunk) # Move start position with overlap start = max(start + 1, end - overlap) # Prevent infinite loop if start >= len(transcription): break return chunks def generate_embeddings(texts): """ Generate embeddings for a list of texts. Args: texts (list): List of text strings Returns: list: List of embedding vectors as numpy arrays, or empty list if embeddings unavailable """ if not EMBEDDINGS_AVAILABLE: current_app.logger.warning("Embeddings not available - skipping embedding generation") return [] model = get_embedding_model() if not model or not texts: return [] try: embeddings = model.encode(texts) return [embedding.astype(np.float32) for embedding in embeddings] except Exception as e: current_app.logger.error(f"Error generating embeddings: {e}") return [] def serialize_embedding(embedding): """Convert numpy array to binary for database storage.""" if embedding is None or not EMBEDDINGS_AVAILABLE: return None return embedding.tobytes() def deserialize_embedding(binary_data): """Convert binary data back to numpy array.""" if binary_data is None or not EMBEDDINGS_AVAILABLE: return None return np.frombuffer(binary_data, dtype=np.float32) def get_accessible_recording_ids(user_id): """ Get all recording IDs that a user has access to. Includes: - Recordings owned by the user - Recordings shared with the user via InternalShare - Recordings shared via group tags (if team membership exists) Args: user_id (int): User ID to check access for Returns: list: List of recording IDs the user can access """ accessible_ids = set() # 1. User's own recordings own_recordings = db.session.query(Recording.id).filter_by(user_id=user_id).all() accessible_ids.update([r.id for r in own_recordings]) # 2. Internally shared recordings if ENABLE_INTERNAL_SHARING: shared_recordings = db.session.query(InternalShare.recording_id).filter_by( shared_with_user_id=user_id ).all() accessible_ids.update([r.recording_id for r in shared_recordings]) return list(accessible_ids) def process_recording_chunks(recording_id): """ Process a recording by creating chunks and generating embeddings. This should be called after a recording is transcribed. """ try: recording = db.session.get(Recording, recording_id) if not recording or not recording.transcription: return False # Delete existing chunks for this recording TranscriptChunk.query.filter_by(recording_id=recording_id).delete() # Create chunks chunks = chunk_transcription(recording.transcription) if not chunks: return True # Generate embeddings embeddings = generate_embeddings(chunks) # Store chunks in database for i, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)): chunk = TranscriptChunk( recording_id=recording_id, user_id=recording.user_id, chunk_index=i, content=chunk_text, embedding=serialize_embedding(embedding) if embedding is not None else None ) db.session.add(chunk) db.session.commit() current_app.logger.info(f"Created {len(chunks)} chunks for recording {recording_id}") return True except Exception as e: current_app.logger.error(f"Error processing chunks for recording {recording_id}: {e}") db.session.rollback() return False def basic_text_search_chunks(user_id, query, filters=None, top_k=5): """ Basic text search fallback when embeddings are not available. Uses simple text matching instead of semantic search. Searches across user's own recordings and recordings shared with them. """ try: # Get all accessible recording IDs (own + shared) accessible_recording_ids = get_accessible_recording_ids(user_id) if not accessible_recording_ids: return [] # Build base query for chunks from accessible recordings with eager loading chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter( TranscriptChunk.recording_id.in_(accessible_recording_ids) ) # Apply filters if provided if filters: if filters.get('tag_ids'): chunks_query = chunks_query.join(Recording).join( RecordingTag, Recording.id == RecordingTag.recording_id ).filter(RecordingTag.tag_id.in_(filters['tag_ids'])) if filters.get('speaker_names'): # Filter by participants field in recordings instead of chunk speaker_name if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions): chunks_query = chunks_query.join(Recording) # Build OR conditions for each speaker name in participants speaker_conditions = [] for speaker_name in filters['speaker_names']: speaker_conditions.append( Recording.participants.ilike(f'%{speaker_name}%') ) chunks_query = chunks_query.filter(db.or_(*speaker_conditions)) current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}") if filters.get('recording_ids'): chunks_query = chunks_query.filter( TranscriptChunk.recording_id.in_(filters['recording_ids']) ) if filters.get('date_from') or filters.get('date_to'): chunks_query = chunks_query.join(Recording) if filters.get('date_from'): chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from']) if filters.get('date_to'): chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to']) # Text search - filter stop words and rank by match count stop_words = {'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'shall', 'can', 'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before', 'after', 'and', 'but', 'or', 'nor', 'not', 'so', 'yet', 'both', 'it', 'its', 'this', 'that', 'these', 'those', 'what', 'which', 'who', 'whom', 'how', 'when', 'where', 'why', 'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'she', 'his', 'her', 'they', 'them', 'their'} query_words = [w for w in query.lower().split() if w not in stop_words and len(w) > 1] if not query_words: # If all words were stop words, fall back to using original query words query_words = [w for w in query.lower().split() if len(w) > 1] if query_words: from sqlalchemy import or_, func, case, literal # Filter: match ANY keyword (OR) to get candidates text_conditions = [] for word in query_words: text_conditions.append(TranscriptChunk.content.ilike(f'%{word}%')) chunks_query = chunks_query.filter(or_(*text_conditions)) # Fetch more candidates than needed so we can rank them chunks = chunks_query.limit(top_k * 5).all() # Rank by how many query words each chunk matches scored_chunks = [] for chunk in chunks: content_lower = chunk.content.lower() match_count = sum(1 for word in query_words if word in content_lower) score = match_count / len(query_words) # 0.0 to 1.0 scored_chunks.append((chunk, score)) # Sort by score descending, take top_k scored_chunks.sort(key=lambda x: x[1], reverse=True) return scored_chunks[:top_k] # No usable query words return [] except Exception as e: current_app.logger.error(f"Error in basic text search: {e}") return [] def semantic_search_chunks(user_id, query, filters=None, top_k=5): """ Perform semantic search on transcript chunks with filtering. Searches across user's own recordings and recordings shared with them. Args: user_id (int): User ID for permission filtering query (str): Search query filters (dict): Optional filters for tags, speakers, dates, recording_ids top_k (int): Number of top chunks to return Returns: list: List of relevant chunks with similarity scores """ try: # If embeddings are not available, fall back to basic text search if not EMBEDDINGS_AVAILABLE: current_app.logger.info("Embeddings not available - using basic text search as fallback") return basic_text_search_chunks(user_id, query, filters, top_k) # Generate embedding for the query model = get_embedding_model() if not model: return basic_text_search_chunks(user_id, query, filters, top_k) query_embedding = model.encode([query])[0] # Get all accessible recording IDs (own + shared) accessible_recording_ids = get_accessible_recording_ids(user_id) if not accessible_recording_ids: return [] # Build base query for chunks from accessible recordings with eager loading chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter( TranscriptChunk.recording_id.in_(accessible_recording_ids) ) # Apply filters if provided if filters: if filters.get('tag_ids'): # Join with recordings that have specified tags chunks_query = chunks_query.join(Recording).join( RecordingTag, Recording.id == RecordingTag.recording_id ).filter(RecordingTag.tag_id.in_(filters['tag_ids'])) if filters.get('speaker_names'): # Filter by participants field in recordings instead of chunk speaker_name if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions): chunks_query = chunks_query.join(Recording) # Build OR conditions for each speaker name in participants speaker_conditions = [] for speaker_name in filters['speaker_names']: speaker_conditions.append( Recording.participants.ilike(f'%{speaker_name}%') ) chunks_query = chunks_query.filter(db.or_(*speaker_conditions)) current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}") if filters.get('recording_ids'): chunks_query = chunks_query.filter( TranscriptChunk.recording_id.in_(filters['recording_ids']) ) if filters.get('date_from') or filters.get('date_to'): chunks_query = chunks_query.join(Recording) if filters.get('date_from'): chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from']) if filters.get('date_to'): chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to']) # Get chunks that have embeddings chunks = chunks_query.filter(TranscriptChunk.embedding.isnot(None)).all() if not chunks: return [] # Calculate similarities chunk_similarities = [] for chunk in chunks: try: chunk_embedding = deserialize_embedding(chunk.embedding) if chunk_embedding is not None: similarity = cosine_similarity( query_embedding.reshape(1, -1), chunk_embedding.reshape(1, -1) )[0][0] chunk_similarities.append((chunk, float(similarity))) except Exception as e: current_app.logger.warning(f"Error calculating similarity for chunk {chunk.id}: {e}") continue # Sort by similarity and return top k chunk_similarities.sort(key=lambda x: x[1], reverse=True) return chunk_similarities[:top_k] except Exception as e: current_app.logger.error(f"Error in semantic search: {e}") return [] # --- Helper Functions for Document Processing ---