Initial commit
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
# Reranking services
|
||||
import cohere
|
||||
from typing import List, Dict, Any
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.models import ComplianceIssue, ComplianceReport, ComplianceLevel
|
||||
|
||||
class RankingService:
|
||||
"""Service for ranking and prioritizing compliance issues using Cohere Reranker."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the ranking service with the Cohere client."""
|
||||
self.cohere_client = cohere.Client(settings.COHERE_API_KEY)
|
||||
self.reranker_model = settings.RERANKER_MODEL
|
||||
|
||||
async def prioritize_issues(self, report: ComplianceReport, max_issues: int = 10) -> ComplianceReport:
|
||||
"""
|
||||
Prioritize and rank compliance issues in a report.
|
||||
|
||||
Args:
|
||||
report: The compliance report with issues to prioritize
|
||||
max_issues: Maximum number of issues to include in the final report
|
||||
|
||||
Returns:
|
||||
Updated compliance report with prioritized issues
|
||||
"""
|
||||
if not report.issues or len(report.issues) <= 1:
|
||||
# No need to rank if there's only 0 or 1 issues
|
||||
return report
|
||||
|
||||
try:
|
||||
# Prepare issues for ranking
|
||||
issue_texts = [
|
||||
f"Section: {issue.section}. "
|
||||
f"Level: {issue.level.value}. "
|
||||
f"Description: {issue.description}. "
|
||||
f"Recommendation: {issue.recommendation}"
|
||||
for issue in report.issues
|
||||
]
|
||||
|
||||
# Query object representing what we're looking for
|
||||
query = "critical compliance issues that require immediate attention"
|
||||
|
||||
# Rerank issues based on relevance to the query
|
||||
reranked_issues = await self._rerank_issues(query, issue_texts)
|
||||
|
||||
# Sort issues based on:
|
||||
# 1. Compliance level (critical > major > minor > info)
|
||||
# 2. Reranker relevance score
|
||||
sorted_issues = []
|
||||
level_scores = {
|
||||
ComplianceLevel.CRITICAL: 4,
|
||||
ComplianceLevel.MAJOR: 3,
|
||||
ComplianceLevel.MINOR: 2,
|
||||
ComplianceLevel.INFO: 1
|
||||
}
|
||||
|
||||
# Combine original issues with reranked scores
|
||||
combined_issues = []
|
||||
for i, issue in enumerate(report.issues):
|
||||
rerank_score = next((item["relevance_score"] for item in reranked_issues
|
||||
if item["index"] == i), 0.0)
|
||||
|
||||
# Calculate combined score (level_score * 100 + rerank_score)
|
||||
# This ensures level is always the primary sorting factor
|
||||
level_score = level_scores.get(issue.level, 0)
|
||||
combined_score = (level_score * 100) + rerank_score
|
||||
|
||||
combined_issues.append({
|
||||
"issue": issue,
|
||||
"combined_score": combined_score,
|
||||
"rerank_score": rerank_score
|
||||
})
|
||||
|
||||
# Sort by combined score (descending)
|
||||
combined_issues.sort(key=lambda x: x["combined_score"], reverse=True)
|
||||
|
||||
# Take top issues based on max_issues limit
|
||||
sorted_issues = [item["issue"] for item in combined_issues[:max_issues]]
|
||||
|
||||
# Create updated report
|
||||
prioritized_report = ComplianceReport(
|
||||
report_id=report.report_id,
|
||||
document_id=report.document_id,
|
||||
timestamp=report.timestamp,
|
||||
compliance_score=report.compliance_score,
|
||||
summary=report.summary,
|
||||
issues=sorted_issues
|
||||
)
|
||||
|
||||
return prioritized_report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error prioritizing issues: {str(e)}")
|
||||
# If ranking fails, return the original report
|
||||
return report
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
||||
async def _rerank_issues(self, query: str, issue_texts: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank issues using Cohere Reranker.
|
||||
|
||||
Args:
|
||||
query: The search query to compare issues against
|
||||
issue_texts: List of issue descriptions to rank
|
||||
|
||||
Returns:
|
||||
List of dictionaries with reranked issues and scores
|
||||
"""
|
||||
try:
|
||||
# Call Cohere Rerank endpoint
|
||||
response = self.cohere_client.rerank(
|
||||
model=self.reranker_model,
|
||||
query=query,
|
||||
documents=issue_texts,
|
||||
top_n=len(issue_texts)
|
||||
)
|
||||
|
||||
# Format results
|
||||
reranked_issues = []
|
||||
for result in response.results:
|
||||
reranked_issues.append({
|
||||
"index": result.index, # Original index in the issues list
|
||||
"relevance_score": result.relevance_score
|
||||
})
|
||||
|
||||
return reranked_issues
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Cohere Reranker: {str(e)}")
|
||||
|
||||
# Return basic ranking if reranking fails
|
||||
return [{"index": i, "relevance_score": 1.0 - (i * 0.1)}
|
||||
for i in range(len(issue_texts))]
|
||||
Reference in New Issue
Block a user