136 lines
5.3 KiB
Python
136 lines
5.3 KiB
Python
# Reranking services
|
|
import cohere
|
|
from typing import List, Dict, Any
|
|
from loguru import logger
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
from app.core.config import settings
|
|
from app.core.models import ComplianceIssue, ComplianceReport, ComplianceLevel
|
|
|
|
class RankingService:
|
|
"""Service for ranking and prioritizing compliance issues using Cohere Reranker."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the ranking service with the Cohere client."""
|
|
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
|
|
self.reranker_model = settings.RERANKER_MODEL
|
|
|
|
async def prioritize_issues(self, report: ComplianceReport, max_issues: int = 10) -> ComplianceReport:
|
|
"""
|
|
Prioritize and rank compliance issues in a report.
|
|
|
|
Args:
|
|
report: The compliance report with issues to prioritize
|
|
max_issues: Maximum number of issues to include in the final report
|
|
|
|
Returns:
|
|
Updated compliance report with prioritized issues
|
|
"""
|
|
if not report.issues or len(report.issues) <= 1:
|
|
# No need to rank if there's only 0 or 1 issues
|
|
return report
|
|
|
|
try:
|
|
# Prepare issues for ranking
|
|
issue_texts = [
|
|
f"Section: {issue.section}. "
|
|
f"Level: {issue.level.value}. "
|
|
f"Description: {issue.description}. "
|
|
f"Recommendation: {issue.recommendation}"
|
|
for issue in report.issues
|
|
]
|
|
|
|
# Query object representing what we're looking for
|
|
query = "critical compliance issues that require immediate attention"
|
|
|
|
# Rerank issues based on relevance to the query
|
|
reranked_issues = await self._rerank_issues(query, issue_texts)
|
|
|
|
# Sort issues based on:
|
|
# 1. Compliance level (critical > major > minor > info)
|
|
# 2. Reranker relevance score
|
|
sorted_issues = []
|
|
level_scores = {
|
|
ComplianceLevel.CRITICAL: 4,
|
|
ComplianceLevel.MAJOR: 3,
|
|
ComplianceLevel.MINOR: 2,
|
|
ComplianceLevel.INFO: 1
|
|
}
|
|
|
|
# Combine original issues with reranked scores
|
|
combined_issues = []
|
|
for i, issue in enumerate(report.issues):
|
|
rerank_score = next((item["relevance_score"] for item in reranked_issues
|
|
if item["index"] == i), 0.0)
|
|
|
|
# Calculate combined score (level_score * 100 + rerank_score)
|
|
# This ensures level is always the primary sorting factor
|
|
level_score = level_scores.get(issue.level, 0)
|
|
combined_score = (level_score * 100) + rerank_score
|
|
|
|
combined_issues.append({
|
|
"issue": issue,
|
|
"combined_score": combined_score,
|
|
"rerank_score": rerank_score
|
|
})
|
|
|
|
# Sort by combined score (descending)
|
|
combined_issues.sort(key=lambda x: x["combined_score"], reverse=True)
|
|
|
|
# Take top issues based on max_issues limit
|
|
sorted_issues = [item["issue"] for item in combined_issues[:max_issues]]
|
|
|
|
# Create updated report
|
|
prioritized_report = ComplianceReport(
|
|
report_id=report.report_id,
|
|
document_id=report.document_id,
|
|
timestamp=report.timestamp,
|
|
compliance_score=report.compliance_score,
|
|
summary=report.summary,
|
|
issues=sorted_issues
|
|
)
|
|
|
|
return prioritized_report
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error prioritizing issues: {str(e)}")
|
|
# If ranking fails, return the original report
|
|
return report
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
|
async def _rerank_issues(self, query: str, issue_texts: List[str]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank issues using Cohere Reranker.
|
|
|
|
Args:
|
|
query: The search query to compare issues against
|
|
issue_texts: List of issue descriptions to rank
|
|
|
|
Returns:
|
|
List of dictionaries with reranked issues and scores
|
|
"""
|
|
try:
|
|
# Call Cohere Rerank endpoint
|
|
response = self.cohere_client.rerank(
|
|
model=self.reranker_model,
|
|
query=query,
|
|
documents=issue_texts,
|
|
top_n=len(issue_texts)
|
|
)
|
|
|
|
# Format results
|
|
reranked_issues = []
|
|
for result in response.results:
|
|
reranked_issues.append({
|
|
"index": result.index, # Original index in the issues list
|
|
"relevance_score": result.relevance_score
|
|
})
|
|
|
|
return reranked_issues
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling Cohere Reranker: {str(e)}")
|
|
|
|
# Return basic ranking if reranking fails
|
|
return [{"index": i, "relevance_score": 1.0 - (i * 0.1)}
|
|
for i in range(len(issue_texts))] |