diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 88f6657..ba7a368 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/main.py b/app/main.py index fb93d85..a923ecd 100644 --- a/app/main.py +++ b/app/main.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from fastapi import FastAPI, File, Form, UploadFile from pydantic import BaseModel from routers import companies, investors, projects -from schemas.router_schemas import InvestorList +from schemas.router_schemas import InvestmentResponse, PaginatedResponse from services.llm_parser import InvestorProcessor from services.querying import QueryProcessor @@ -84,11 +84,16 @@ async def parse_csv( return results -@app.post("/query", response_model=InvestorList, tags=["Querying"]) +@app.post( + "/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"] +) async def query_investors(request: QueryRequest): """ Query investors using natural language. + Returns fund-level matches (one row per fund) with investor details. + This ensures only relevant funds are included in the response. + Supports queries like: - "Show me seed stage investors" - "Find fintech investors in Silicon Valley" diff --git a/app/services/__pycache__/querying.cpython-312.pyc b/app/services/__pycache__/querying.cpython-312.pyc index 88b87c9..3159ece 100644 Binary files a/app/services/__pycache__/querying.cpython-312.pyc and b/app/services/__pycache__/querying.cpython-312.pyc differ diff --git a/app/services/compatibility_score.py b/app/services/compatibility_score.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/crm.py b/app/services/crm.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/insight.py b/app/services/insight.py new file mode 100644 index 0000000..e69de29 diff --git a/app/services/querying.py b/app/services/querying.py index 27df87a..05b3fae 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -2,13 +2,18 @@ import os from typing import List from db.db import DATABASE_URL, get_db -from db.models import InvestorTable +from db.models import FundTable, InvestorTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent -from schemas.py_schemas import InvestorData, InvestorList +from schemas.router_schemas import ( + CompanyMinimal, + InvestmentResponse, + PaginatedResponse, + SectorMinimal, +) from sqlalchemy.orm import selectinload # Connect to SQLite @@ -21,16 +26,16 @@ class QueryProcessor: self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", - model="openai/gpt-4o-mini", + model="x-ai/grok-4-fast", temperature=0, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) - # Update system message to specifically request only investor IDs + # Update system message to specifically request only fund IDs system_message_updated = ( prompt_template.format(dialect="SQLite", top_k=5) - + "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. " + + "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. " + "Do NOT return any other information, explanations, or data. " - + "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. " + + "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. " + "Example format: 1, 5, 12, 23" ) self.agent = create_react_agent( @@ -39,9 +44,9 @@ class QueryProcessor: prompt=system_message_updated, ) - def process_query(self, question: str) -> InvestorList: - """Process a query using the LLM and return investor data.""" - # Let the LLM handle all database interactions and filtering to get IDs + def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]: + """Process a query using the LLM and return investment response data.""" + # Let the LLM handle all database interactions and filtering to get fund IDs response = self.agent.invoke( {"messages": [("user", question)]}, ) @@ -51,70 +56,122 @@ class QueryProcessor: response["messages"][-1].content if response.get("messages") else "" ) - # Extract investor IDs from the AI response - investor_ids = self._extract_investor_ids_from_response(ai_response) + # Extract fund IDs from the AI response + fund_ids = self._extract_fund_ids_from_response(ai_response) - # Fetch full investor data using the IDs - return self._fetch_investors_by_ids(investor_ids) + # Fetch full fund data with investor relationships using the IDs + return self._fetch_funds_by_ids(fund_ids) - def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]: - """Extract investor IDs from AI response.""" + def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]: + """Extract fund IDs from AI response.""" import re - investor_ids = [] + fund_ids = [] try: # Try multiple patterns to extract IDs from the response # Pattern 1: Simple numbers (assuming they are IDs) numbers = re.findall(r"\b\d+\b", ai_response) - investor_ids = [int(num) for num in numbers] + fund_ids = [int(num) for num in numbers] # Pattern 2: If response contains explicit ID references id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) if id_matches: - investor_ids = [int(id_str) for id_str in id_matches] + fund_ids = [int(id_str) for id_str in id_matches] except Exception as e: print(f"Error extracting IDs from response: {e}") return [] - return investor_ids + return fund_ids - def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList: - """Fetch investors with all their relationships from the database using IDs.""" - if not investor_ids: - return InvestorList(investors=[]) + def _fetch_funds_by_ids( + self, fund_ids: List[int] + ) -> PaginatedResponse[InvestmentResponse]: + """Fetch funds with all their relationships from the database using fund IDs. + Constructs response similar to read_investors but starting from funds.""" + if not fund_ids: + return PaginatedResponse( + items=[], + total=0, + page=1, + page_size=len(fund_ids) if fund_ids else 10, + total_pages=0, + ) # Get database session db_session = next(get_db()) try: - # Build query with all relationships loaded - query = ( - db_session.query(InvestorTable) + # Query funds with all necessary relationships loaded + funds = ( + db_session.query(FundTable) .options( - selectinload(InvestorTable.portfolio_companies), - selectinload(InvestorTable.team_members), - selectinload(InvestorTable.sectors), - selectinload(InvestorTable.funds), + selectinload(FundTable.investor).selectinload( + InvestorTable.portfolio_companies + ), + selectinload(FundTable.investor).selectinload( + InvestorTable.team_members + ), + selectinload(FundTable.investor).selectinload( + InvestorTable.sectors + ), + selectinload(FundTable.investment_stages), + selectinload(FundTable.sectors), ) - .filter(InvestorTable.id.in_(investor_ids)) + .filter(FundTable.id.in_(fund_ids)) + .all() ) - investors = query.all() + # Transform to InvestmentResponse format (one row per fund) + investment_responses = [] + for fund in funds: + investor = fund.investor - # Transform to InvestorData format - investor_data_list = [] - for investor in investors: - investor_data = InvestorData( - investor=investor, - portfolio_companies=investor.portfolio_companies, - team_members=investor.team_members, - sectors=investor.sectors, - funds=investor.funds, + # Get top 3 portfolio companies (id and name only) + portfolio_companies = [ + CompanyMinimal(id=company.id, name=company.name) + for company in investor.portfolio_companies[:3] + ] + + # Get stage focus as comma-separated string + stage_focus = ( + ", ".join([stage.name for stage in fund.investment_stages]) + if fund.investment_stages + else None ) - investor_data_list.append(investor_data) - return InvestorList(investors=investor_data_list) + # Get top 3 sectors from fund (id and name only) + fund_sectors = [ + SectorMinimal(id=sector.id, name=sector.name) + for sector in (fund.sectors[:3] if fund.sectors else []) + ] + + investment_response = InvestmentResponse( + id=investor.id, + name=f"{investor.name} - {fund.fund_name}" + if fund.fund_name + else investor.name, + aum=investor.aum, + check_size_lower=fund.check_size_lower, + check_size_upper=fund.check_size_upper, + geographic_focus=fund.geographic_focus, + stage_focus=stage_focus, + portfolio_companies=portfolio_companies, + sectors=fund_sectors, + compatibility_score=1.0, + ) + investment_responses.append(investment_response) + + total_count = len(investment_responses) + total_pages = 1 if total_count > 0 else 0 + + return PaginatedResponse( + items=investment_responses, + total=total_count, + page=1, + page_size=total_count, + total_pages=total_pages, + ) finally: db_session.close() diff --git a/app/services/report_gen.py b/app/services/report_gen.py new file mode 100644 index 0000000..e69de29