diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 26e197f..9932cb9 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/db/__pycache__/models.cpython-312.pyc b/app/db/__pycache__/models.cpython-312.pyc index 5bcb5ac..e04c350 100644 Binary files a/app/db/__pycache__/models.cpython-312.pyc and b/app/db/__pycache__/models.cpython-312.pyc differ diff --git a/app/main.py b/app/main.py index 4a79176..a017939 100644 --- a/app/main.py +++ b/app/main.py @@ -44,6 +44,15 @@ class QueryRequest(BaseModel): } } +class CompanyQueryRequest(BaseModel): + question: str + + class Config: + json_schema_extra = { + "example": { + "question": "Find me companies in the fintech sector located in San Francisco." + } + } @app.get("/") def health(): @@ -120,7 +129,7 @@ async def query_investors(request: QueryRequest): @app.post( "/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"] ) -async def query_companies(request: QueryRequest): +async def query_companies(request: CompanyQueryRequest): """ Query companies using natural language. diff --git a/app/schemas/__pycache__/router_schemas.cpython-312.pyc b/app/schemas/__pycache__/router_schemas.cpython-312.pyc index dbc61b1..6b36456 100644 Binary files a/app/schemas/__pycache__/router_schemas.cpython-312.pyc and b/app/schemas/__pycache__/router_schemas.cpython-312.pyc differ diff --git a/app/services/__pycache__/querying.cpython-312.pyc b/app/services/__pycache__/querying.cpython-312.pyc index 0fc1072..4221e66 100644 Binary files a/app/services/__pycache__/querying.cpython-312.pyc and b/app/services/__pycache__/querying.cpython-312.pyc differ diff --git a/app/services/company_querying.py b/app/services/company_querying.py index 80a2b71..e002af3 100644 --- a/app/services/company_querying.py +++ b/app/services/company_querying.py @@ -1,21 +1,17 @@ +import hashlib import logging import os from typing import List -from db.db import DATABASE_URL, get_db +from db.db import get_db from db.models import CompanyTable -from langchain import hub -from langchain_community.agent_toolkits import SQLDatabaseToolkit -from langchain_community.utilities import SQLDatabase +from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from langgraph.prebuilt import create_react_agent from schemas.router_schemas import CompanyData, PaginatedResponse +from sqlalchemy import text from sqlalchemy.orm import selectinload logger = logging.getLogger(__name__) -# Connect to SQLite -prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") -db = SQLDatabase.from_uri(DATABASE_URL) class CompanyQueryProcessor: @@ -26,96 +22,144 @@ class CompanyQueryProcessor: model="openai/gpt-4o-mini", temperature=0, ) - self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) - # Update system message to specifically request only company IDs - system_message_updated = ( - prompt_template.format(dialect="SQLite", top_k=5) - + "\n\n=== CRITICAL INSTRUCTIONS ===" - + "\n- Your ONLY task is to run SQL queries and extract company IDs" - + "\n- When you get SQL results with company IDs, return them EXACTLY as shown" - + "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs" - + "\n- Do NOT add any explanations, just list the IDs" - + "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'" - + "\n\n=== QUERY GUIDELINES ===" - + "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'" - + "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'" - + "\n3. For location searches: WHERE companies.location LIKE '%location%'" - + "\n4. For founding year searches: WHERE companies.founded_year >= year" - + "\n5. For investor-related: JOIN investor_companies table" - ) - self.agent = create_react_agent( - model=self.llm, - tools=self.toolkit.get_tools(), - prompt=system_message_updated, + + # Query cache for performance + self.query_cache = {} + + # SQL generation prompt + self.sql_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements. + +Database Schema: +- companies: id, name, industry, location, description, founded_year, website +- company_sector: company_id, sector_id +- sectors: id, name +- investor_companies: investor_id, company_id +- investors: id, name, aum +- team_members: id, company_id, name, title + +IMPORTANT RULES: +1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id +2. For industry: Check BOTH industry field AND sectors table with synonyms + - Use LEFT JOIN for sectors so companies without sector tags still match + - Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%' + - 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%' +3. For location: Be FLEXIBLE with variations and abbreviations + - 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%' + - 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%' + - 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%' +4. For sectors: Use LEFT JOIN and include multiple synonyms + - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%' +5. For founding year filters (include NULL to be inclusive): + - "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL) + - "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL) + - "founded in 2020" → WHERE founded_year = 2020 +6. For investor-related queries: Use JOIN investor_companies +7. Use LEFT JOIN for sectors so companies without tags still match +8. Use DISTINCT to avoid duplicates from joins +9. Be INCLUSIVE - use OR conditions with synonyms and variations +10. Return a single, complete SELECT query + +Example Queries: +Q: "Fintech companies founded in 2020" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%') + AND c.founded_year = 2020 + +Q: "AI companies in San Francisco" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') + AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%') + +Q: "Healthcare companies" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' + +Q: "Companies funded by Sequoia" +A: SELECT DISTINCT c.id FROM companies c + JOIN investor_companies ic ON c.id = ic.company_id + JOIN investors i ON ic.investor_id = i.id + WHERE i.name LIKE '%Sequoia%' + +Q: "European startups founded after 2019" +A: SELECT DISTINCT c.id FROM companies c + WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%') + AND (c.founded_year > 2019 OR c.founded_year IS NULL) + +Q: "SaaS companies" +A: SELECT DISTINCT c.id FROM companies c + LEFT JOIN company_sector cs ON c.id = cs.company_id + LEFT JOIN sectors sec ON cs.sector_id = sec.id + WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%' + +IMPORTANT: +- Use LEFT JOIN so companies without sector tags still match via industry field +- Use OR conditions with related keywords/synonyms to cast a wider net +- Include NULL checks for optional filters to avoid excluding companies with missing data + +Return ONLY the SQL query, no explanations or markdown.""", + ), + ("user", "{question}"), + ] ) + def _get_cache_key(self, question: str) -> str: + """Generate cache key from normalized question.""" + return hashlib.md5(question.lower().strip().encode()).hexdigest() + def process_query(self, question: str) -> PaginatedResponse[CompanyData]: - """Process a query using the LLM and return company response data. + """Process a query by generating and executing SQL directly. Args: question: The natural language query to process """ - # Let the LLM handle all database interactions and filtering to get company IDs - response = self.agent.invoke( - {"messages": [("user", question)]}, - config={"recursion_limit": 50}, - ) + cache_key = self._get_cache_key(question) - # Extract the actual message content - logger.info(f"{response}") + # Check cache first + if cache_key in self.query_cache: + sql_query = self.query_cache[cache_key] + logger.info(f"Using cached SQL: {sql_query}") + else: + # Generate SQL query + messages = self.sql_prompt.format_messages(question=question) + response = self.llm.invoke(messages) + sql_query = response.content.strip() - # Look through all messages to find the SQL query results (ToolMessage with actual data) - company_ids = [] - for message in response["messages"]: - if hasattr(message, "content") and message.content: - # Check if this looks like SQL results (contains tuples with numbers) - if "(" in str(message.content) and "," in str(message.content): - company_ids = self._extract_company_ids_from_response( - str(message.content) - ) - if company_ids: - logger.info( - f"Extracted {len(company_ids)} company IDs from results" - ) - break + # Clean up SQL (remove markdown code blocks if present) + sql_query = sql_query.replace("```sql", "").replace("```", "").strip() - # If no IDs found from ToolMessage, check the final AI message - if not company_ids: - final_message_content = response["messages"][-1].content - logger.info(f"AI Response: \n{final_message_content}") - company_ids = self._extract_company_ids_from_response(final_message_content) - - # Fetch full company data with relationships using the IDs - return self._fetch_companies_by_ids(company_ids) - - def _extract_company_ids_from_response(self, ai_response: str) -> List[int]: - """Extract company IDs from AI response.""" - import re - - company_ids = [] - - # Check if response is NO_RESULTS - if "NO_RESULTS" in ai_response.upper(): - return [] + # Cache the query + self.query_cache[cache_key] = sql_query + logger.info(f"Generated SQL: {sql_query}") + # Execute query to get company IDs + db_session = next(get_db()) try: - # The response contains tuples like (1,), (5,), etc. - # Extract numbers between parentheses - pattern = r"\((\d+),?\)" - matches = re.findall(pattern, ai_response) - if matches: - company_ids = [int(match) for match in matches] - else: - # Fallback: extract all numbers - numbers = re.findall(r"\b\d+\b", ai_response) - # Filter out very large numbers that might be tokens or timestamps - company_ids = [int(num) for num in numbers if int(num) < 100000] + result = db_session.execute(text(sql_query)) + company_ids = [row[0] for row in result.fetchall()] + logger.info( + f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}" + ) + return self._fetch_companies_by_ids(company_ids) except Exception as e: - logger.error(f"Error extracting IDs from response: {e}") - return [] - - return company_ids + logger.error(f"SQL execution error: {e}") + logger.error(f"Failed SQL: {sql_query}") + # Return empty result + return PaginatedResponse( + items=[], total=0, page=1, page_size=10, total_pages=0 + ) + finally: + db_session.close() def _fetch_companies_by_ids( self, company_ids: List[int] @@ -130,7 +174,7 @@ class CompanyQueryProcessor: items=[], total=0, page=1, - page_size=len(company_ids) if company_ids else 10, + page_size=10, total_pages=0, ) diff --git a/app/services/querying.py b/app/services/querying.py index 5bd0219..2a566eb 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -1,29 +1,24 @@ -import json +import hashlib import logging import os from typing import List, Optional -from db.db import DATABASE_URL, get_db +from db.db import get_db from db.models import FundTable, InvestorTable, ProjectTable -from langchain import hub -from langchain_community.agent_toolkits import SQLDatabaseToolkit -from langchain_community.utilities import SQLDatabase +from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from langgraph.prebuilt import create_react_agent from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, PaginatedResponse, SectorMinimal, ) +from sqlalchemy import text from sqlalchemy.orm import selectinload from services.compatibility_score import calculate_project_investor_compatibility logger = logging.getLogger(__name__) -# Connect to SQLite -prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") -db = SQLDatabase.from_uri(DATABASE_URL) class QueryProcessor: @@ -34,89 +29,150 @@ class QueryProcessor: model="openai/gpt-4o-mini", temperature=0, ) - self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) - # Update system message to specifically request only fund IDs - system_message_updated = ( - prompt_template.format(dialect="SQLite", top_k=100) - + "\n\n=== IMPORTANT TERMINOLOGY ===" - + "\n- When users say 'investors' or 'find me investors', they mean FUNDS" - + "\n- Always query the 'funds' table for investment opportunities" - + "\n- The 'investors' table is for parent company information only" - + "\n- Relationship: investors (1) -> (many) funds" - + "\n\n=== YOUR TASK ===" - + "\nReturn ONLY fund IDs (funds.id) that match the user's criteria." - + "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)" - + "\nNo explanations, no other data." - + "\n\n=== QUERY GUIDELINES ===" - + "\n1. For geographic searches: use funds.geographic_focus" - + "\n2. For sector searches: JOIN with fund_sectors table" - + "\n3. For stage searches: JOIN with fund_investment_stages table" - + "\n4. Return ALL matching fund IDs, not just the first few" - + "\n5. If no results: respond with 'NO_RESULTS'" - + "\n6. Never repeat the same failed query" - + "\n\n=== GEOGRAPHIC SEARCH RULES (VERY IMPORTANT) ===" - + "\n- ALWAYS use LIKE '%keyword%' for geographic searches, NEVER use exact equality (=)" - + "\n- When user says 'Europe', match ANY location containing 'Europe' (e.g., 'Northern Europe', 'Western Europe', 'Europe', 'Central Europe')" - + "\n- When user says 'America', match locations like 'North America', 'South America', 'Latin America', 'United States'" - + "\n- When user says 'Asia', match 'Asia', 'Southeast Asia', 'East Asia', etc." - + "\n- Examples:" - + "\n * User: 'Europe' → SQL: WHERE geographic_focus LIKE '%Europe%'" - + "\n * User: 'America' → SQL: WHERE geographic_focus LIKE '%America%'" - + "\n * User: 'UK' → SQL: WHERE geographic_focus LIKE '%UK%' OR geographic_focus LIKE '%United Kingdom%'" - + "\n- Be INCLUSIVE: capture all relevant regional variations" - ) - self.agent = create_react_agent( - model=self.llm, - tools=self.toolkit.get_tools(), - prompt=system_message_updated, + + # Query cache for performance + self.query_cache = {} + + # SQL generation prompt + self.sql_prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements. + +Database Schema: +- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus +- fund_sectors: fund_id, sector_id +- fund_investment_stages: fund_id, stage_id +- sectors: id, name +- investment_stages: id, name +- investors: id, name, aum + +IMPORTANT RULES: +1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id +2. For geography: Be FLEXIBLE - use OR with variations and partial matches + - 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%' + - 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%' + - 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%' + - If no geography specified, DON'T filter by geography +3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms + - 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' + - 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%' + - 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' + - If stage not specified, include ALL funds +4. For sectors: Use LEFT JOIN and include related terms with OR + - 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' + - 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%' + - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' +5. For check size filters (be flexible with ranges): + - "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL) + - "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL) + - "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y +6. Use LEFT JOIN for stages and sectors so funds without tags still match +7. Use DISTINCT to avoid duplicates from joins +8. Be INCLUSIVE - use OR conditions to cast a wider net +9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters +10. Return a single, complete SELECT query + +Example Queries: +Q: "Seed stage investors in Europe" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL) + AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%') + +Q: "Fintech investors with check size under 5 million" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL) + AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL) + +Q: "Seed stage investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' + +Q: "Growth stage investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id + LEFT JOIN investment_stages s ON fis.stage_id = s.id + WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%' + +Q: "AI investors in America" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') + AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%') + +Q: "Healthcare investors" +A: SELECT DISTINCT f.id FROM funds f + LEFT JOIN fund_sectors fs ON f.id = fs.fund_id + LEFT JOIN sectors sec ON fs.sector_id = sec.id + WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' + +IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall. + +Return ONLY the SQL query, no explanations or markdown.""", + ), + ("user", "{question}"), + ] ) + def _get_cache_key(self, question: str) -> str: + """Generate cache key from normalized question.""" + return hashlib.md5(question.lower().strip().encode()).hexdigest() + def process_query( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: - """Process a query using the LLM and return investment response data. + """Process a query by generating and executing SQL directly. Args: question: The natural language query to process project_id: Optional project ID for compatibility scoring """ - # Let the LLM handle all database interactions and filtering to get fund IDs - response = self.agent.invoke( - {"messages": [("user", question)]}, - config={"recursion_limit": 50}, - ) + cache_key = self._get_cache_key(question) - # Extract the actual message content - logger.info(f"{response}") - final_message_content = response["messages"][-1].content - logger.info(f"AI Response: \n{final_message_content}") - # Extract fund IDs from the AI response - fund_ids = self._extract_fund_ids_from_response(final_message_content) + # Check cache first + if cache_key in self.query_cache: + sql_query = self.query_cache[cache_key] + logger.info(f"Using cached SQL: {sql_query}") + else: + # Generate SQL query + messages = self.sql_prompt.format_messages(question=question) + response = self.llm.invoke(messages) + sql_query = response.content.strip() - # Fetch full fund data with investor relationships using the IDs - return self._fetch_funds_by_ids(fund_ids, project_id) + # Clean up SQL (remove markdown code blocks if present) + sql_query = sql_query.replace("```sql", "").replace("```", "").strip() - def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]: - """Extract fund IDs from AI response.""" - import re + # Cache the query + self.query_cache[cache_key] = sql_query + logger.info(f"Generated SQL: {sql_query}") - fund_ids = [] + # Execute query to get fund IDs + db_session = next(get_db()) try: - # Try multiple patterns to extract IDs from the response - # Pattern 1: Simple numbers (assuming they are IDs) - numbers = re.findall(r"\b\d+\b", ai_response) - fund_ids = [int(num) for num in numbers] - - # Pattern 2: If response contains explicit ID references - id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) - if id_matches: - fund_ids = [int(id_str) for id_str in id_matches] + result = db_session.execute(text(sql_query)) + fund_ids = [row[0] for row in result.fetchall()] + logger.info( + f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}" + ) + return self._fetch_funds_by_ids(fund_ids, project_id) except Exception as e: - print(f"Error extracting IDs from response: {e}") - return [] - - return fund_ids + logger.error(f"SQL execution error: {e}") + logger.error(f"Failed SQL: {sql_query}") + # Return empty result + return PaginatedResponse( + items=[], total=0, page=1, page_size=10, total_pages=0 + ) + finally: + db_session.close() def _fetch_funds_by_ids( self, fund_ids: List[int], project_id: Optional[int] = None