fixed querying
This commit is contained in:
Binary file not shown.
Binary file not shown.
+10
-1
@@ -44,6 +44,15 @@ class QueryRequest(BaseModel):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class CompanyQueryRequest(BaseModel):
|
||||||
|
question: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"question": "Find me companies in the fintech sector located in San Francisco."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def health():
|
def health():
|
||||||
@@ -120,7 +129,7 @@ async def query_investors(request: QueryRequest):
|
|||||||
@app.post(
|
@app.post(
|
||||||
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
|
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
|
||||||
)
|
)
|
||||||
async def query_companies(request: QueryRequest):
|
async def query_companies(request: CompanyQueryRequest):
|
||||||
"""
|
"""
|
||||||
Query companies using natural language.
|
Query companies using natural language.
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -1,21 +1,17 @@
|
|||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from db.db import DATABASE_URL, get_db
|
from db.db import get_db
|
||||||
from db.models import CompanyTable
|
from db.models import CompanyTable
|
||||||
from langchain import hub
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
||||||
from langchain_community.utilities import SQLDatabase
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.prebuilt import create_react_agent
|
|
||||||
from schemas.router_schemas import CompanyData, PaginatedResponse
|
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# Connect to SQLite
|
|
||||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
|
||||||
db = SQLDatabase.from_uri(DATABASE_URL)
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyQueryProcessor:
|
class CompanyQueryProcessor:
|
||||||
@@ -26,96 +22,144 @@ class CompanyQueryProcessor:
|
|||||||
model="openai/gpt-4o-mini",
|
model="openai/gpt-4o-mini",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
||||||
# Update system message to specifically request only company IDs
|
# Query cache for performance
|
||||||
system_message_updated = (
|
self.query_cache = {}
|
||||||
prompt_template.format(dialect="SQLite", top_k=5)
|
|
||||||
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
|
# SQL generation prompt
|
||||||
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
|
self.sql_prompt = ChatPromptTemplate.from_messages(
|
||||||
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
|
[
|
||||||
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
|
(
|
||||||
+ "\n- Do NOT add any explanations, just list the IDs"
|
"system",
|
||||||
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
|
"""You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements.
|
||||||
+ "\n\n=== QUERY GUIDELINES ==="
|
|
||||||
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
|
Database Schema:
|
||||||
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
|
- companies: id, name, industry, location, description, founded_year, website
|
||||||
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
|
- company_sector: company_id, sector_id
|
||||||
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
|
- sectors: id, name
|
||||||
+ "\n5. For investor-related: JOIN investor_companies table"
|
- investor_companies: investor_id, company_id
|
||||||
)
|
- investors: id, name, aum
|
||||||
self.agent = create_react_agent(
|
- team_members: id, company_id, name, title
|
||||||
model=self.llm,
|
|
||||||
tools=self.toolkit.get_tools(),
|
IMPORTANT RULES:
|
||||||
prompt=system_message_updated,
|
1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id
|
||||||
|
2. For industry: Check BOTH industry field AND sectors table with synonyms
|
||||||
|
- Use LEFT JOIN for sectors so companies without sector tags still match
|
||||||
|
- Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%'
|
||||||
|
- 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%'
|
||||||
|
3. For location: Be FLEXIBLE with variations and abbreviations
|
||||||
|
- 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%'
|
||||||
|
- 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%'
|
||||||
|
- 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%'
|
||||||
|
4. For sectors: Use LEFT JOIN and include multiple synonyms
|
||||||
|
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%'
|
||||||
|
5. For founding year filters (include NULL to be inclusive):
|
||||||
|
- "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL)
|
||||||
|
- "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL)
|
||||||
|
- "founded in 2020" → WHERE founded_year = 2020
|
||||||
|
6. For investor-related queries: Use JOIN investor_companies
|
||||||
|
7. Use LEFT JOIN for sectors so companies without tags still match
|
||||||
|
8. Use DISTINCT to avoid duplicates from joins
|
||||||
|
9. Be INCLUSIVE - use OR conditions with synonyms and variations
|
||||||
|
10. Return a single, complete SELECT query
|
||||||
|
|
||||||
|
Example Queries:
|
||||||
|
Q: "Fintech companies founded in 2020"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||||
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||||
|
WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%')
|
||||||
|
AND c.founded_year = 2020
|
||||||
|
|
||||||
|
Q: "AI companies in San Francisco"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||||
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||||
|
WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
||||||
|
AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%')
|
||||||
|
|
||||||
|
Q: "Healthcare companies"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||||
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||||
|
WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
||||||
|
|
||||||
|
Q: "Companies funded by Sequoia"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
JOIN investor_companies ic ON c.id = ic.company_id
|
||||||
|
JOIN investors i ON ic.investor_id = i.id
|
||||||
|
WHERE i.name LIKE '%Sequoia%'
|
||||||
|
|
||||||
|
Q: "European startups founded after 2019"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%')
|
||||||
|
AND (c.founded_year > 2019 OR c.founded_year IS NULL)
|
||||||
|
|
||||||
|
Q: "SaaS companies"
|
||||||
|
A: SELECT DISTINCT c.id FROM companies c
|
||||||
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||||
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||||
|
WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%'
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- Use LEFT JOIN so companies without sector tags still match via industry field
|
||||||
|
- Use OR conditions with related keywords/synonyms to cast a wider net
|
||||||
|
- Include NULL checks for optional filters to avoid excluding companies with missing data
|
||||||
|
|
||||||
|
Return ONLY the SQL query, no explanations or markdown.""",
|
||||||
|
),
|
||||||
|
("user", "{question}"),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_cache_key(self, question: str) -> str:
|
||||||
|
"""Generate cache key from normalized question."""
|
||||||
|
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
||||||
|
|
||||||
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
||||||
"""Process a query using the LLM and return company response data.
|
"""Process a query by generating and executing SQL directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: The natural language query to process
|
question: The natural language query to process
|
||||||
"""
|
"""
|
||||||
# Let the LLM handle all database interactions and filtering to get company IDs
|
cache_key = self._get_cache_key(question)
|
||||||
response = self.agent.invoke(
|
|
||||||
{"messages": [("user", question)]},
|
|
||||||
config={"recursion_limit": 50},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the actual message content
|
# Check cache first
|
||||||
logger.info(f"{response}")
|
if cache_key in self.query_cache:
|
||||||
|
sql_query = self.query_cache[cache_key]
|
||||||
|
logger.info(f"Using cached SQL: {sql_query}")
|
||||||
|
else:
|
||||||
|
# Generate SQL query
|
||||||
|
messages = self.sql_prompt.format_messages(question=question)
|
||||||
|
response = self.llm.invoke(messages)
|
||||||
|
sql_query = response.content.strip()
|
||||||
|
|
||||||
# Look through all messages to find the SQL query results (ToolMessage with actual data)
|
# Clean up SQL (remove markdown code blocks if present)
|
||||||
company_ids = []
|
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
||||||
for message in response["messages"]:
|
|
||||||
if hasattr(message, "content") and message.content:
|
|
||||||
# Check if this looks like SQL results (contains tuples with numbers)
|
|
||||||
if "(" in str(message.content) and "," in str(message.content):
|
|
||||||
company_ids = self._extract_company_ids_from_response(
|
|
||||||
str(message.content)
|
|
||||||
)
|
|
||||||
if company_ids:
|
|
||||||
logger.info(
|
|
||||||
f"Extracted {len(company_ids)} company IDs from results"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# If no IDs found from ToolMessage, check the final AI message
|
# Cache the query
|
||||||
if not company_ids:
|
self.query_cache[cache_key] = sql_query
|
||||||
final_message_content = response["messages"][-1].content
|
logger.info(f"Generated SQL: {sql_query}")
|
||||||
logger.info(f"AI Response: \n{final_message_content}")
|
|
||||||
company_ids = self._extract_company_ids_from_response(final_message_content)
|
|
||||||
|
|
||||||
# Fetch full company data with relationships using the IDs
|
|
||||||
return self._fetch_companies_by_ids(company_ids)
|
|
||||||
|
|
||||||
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
|
|
||||||
"""Extract company IDs from AI response."""
|
|
||||||
import re
|
|
||||||
|
|
||||||
company_ids = []
|
|
||||||
|
|
||||||
# Check if response is NO_RESULTS
|
|
||||||
if "NO_RESULTS" in ai_response.upper():
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
# Execute query to get company IDs
|
||||||
|
db_session = next(get_db())
|
||||||
try:
|
try:
|
||||||
# The response contains tuples like (1,), (5,), etc.
|
result = db_session.execute(text(sql_query))
|
||||||
# Extract numbers between parentheses
|
company_ids = [row[0] for row in result.fetchall()]
|
||||||
pattern = r"\((\d+),?\)"
|
logger.info(
|
||||||
matches = re.findall(pattern, ai_response)
|
f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}"
|
||||||
if matches:
|
)
|
||||||
company_ids = [int(match) for match in matches]
|
|
||||||
else:
|
|
||||||
# Fallback: extract all numbers
|
|
||||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
|
||||||
# Filter out very large numbers that might be tokens or timestamps
|
|
||||||
company_ids = [int(num) for num in numbers if int(num) < 100000]
|
|
||||||
|
|
||||||
|
return self._fetch_companies_by_ids(company_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting IDs from response: {e}")
|
logger.error(f"SQL execution error: {e}")
|
||||||
return []
|
logger.error(f"Failed SQL: {sql_query}")
|
||||||
|
# Return empty result
|
||||||
return company_ids
|
return PaginatedResponse(
|
||||||
|
items=[], total=0, page=1, page_size=10, total_pages=0
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
def _fetch_companies_by_ids(
|
def _fetch_companies_by_ids(
|
||||||
self, company_ids: List[int]
|
self, company_ids: List[int]
|
||||||
@@ -130,7 +174,7 @@ class CompanyQueryProcessor:
|
|||||||
items=[],
|
items=[],
|
||||||
total=0,
|
total=0,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=len(company_ids) if company_ids else 10,
|
page_size=10,
|
||||||
total_pages=0,
|
total_pages=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+131
-75
@@ -1,29 +1,24 @@
|
|||||||
import json
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from db.db import DATABASE_URL, get_db
|
from db.db import get_db
|
||||||
from db.models import FundTable, InvestorTable, ProjectTable
|
from db.models import FundTable, InvestorTable, ProjectTable
|
||||||
from langchain import hub
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
||||||
from langchain_community.utilities import SQLDatabase
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.prebuilt import create_react_agent
|
|
||||||
from schemas.router_schemas import (
|
from schemas.router_schemas import (
|
||||||
CompanyMinimal,
|
CompanyMinimal,
|
||||||
InvestmentResponse,
|
InvestmentResponse,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
SectorMinimal,
|
SectorMinimal,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from services.compatibility_score import calculate_project_investor_compatibility
|
from services.compatibility_score import calculate_project_investor_compatibility
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# Connect to SQLite
|
|
||||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
|
||||||
db = SQLDatabase.from_uri(DATABASE_URL)
|
|
||||||
|
|
||||||
|
|
||||||
class QueryProcessor:
|
class QueryProcessor:
|
||||||
@@ -34,89 +29,150 @@ class QueryProcessor:
|
|||||||
model="openai/gpt-4o-mini",
|
model="openai/gpt-4o-mini",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
||||||
# Update system message to specifically request only fund IDs
|
# Query cache for performance
|
||||||
system_message_updated = (
|
self.query_cache = {}
|
||||||
prompt_template.format(dialect="SQLite", top_k=100)
|
|
||||||
+ "\n\n=== IMPORTANT TERMINOLOGY ==="
|
# SQL generation prompt
|
||||||
+ "\n- When users say 'investors' or 'find me investors', they mean FUNDS"
|
self.sql_prompt = ChatPromptTemplate.from_messages(
|
||||||
+ "\n- Always query the 'funds' table for investment opportunities"
|
[
|
||||||
+ "\n- The 'investors' table is for parent company information only"
|
(
|
||||||
+ "\n- Relationship: investors (1) -> (many) funds"
|
"system",
|
||||||
+ "\n\n=== YOUR TASK ==="
|
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
|
||||||
+ "\nReturn ONLY fund IDs (funds.id) that match the user's criteria."
|
|
||||||
+ "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)"
|
Database Schema:
|
||||||
+ "\nNo explanations, no other data."
|
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
|
||||||
+ "\n\n=== QUERY GUIDELINES ==="
|
- fund_sectors: fund_id, sector_id
|
||||||
+ "\n1. For geographic searches: use funds.geographic_focus"
|
- fund_investment_stages: fund_id, stage_id
|
||||||
+ "\n2. For sector searches: JOIN with fund_sectors table"
|
- sectors: id, name
|
||||||
+ "\n3. For stage searches: JOIN with fund_investment_stages table"
|
- investment_stages: id, name
|
||||||
+ "\n4. Return ALL matching fund IDs, not just the first few"
|
- investors: id, name, aum
|
||||||
+ "\n5. If no results: respond with 'NO_RESULTS'"
|
|
||||||
+ "\n6. Never repeat the same failed query"
|
IMPORTANT RULES:
|
||||||
+ "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ==="
|
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
|
||||||
+ "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)"
|
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
|
||||||
+ "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')"
|
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
|
||||||
+ "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'"
|
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
|
||||||
+ "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc."
|
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
|
||||||
+ "\n- Examples:"
|
- If no geography specified, DON'T filter by geography
|
||||||
+ "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'"
|
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
|
||||||
+ "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'"
|
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
||||||
+ "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'"
|
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
|
||||||
+ "\n- Be INCLUSIVE: capture all relevant regional variations"
|
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
|
||||||
)
|
- If stage not specified, include ALL funds
|
||||||
self.agent = create_react_agent(
|
4. For sectors: Use LEFT JOIN and include related terms with OR
|
||||||
model=self.llm,
|
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
|
||||||
tools=self.toolkit.get_tools(),
|
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
|
||||||
prompt=system_message_updated,
|
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
|
||||||
|
5. For check size filters (be flexible with ranges):
|
||||||
|
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
|
||||||
|
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
|
||||||
|
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
|
||||||
|
6. Use LEFT JOIN for stages and sectors so funds without tags still match
|
||||||
|
7. Use DISTINCT to avoid duplicates from joins
|
||||||
|
8. Be INCLUSIVE - use OR conditions to cast a wider net
|
||||||
|
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
|
||||||
|
10. Return a single, complete SELECT query
|
||||||
|
|
||||||
|
Example Queries:
|
||||||
|
Q: "Seed stage investors in Europe"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||||
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||||
|
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
|
||||||
|
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
|
||||||
|
|
||||||
|
Q: "Fintech investors with check size under 5 million"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||||
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||||
|
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
|
||||||
|
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
|
||||||
|
|
||||||
|
Q: "Seed stage investors"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||||
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||||
|
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
||||||
|
|
||||||
|
Q: "Growth stage investors"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||||
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||||
|
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
|
||||||
|
|
||||||
|
Q: "AI investors in America"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||||
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||||
|
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
||||||
|
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
|
||||||
|
|
||||||
|
Q: "Healthcare investors"
|
||||||
|
A: SELECT DISTINCT f.id FROM funds f
|
||||||
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||||
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||||
|
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
||||||
|
|
||||||
|
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
|
||||||
|
|
||||||
|
Return ONLY the SQL query, no explanations or markdown.""",
|
||||||
|
),
|
||||||
|
("user", "{question}"),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_cache_key(self, question: str) -> str:
|
||||||
|
"""Generate cache key from normalized question."""
|
||||||
|
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
||||||
|
|
||||||
def process_query(
|
def process_query(
|
||||||
self, question: str, project_id: Optional[int] = None
|
self, question: str, project_id: Optional[int] = None
|
||||||
) -> PaginatedResponse[InvestmentResponse]:
|
) -> PaginatedResponse[InvestmentResponse]:
|
||||||
"""Process a query using the LLM and return investment response data.
|
"""Process a query by generating and executing SQL directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: The natural language query to process
|
question: The natural language query to process
|
||||||
project_id: Optional project ID for compatibility scoring
|
project_id: Optional project ID for compatibility scoring
|
||||||
"""
|
"""
|
||||||
# Let the LLM handle all database interactions and filtering to get fund IDs
|
cache_key = self._get_cache_key(question)
|
||||||
response = self.agent.invoke(
|
|
||||||
{"messages": [("user", question)]},
|
|
||||||
config={"recursion_limit": 50},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the actual message content
|
# Check cache first
|
||||||
logger.info(f"{response}")
|
if cache_key in self.query_cache:
|
||||||
final_message_content = response["messages"][-1].content
|
sql_query = self.query_cache[cache_key]
|
||||||
logger.info(f"AI Response: \n{final_message_content}")
|
logger.info(f"Using cached SQL: {sql_query}")
|
||||||
# Extract fund IDs from the AI response
|
else:
|
||||||
fund_ids = self._extract_fund_ids_from_response(final_message_content)
|
# Generate SQL query
|
||||||
|
messages = self.sql_prompt.format_messages(question=question)
|
||||||
|
response = self.llm.invoke(messages)
|
||||||
|
sql_query = response.content.strip()
|
||||||
|
|
||||||
# Fetch full fund data with investor relationships using the IDs
|
# Clean up SQL (remove markdown code blocks if present)
|
||||||
return self._fetch_funds_by_ids(fund_ids, project_id)
|
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
||||||
|
|
||||||
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
|
# Cache the query
|
||||||
"""Extract fund IDs from AI response."""
|
self.query_cache[cache_key] = sql_query
|
||||||
import re
|
logger.info(f"Generated SQL: {sql_query}")
|
||||||
|
|
||||||
fund_ids = []
|
# Execute query to get fund IDs
|
||||||
|
db_session = next(get_db())
|
||||||
try:
|
try:
|
||||||
# Try multiple patterns to extract IDs from the response
|
result = db_session.execute(text(sql_query))
|
||||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
fund_ids = [row[0] for row in result.fetchall()]
|
||||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
logger.info(
|
||||||
fund_ids = [int(num) for num in numbers]
|
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
|
||||||
|
)
|
||||||
# Pattern 2: If response contains explicit ID references
|
|
||||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
|
||||||
if id_matches:
|
|
||||||
fund_ids = [int(id_str) for id_str in id_matches]
|
|
||||||
|
|
||||||
|
return self._fetch_funds_by_ids(fund_ids, project_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting IDs from response: {e}")
|
logger.error(f"SQL execution error: {e}")
|
||||||
return []
|
logger.error(f"Failed SQL: {sql_query}")
|
||||||
|
# Return empty result
|
||||||
return fund_ids
|
return PaginatedResponse(
|
||||||
|
items=[], total=0, page=1, page_size=10, total_pages=0
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
def _fetch_funds_by_ids(
|
def _fetch_funds_by_ids(
|
||||||
self, fund_ids: List[int], project_id: Optional[int] = None
|
self, fund_ids: List[int], project_id: Optional[int] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user