from typing import List, Optional import chromadb from db.models import InvestorTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from py_schemas import InvestorData, InvestorList from settings import settings from sqlalchemy.orm import selectinload # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") db = SQLDatabase.from_uri("sqlite:///investors_2.db") system_message = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n Get answers from the Sql database and the vector database" ) class QueryProcessor: def __init__( self, sql_session: Optional[object] = None, vector_db_client: Optional[object] = None, ): self.sql_session = sql_session self.llm = ChatOpenAI( api_key=settings.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1", model="google/gemini-2.5-flash-lite", temperature=0.3, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) self.agent = create_react_agent( model=self.llm, tools=self.toolkit.get_tools() + [self.query_vector_database], prompt=system_message, ) self.vector_db_client = vector_db_client self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") self.collection = self.vector_db_client.get_or_create_collection( name="investor_descriptions", metadata={ "description": "Investor descriptions and investment thesis focus" }, ) def query_sql_database(self, query: str) -> Optional[InvestorList]: """Query the SQL database for investor information.""" if not self.sql_session: return None # Implement SQL querying logic here result = self.sql_session.execute(query) investors = result.scalars().all() return InvestorList(investors=investors) def query_vector_database(self, query: str) -> Optional[InvestorList]: """Query the vector database for investor information.""" if not self.vector_db_client: return None print("VECTOR STORE WAS CALLED") # Query the collection directly, not passing collection as parameter results = self.collection.query( query_texts=[query], # ChromaDB expects a list of query texts n_results=3, # Specify how many results you want ) print(results) # ChromaDB returns results in a different structure # results will have 'documents', 'metadatas', 'ids', 'distances' return results def process_query(self, question: str) -> InvestorList: """Process a query using the LLM and return structured investor data.""" # Extract filters from the query first filters = self._extract_filters_from_query(question) # Get AI response for additional context response = self.agent.invoke( {"messages": [("user", question)]}, ) # Extract the actual message content ai_response = ( response["messages"][-1].content if response.get("messages") else "" ) # Try to extract investor IDs or names from the AI response investor_ids = self._extract_investor_info_from_response(ai_response) # Fetch filtered investor data with relationships from database return self._fetch_investors_with_relationships(investor_ids, filters) def _extract_investor_info_from_response(self, ai_response: str) -> List[int]: """Extract investor IDs from AI response. This is a simple implementation.""" # This is a basic implementation - you might want to make it more sophisticated # based on how your AI formats responses investor_ids = [] # If the AI can't provide structured data, fall back to getting all investors # that match basic criteria try: # Try to extract numbers that might be IDs import re ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower()) investor_ids = [int(id_str) for id_str in ids] except Exception: pass return investor_ids if investor_ids else [] def _extract_filters_from_query(self, question: str) -> dict: """Extract filter criteria from natural language query.""" question_lower = question.lower() filters = {} # Extract stage filters if any( stage in question_lower for stage in [ "seed", "series a", "series b", "series c", "growth", "late stage", ] ): if "seed" in question_lower: filters["stage"] = "SEED" elif "series a" in question_lower: filters["stage"] = "SERIES_A" elif "series b" in question_lower: filters["stage"] = "SERIES_B" elif "series c" in question_lower: filters["stage"] = "SERIES_C" elif "growth" in question_lower: filters["stage"] = "GROWTH" elif "late stage" in question_lower: filters["stage"] = "LATE_STAGE" # Extract geographic filters if any( geo in question_lower for geo in [ "us", "usa", "united states", "europe", "asia", "silicon valley", "bay area", ] ): if ( "us" in question_lower or "usa" in question_lower or "united states" in question_lower ): filters["geography"] = "US" elif "europe" in question_lower: filters["geography"] = "Europe" elif "asia" in question_lower: filters["geography"] = "Asia" elif "silicon valley" in question_lower or "bay area" in question_lower: filters["geography"] = "Silicon Valley" # Extract sector filters sectors = [ "fintech", "healthcare", "saas", "ai", "biotech", "consumer", "enterprise", "crypto", "blockchain", ] for sector in sectors: if sector in question_lower: filters["sector"] = sector break # Extract check size filters (simple patterns) import re amounts = re.findall( r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower ) if amounts: amount = amounts[0].replace(",", "") if "million" in question_lower or "m" in question_lower: filters["min_check_size"] = int(float(amount) * 1000000) elif "thousand" in question_lower or "k" in question_lower: filters["min_check_size"] = int(float(amount) * 1000) return filters def _fetch_investors_with_relationships( self, investor_ids: List[int] = None, filters: dict = None ) -> InvestorList: """Fetch investors with all their relationships from the database.""" if not self.sql_session: return InvestorList(investors=[]) # Import here to avoid circular imports from db.models import SectorTable # Build query with all relationships loaded query = self.sql_session.query(InvestorTable).options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) # Apply filters if provided if filters: if "stage" in filters: from db.models import InvestmentStage stage_enum = getattr(InvestmentStage, filters["stage"]) query = query.filter(InvestorTable.stage_focus == stage_enum) if "geography" in filters: query = query.filter( InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%") ) if "min_check_size" in filters: query = query.filter( InvestorTable.check_size_lower >= filters["min_check_size"] ) if "max_check_size" in filters: query = query.filter( InvestorTable.check_size_upper <= filters["max_check_size"] ) if "min_aum" in filters: query = query.filter(InvestorTable.aum >= filters["min_aum"]) if "max_aum" in filters: query = query.filter(InvestorTable.aum <= filters["max_aum"]) if "sector" in filters: query = query.join(InvestorTable.sectors).filter( SectorTable.name.ilike(f"%{filters['sector']}%") ) # Filter by IDs if provided if investor_ids: query = query.filter(InvestorTable.id.in_(investor_ids)) else: # If no specific IDs and no filters, limit to prevent overwhelming response if not filters: query = query.limit(10) investors = query.all() # Transform to InvestorData format investor_data_list = [] for investor in investors: investor_data = InvestorData( investor=investor, portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, ) investor_data_list.append(investor_data) return InvestorList(investors=investor_data_list)