423 lines
16 KiB
Python
423 lines
16 KiB
Python
"""
|
|
Embedding generation and semantic search services.
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
from flask import current_app
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
EMBEDDINGS_AVAILABLE = True
|
|
except ImportError:
|
|
EMBEDDINGS_AVAILABLE = False
|
|
cosine_similarity = None
|
|
|
|
from src.database import db
|
|
from src.models import Recording, TranscriptChunk, InternalShare, RecordingTag
|
|
|
|
ENABLE_INTERNAL_SHARING = os.environ.get('ENABLE_INTERNAL_SHARING', 'false').lower() == 'true'
|
|
|
|
# Initialize embedding model (lazy loading)
|
|
_embedding_model = None
|
|
|
|
|
|
|
|
def get_embedding_model():
|
|
"""Get or initialize the sentence transformer model."""
|
|
global _embedding_model
|
|
|
|
if not EMBEDDINGS_AVAILABLE:
|
|
return None
|
|
|
|
if _embedding_model is None:
|
|
try:
|
|
_embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
current_app.logger.info("Embedding model loaded successfully")
|
|
except Exception as e:
|
|
current_app.logger.error(f"Failed to load embedding model: {e}")
|
|
return None
|
|
return _embedding_model
|
|
|
|
|
|
|
|
def chunk_transcription(transcription, max_chunk_length=500, overlap=50):
|
|
"""
|
|
Split transcription into overlapping chunks for better context retrieval.
|
|
|
|
Args:
|
|
transcription (str): The full transcription text
|
|
max_chunk_length (int): Maximum characters per chunk
|
|
overlap (int): Character overlap between chunks
|
|
|
|
Returns:
|
|
list: List of text chunks
|
|
"""
|
|
if not transcription or len(transcription) <= max_chunk_length:
|
|
return [transcription] if transcription else []
|
|
|
|
chunks = []
|
|
start = 0
|
|
|
|
while start < len(transcription):
|
|
end = start + max_chunk_length
|
|
|
|
# Try to break at sentence boundaries
|
|
if end < len(transcription):
|
|
# Look for sentence endings within the last 100 characters
|
|
sentence_end = -1
|
|
for i in range(max(0, end - 100), end):
|
|
if transcription[i] in '.!?':
|
|
# Check if it's not an abbreviation
|
|
if i + 1 < len(transcription) and transcription[i + 1].isspace():
|
|
sentence_end = i + 1
|
|
|
|
if sentence_end > start:
|
|
end = sentence_end
|
|
|
|
chunk = transcription[start:end].strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
# Move start position with overlap
|
|
start = max(start + 1, end - overlap)
|
|
|
|
# Prevent infinite loop
|
|
if start >= len(transcription):
|
|
break
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
def generate_embeddings(texts):
|
|
"""
|
|
Generate embeddings for a list of texts.
|
|
|
|
Args:
|
|
texts (list): List of text strings
|
|
|
|
Returns:
|
|
list: List of embedding vectors as numpy arrays, or empty list if embeddings unavailable
|
|
"""
|
|
if not EMBEDDINGS_AVAILABLE:
|
|
current_app.logger.warning("Embeddings not available - skipping embedding generation")
|
|
return []
|
|
|
|
model = get_embedding_model()
|
|
if not model or not texts:
|
|
return []
|
|
|
|
try:
|
|
embeddings = model.encode(texts)
|
|
return [embedding.astype(np.float32) for embedding in embeddings]
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error generating embeddings: {e}")
|
|
return []
|
|
|
|
|
|
|
|
def serialize_embedding(embedding):
|
|
"""Convert numpy array to binary for database storage."""
|
|
if embedding is None or not EMBEDDINGS_AVAILABLE:
|
|
return None
|
|
return embedding.tobytes()
|
|
|
|
|
|
|
|
def deserialize_embedding(binary_data):
|
|
"""Convert binary data back to numpy array."""
|
|
if binary_data is None or not EMBEDDINGS_AVAILABLE:
|
|
return None
|
|
return np.frombuffer(binary_data, dtype=np.float32)
|
|
|
|
|
|
|
|
def get_accessible_recording_ids(user_id):
|
|
"""
|
|
Get all recording IDs that a user has access to.
|
|
|
|
Includes:
|
|
- Recordings owned by the user
|
|
- Recordings shared with the user via InternalShare
|
|
- Recordings shared via group tags (if team membership exists)
|
|
|
|
Args:
|
|
user_id (int): User ID to check access for
|
|
|
|
Returns:
|
|
list: List of recording IDs the user can access
|
|
"""
|
|
accessible_ids = set()
|
|
|
|
# 1. User's own recordings
|
|
own_recordings = db.session.query(Recording.id).filter_by(user_id=user_id).all()
|
|
accessible_ids.update([r.id for r in own_recordings])
|
|
|
|
# 2. Internally shared recordings
|
|
if ENABLE_INTERNAL_SHARING:
|
|
shared_recordings = db.session.query(InternalShare.recording_id).filter_by(
|
|
shared_with_user_id=user_id
|
|
).all()
|
|
accessible_ids.update([r.recording_id for r in shared_recordings])
|
|
|
|
return list(accessible_ids)
|
|
|
|
|
|
|
|
def process_recording_chunks(recording_id):
|
|
"""
|
|
Process a recording by creating chunks and generating embeddings.
|
|
This should be called after a recording is transcribed.
|
|
"""
|
|
try:
|
|
recording = db.session.get(Recording, recording_id)
|
|
if not recording or not recording.transcription:
|
|
return False
|
|
|
|
# Delete existing chunks for this recording
|
|
TranscriptChunk.query.filter_by(recording_id=recording_id).delete()
|
|
|
|
# Create chunks
|
|
chunks = chunk_transcription(recording.transcription)
|
|
|
|
if not chunks:
|
|
return True
|
|
|
|
# Generate embeddings
|
|
embeddings = generate_embeddings(chunks)
|
|
|
|
# Store chunks in database
|
|
for i, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)):
|
|
chunk = TranscriptChunk(
|
|
recording_id=recording_id,
|
|
user_id=recording.user_id,
|
|
chunk_index=i,
|
|
content=chunk_text,
|
|
embedding=serialize_embedding(embedding) if embedding is not None else None
|
|
)
|
|
db.session.add(chunk)
|
|
|
|
db.session.commit()
|
|
current_app.logger.info(f"Created {len(chunks)} chunks for recording {recording_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error processing chunks for recording {recording_id}: {e}")
|
|
db.session.rollback()
|
|
return False
|
|
|
|
|
|
|
|
def basic_text_search_chunks(user_id, query, filters=None, top_k=5):
|
|
"""
|
|
Basic text search fallback when embeddings are not available.
|
|
Uses simple text matching instead of semantic search.
|
|
Searches across user's own recordings and recordings shared with them.
|
|
"""
|
|
try:
|
|
# Get all accessible recording IDs (own + shared)
|
|
accessible_recording_ids = get_accessible_recording_ids(user_id)
|
|
|
|
if not accessible_recording_ids:
|
|
return []
|
|
|
|
# Build base query for chunks from accessible recordings with eager loading
|
|
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
|
|
TranscriptChunk.recording_id.in_(accessible_recording_ids)
|
|
)
|
|
|
|
# Apply filters if provided
|
|
if filters:
|
|
if filters.get('tag_ids'):
|
|
chunks_query = chunks_query.join(Recording).join(
|
|
RecordingTag, Recording.id == RecordingTag.recording_id
|
|
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
|
|
|
|
if filters.get('speaker_names'):
|
|
# Filter by participants field in recordings instead of chunk speaker_name
|
|
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
|
|
chunks_query = chunks_query.join(Recording)
|
|
|
|
# Build OR conditions for each speaker name in participants
|
|
speaker_conditions = []
|
|
for speaker_name in filters['speaker_names']:
|
|
speaker_conditions.append(
|
|
Recording.participants.ilike(f'%{speaker_name}%')
|
|
)
|
|
|
|
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
|
|
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
|
|
|
|
if filters.get('recording_ids'):
|
|
chunks_query = chunks_query.filter(
|
|
TranscriptChunk.recording_id.in_(filters['recording_ids'])
|
|
)
|
|
|
|
if filters.get('date_from') or filters.get('date_to'):
|
|
chunks_query = chunks_query.join(Recording)
|
|
if filters.get('date_from'):
|
|
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
|
|
if filters.get('date_to'):
|
|
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
|
|
|
|
# Text search - filter stop words and rank by match count
|
|
stop_words = {'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
|
|
'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will',
|
|
'would', 'could', 'should', 'may', 'might', 'shall', 'can',
|
|
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
|
|
'up', 'about', 'into', 'through', 'during', 'before', 'after',
|
|
'and', 'but', 'or', 'nor', 'not', 'so', 'yet', 'both',
|
|
'it', 'its', 'this', 'that', 'these', 'those', 'what', 'which',
|
|
'who', 'whom', 'how', 'when', 'where', 'why',
|
|
'i', 'me', 'my', 'we', 'our', 'you', 'your', 'he', 'she',
|
|
'his', 'her', 'they', 'them', 'their'}
|
|
|
|
query_words = [w for w in query.lower().split() if w not in stop_words and len(w) > 1]
|
|
|
|
if not query_words:
|
|
# If all words were stop words, fall back to using original query words
|
|
query_words = [w for w in query.lower().split() if len(w) > 1]
|
|
|
|
if query_words:
|
|
from sqlalchemy import or_, func, case, literal
|
|
|
|
# Filter: match ANY keyword (OR) to get candidates
|
|
text_conditions = []
|
|
for word in query_words:
|
|
text_conditions.append(TranscriptChunk.content.ilike(f'%{word}%'))
|
|
chunks_query = chunks_query.filter(or_(*text_conditions))
|
|
|
|
# Fetch more candidates than needed so we can rank them
|
|
chunks = chunks_query.limit(top_k * 5).all()
|
|
|
|
# Rank by how many query words each chunk matches
|
|
scored_chunks = []
|
|
for chunk in chunks:
|
|
content_lower = chunk.content.lower()
|
|
match_count = sum(1 for word in query_words if word in content_lower)
|
|
score = match_count / len(query_words) # 0.0 to 1.0
|
|
scored_chunks.append((chunk, score))
|
|
|
|
# Sort by score descending, take top_k
|
|
scored_chunks.sort(key=lambda x: x[1], reverse=True)
|
|
return scored_chunks[:top_k]
|
|
|
|
# No usable query words
|
|
return []
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error in basic text search: {e}")
|
|
return []
|
|
|
|
|
|
|
|
def semantic_search_chunks(user_id, query, filters=None, top_k=5):
|
|
"""
|
|
Perform semantic search on transcript chunks with filtering.
|
|
Searches across user's own recordings and recordings shared with them.
|
|
|
|
Args:
|
|
user_id (int): User ID for permission filtering
|
|
query (str): Search query
|
|
filters (dict): Optional filters for tags, speakers, dates, recording_ids
|
|
top_k (int): Number of top chunks to return
|
|
|
|
Returns:
|
|
list: List of relevant chunks with similarity scores
|
|
"""
|
|
try:
|
|
# If embeddings are not available, fall back to basic text search
|
|
if not EMBEDDINGS_AVAILABLE:
|
|
current_app.logger.info("Embeddings not available - using basic text search as fallback")
|
|
return basic_text_search_chunks(user_id, query, filters, top_k)
|
|
|
|
# Generate embedding for the query
|
|
model = get_embedding_model()
|
|
if not model:
|
|
return basic_text_search_chunks(user_id, query, filters, top_k)
|
|
|
|
query_embedding = model.encode([query])[0]
|
|
|
|
# Get all accessible recording IDs (own + shared)
|
|
accessible_recording_ids = get_accessible_recording_ids(user_id)
|
|
|
|
if not accessible_recording_ids:
|
|
return []
|
|
|
|
# Build base query for chunks from accessible recordings with eager loading
|
|
chunks_query = TranscriptChunk.query.options(joinedload(TranscriptChunk.recording)).filter(
|
|
TranscriptChunk.recording_id.in_(accessible_recording_ids)
|
|
)
|
|
|
|
# Apply filters if provided
|
|
if filters:
|
|
if filters.get('tag_ids'):
|
|
# Join with recordings that have specified tags
|
|
chunks_query = chunks_query.join(Recording).join(
|
|
RecordingTag, Recording.id == RecordingTag.recording_id
|
|
).filter(RecordingTag.tag_id.in_(filters['tag_ids']))
|
|
|
|
if filters.get('speaker_names'):
|
|
# Filter by participants field in recordings instead of chunk speaker_name
|
|
if not any(hasattr(desc, 'name') and desc.name == 'recording' for desc in chunks_query.column_descriptions):
|
|
chunks_query = chunks_query.join(Recording)
|
|
|
|
# Build OR conditions for each speaker name in participants
|
|
speaker_conditions = []
|
|
for speaker_name in filters['speaker_names']:
|
|
speaker_conditions.append(
|
|
Recording.participants.ilike(f'%{speaker_name}%')
|
|
)
|
|
|
|
chunks_query = chunks_query.filter(db.or_(*speaker_conditions))
|
|
current_app.logger.info(f"Applied speaker filter for: {filters['speaker_names']}")
|
|
|
|
if filters.get('recording_ids'):
|
|
chunks_query = chunks_query.filter(
|
|
TranscriptChunk.recording_id.in_(filters['recording_ids'])
|
|
)
|
|
|
|
if filters.get('date_from') or filters.get('date_to'):
|
|
chunks_query = chunks_query.join(Recording)
|
|
if filters.get('date_from'):
|
|
chunks_query = chunks_query.filter(Recording.meeting_date >= filters['date_from'])
|
|
if filters.get('date_to'):
|
|
chunks_query = chunks_query.filter(Recording.meeting_date <= filters['date_to'])
|
|
|
|
# Get chunks that have embeddings
|
|
chunks = chunks_query.filter(TranscriptChunk.embedding.isnot(None)).all()
|
|
|
|
if not chunks:
|
|
return []
|
|
|
|
# Calculate similarities
|
|
chunk_similarities = []
|
|
for chunk in chunks:
|
|
try:
|
|
chunk_embedding = deserialize_embedding(chunk.embedding)
|
|
if chunk_embedding is not None:
|
|
similarity = cosine_similarity(
|
|
query_embedding.reshape(1, -1),
|
|
chunk_embedding.reshape(1, -1)
|
|
)[0][0]
|
|
chunk_similarities.append((chunk, float(similarity)))
|
|
except Exception as e:
|
|
current_app.logger.warning(f"Error calculating similarity for chunk {chunk.id}: {e}")
|
|
continue
|
|
|
|
# Sort by similarity and return top k
|
|
chunk_similarities.sort(key=lambda x: x[1], reverse=True)
|
|
return chunk_similarities[:top_k]
|
|
|
|
except Exception as e:
|
|
current_app.logger.error(f"Error in semantic search: {e}")
|
|
return []
|
|
|
|
# --- Helper Functions for Document Processing ---
|
|
|
|
|
|
|