fixed querying

This commit is contained in:
bolade
2025-10-28 20:54:15 +01:00
parent ff0010019e
commit bb03f6ade4
7 changed files with 270 additions and 161 deletions
+131 -75
View File
@@ -1,29 +1,24 @@
import json
import hashlib
import logging
import os
from typing import List, Optional
from db.db import DATABASE_URL, get_db
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
PaginatedResponse,
SectorMinimal,
)
from sqlalchemy import text
from sqlalchemy.orm import selectinload
from services.compatibility_score import calculate_project_investor_compatibility
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class QueryProcessor:
@@ -34,89 +29,150 @@ class QueryProcessor:
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only fund IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=100)
+ "\n\n=== IMPORTANT TERMINOLOGY ==="
+ "\n- When users say 'investors' or 'find me investors', they mean FUNDS"
+ "\n- Always query the 'funds' table for investment opportunities"
+ "\n- The 'investors' table is for parent company information only"
+ "\n- Relationship: investors (1) -> (many) funds"
+ "\n\n=== YOUR TASK ==="
+ "\nReturn ONLY fund IDs (funds.id) that match the user's criteria."
+ "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)"
+ "\nNo explanations, no other data."
+ "\n\n=== QUERY GUIDELINES ==="
+ "\n1. For geographic searches: use funds.geographic_focus"
+ "\n2. For sector searches: JOIN with fund_sectors table"
+ "\n3. For stage searches: JOIN with fund_investment_stages table"
+ "\n4. Return ALL matching fund IDs, not just the first few"
+ "\n5. If no results: respond with 'NO_RESULTS'"
+ "\n6. Never repeat the same failed query"
+ "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ==="
+ "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)"
+ "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')"
+ "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'"
+ "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc."
+ "\n- Examples:"
+ "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'"
+ "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'"
+ "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'"
+ "\n- Be INCLUSIVE: capture all relevant regional variations"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
# Query cache for performance
self.query_cache = {}
# SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
Database Schema:
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
- fund_sectors: fund_id, sector_id
- fund_investment_stages: fund_id, stage_id
- sectors: id, name
- investment_stages: id, name
- investors: id, name, aum
IMPORTANT RULES:
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
- If no geography specified, DON'T filter by geography
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
- If stage not specified, include ALL funds
4. For sectors: Use LEFT JOIN and include related terms with OR
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
5. For check size filters (be flexible with ranges):
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
6. Use LEFT JOIN for stages and sectors so funds without tags still match
7. Use DISTINCT to avoid duplicates from joins
8. Be INCLUSIVE - use OR conditions to cast a wider net
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
10. Return a single, complete SELECT query
Example Queries:
Q: "Seed stage investors in Europe"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
Q: "Fintech investors with check size under 5 million"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
Q: "Seed stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
Q: "Growth stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
Q: "AI investors in America"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
Q: "Healthcare investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
Return ONLY the SQL query, no explanations or markdown.""",
),
("user", "{question}"),
]
)
def _get_cache_key(self, question: str) -> str:
"""Generate cache key from normalized question."""
return hashlib.md5(question.lower().strip().encode()).hexdigest()
def process_query(
self, question: str, project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Process a query using the LLM and return investment response data.
"""Process a query by generating and executing SQL directly.
Args:
question: The natural language query to process
project_id: Optional project ID for compatibility scoring
"""
# Let the LLM handle all database interactions and filtering to get fund IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
config={"recursion_limit": 50},
)
cache_key = self._get_cache_key(question)
# Extract the actual message content
logger.info(f"{response}")
final_message_content = response["messages"][-1].content
logger.info(f"AI Response: \n{final_message_content}")
# Extract fund IDs from the AI response
fund_ids = self._extract_fund_ids_from_response(final_message_content)
# Check cache first
if cache_key in self.query_cache:
sql_query = self.query_cache[cache_key]
logger.info(f"Using cached SQL: {sql_query}")
else:
# Generate SQL query
messages = self.sql_prompt.format_messages(question=question)
response = self.llm.invoke(messages)
sql_query = response.content.strip()
# Fetch full fund data with investor relationships using the IDs
return self._fetch_funds_by_ids(fund_ids, project_id)
# Clean up SQL (remove markdown code blocks if present)
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract fund IDs from AI response."""
import re
# Cache the query
self.query_cache[cache_key] = sql_query
logger.info(f"Generated SQL: {sql_query}")
fund_ids = []
# Execute query to get fund IDs
db_session = next(get_db())
try:
# Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response)
fund_ids = [int(num) for num in numbers]
# Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches:
fund_ids = [int(id_str) for id_str in id_matches]
result = db_session.execute(text(sql_query))
fund_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
)
return self._fetch_funds_by_ids(fund_ids, project_id)
except Exception as e:
print(f"Error extracting IDs from response: {e}")
return []
return fund_ids
logger.error(f"SQL execution error: {e}")
logger.error(f"Failed SQL: {sql_query}")
# Return empty result
return PaginatedResponse(
items=[], total=0, page=1, page_size=10, total_pages=0
)
finally:
db_session.close()
def _fetch_funds_by_ids(
self, fund_ids: List[int], project_id: Optional[int] = None