# Reranking services import cohere from typing import List, Dict, Any from loguru import logger from tenacity import retry, stop_after_attempt, wait_exponential from app.core.config import settings from app.core.models import ComplianceIssue, ComplianceReport, ComplianceLevel class RankingService: """Service for ranking and prioritizing compliance issues using Cohere Reranker.""" def __init__(self): """Initialize the ranking service with the Cohere client.""" self.cohere_client = cohere.Client(settings.COHERE_API_KEY) self.reranker_model = settings.RERANKER_MODEL async def prioritize_issues(self, report: ComplianceReport, max_issues: int = 10) -> ComplianceReport: """ Prioritize and rank compliance issues in a report. Args: report: The compliance report with issues to prioritize max_issues: Maximum number of issues to include in the final report Returns: Updated compliance report with prioritized issues """ if not report.issues or len(report.issues) <= 1: # No need to rank if there's only 0 or 1 issues return report try: # Prepare issues for ranking issue_texts = [ f"Section: {issue.section}. " f"Level: {issue.level.value}. " f"Description: {issue.description}. " f"Recommendation: {issue.recommendation}" for issue in report.issues ] # Query object representing what we're looking for query = "critical compliance issues that require immediate attention" # Rerank issues based on relevance to the query reranked_issues = await self._rerank_issues(query, issue_texts) # Sort issues based on: # 1. Compliance level (critical > major > minor > info) # 2. Reranker relevance score sorted_issues = [] level_scores = { ComplianceLevel.CRITICAL: 4, ComplianceLevel.MAJOR: 3, ComplianceLevel.MINOR: 2, ComplianceLevel.INFO: 1 } # Combine original issues with reranked scores combined_issues = [] for i, issue in enumerate(report.issues): rerank_score = next((item["relevance_score"] for item in reranked_issues if item["index"] == i), 0.0) # Calculate combined score (level_score * 100 + rerank_score) # This ensures level is always the primary sorting factor level_score = level_scores.get(issue.level, 0) combined_score = (level_score * 100) + rerank_score combined_issues.append({ "issue": issue, "combined_score": combined_score, "rerank_score": rerank_score }) # Sort by combined score (descending) combined_issues.sort(key=lambda x: x["combined_score"], reverse=True) # Take top issues based on max_issues limit sorted_issues = [item["issue"] for item in combined_issues[:max_issues]] # Create updated report prioritized_report = ComplianceReport( report_id=report.report_id, document_id=report.document_id, timestamp=report.timestamp, compliance_score=report.compliance_score, summary=report.summary, issues=sorted_issues ) return prioritized_report except Exception as e: logger.error(f"Error prioritizing issues: {str(e)}") # If ranking fails, return the original report return report @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) async def _rerank_issues(self, query: str, issue_texts: List[str]) -> List[Dict[str, Any]]: """ Rerank issues using Cohere Reranker. Args: query: The search query to compare issues against issue_texts: List of issue descriptions to rank Returns: List of dictionaries with reranked issues and scores """ try: # Call Cohere Rerank endpoint response = self.cohere_client.rerank( model=self.reranker_model, query=query, documents=issue_texts, top_n=len(issue_texts) ) # Format results reranked_issues = [] for result in response.results: reranked_issues.append({ "index": result.index, # Original index in the issues list "relevance_score": result.relevance_score }) return reranked_issues except Exception as e: logger.error(f"Error calling Cohere Reranker: {str(e)}") # Return basic ranking if reranking fails return [{"index": i, "relevance_score": 1.0 - (i * 0.1)} for i in range(len(issue_texts))]