import asyncio from typing import List, Optional import chromadb import pandas as pd from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from py_schemas import InvestorData from pydantic import BaseModel from settings import settings class InvestorList(BaseModel): """Schema for LLM structured output""" investor_list: List[InvestorData] class InvestorProcessor: def __init__( self, sql_session: Optional[object] = None, vector_db_client: Optional[object] = None, ): self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records. Given the following CSV data rows: {question} For each row, extract and structure the following fields for the investor: - name: The investor's full name - description: Description of the investor - aum: Assets under management (as integer, use 0 if not available) - check_size_lower: Lower bound of investment check size (as integer) - check_size_upper: Upper bound of investment check size (as integer) - geographic_focus: Geographic region focus - stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage) - number_of_investments: Number of investments made (default 0) Also extract related data: - portfolio_companies: List of companies they've invested in - team_members: List of team members with name, role, email - sectors: List of sectors they focus on Important: - If a field is not available, use appropriate defaults - stage_focus must be one of the valid enum values - Return clean, valid JSON only Return the data as a structured list of comprehensive investor data.""" self.prompt = PromptTemplate( template=self.template, input_variables=["question"] ) self.llm = ChatOpenAI( api_key=settings.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1", model="google/gemini-2.5-flash-lite", temperature=0, ) self.structured_llm = self.llm.with_structured_output(InvestorList) self.sql_session = sql_session self.vector_db_client = vector_db_client self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") self.collection = self.vector_db_client.get_or_create_collection( name="investor_descriptions", metadata={ "description": "Investor descriptions and investment thesis focus" }, ) async def _process_batch( self, batch: pd.DataFrame, batch_idx: int ) -> List[InvestorData]: """Process a single batch of data""" # Convert batch to string representation - clean the data batch_str = "" for idx, row in batch.iterrows(): # Clean values to remove control characters cleaned_row = {} for key, value in row.items(): if pd.notna(value): # Convert to string and clean control characters clean_value = ( str(value) .replace("\n", " ") .replace("\r", " ") .replace("\t", " ") ) # Remove other control characters clean_value = "".join( char for char in clean_value if ord(char) >= 32 or char in ["\n", "\r", "\t"] ) cleaned_row[key] = clean_value row_str = ", ".join( [f"{key}: {value}" for key, value in cleaned_row.items()] ) batch_str += f"Row {idx + 1}: {row_str}\n" try: print(f"Processing batch {batch_idx + 1}...") batch_results = await self.structured_llm.ainvoke(batch_str) return batch_results.investor_list except Exception as e: print(f"Error processing batch {batch_idx + 1}: {e}") return [] async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None: """Save investors and related data to SQL database""" if not self.sql_session: return try: for investor_data in investor_data_list: # Save investor db_investor = InvestorTable( name=investor_data.investor.name, description=investor_data.investor.description, aum=investor_data.investor.aum, check_size_lower=investor_data.investor.check_size_lower, check_size_upper=investor_data.investor.check_size_upper, geographic_focus=investor_data.investor.geographic_focus, stage_focus=investor_data.investor.stage_focus, number_of_investments=investor_data.investor.number_of_investments, ) self.sql_session.add(db_investor) self.sql_session.flush() # Get the ID # Save sectors and create associations for sector_data in investor_data.sectors: # Check if sector exists, create if not existing_sector = ( self.sql_session.query(SectorTable) .filter(SectorTable.name == sector_data.name) .first() ) if not existing_sector: db_sector = SectorTable(name=sector_data.name) self.sql_session.add(db_sector) self.sql_session.flush() # Add sector to investor's sectors db_investor.sectors.append(db_sector) else: # Add existing sector to investor if not already there if existing_sector not in db_investor.sectors: db_investor.sectors.append(existing_sector) # Save companies and create portfolio associations for company_data in investor_data.portfolio_companies: # Check if company exists, create if not existing_company = ( self.sql_session.query(CompanyTable) .filter(CompanyTable.name == company_data.name) .first() ) if not existing_company: db_company = CompanyTable( name=company_data.name, industry=company_data.industry, location=company_data.location, founded_year=company_data.founded_year, website=company_data.website, ) self.sql_session.add(db_company) self.sql_session.flush() # Add to investor's portfolio db_investor.portfolio_companies.append(db_company) else: # Add existing company to portfolio if not already there if existing_company not in db_investor.portfolio_companies: db_investor.portfolio_companies.append(existing_company) # Save team members for team_member_data in investor_data.team_members: # Check if team member exists existing_member = ( self.sql_session.query(InvestorTeamMember) .filter(InvestorTeamMember.email == team_member_data.email) .first() ) if not existing_member: db_team_member = InvestorTeamMember( name=team_member_data.name, role=team_member_data.role, email=team_member_data.email, investor_id=db_investor.id, ) self.sql_session.add(db_team_member) self.sql_session.commit() print(f"Successfully saved {len(investor_data_list)} investors to database") except Exception as e: self.sql_session.rollback() print(f"Error saving to SQL database: {e}") raise async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None: """Save investors to vector database""" if not self.vector_db_client: return documents = [] metadatas = [] ids = [] for i, investor_data in enumerate(investor_data_list): investor = investor_data.investor sectors = ", ".join([s.name for s in investor_data.sectors]) companies = ", ".join([c.name for c in investor_data.portfolio_companies]) doc_text = f""" Investor: {investor.name} Description: {investor.description or "N/A"} AUM: ${investor.aum:,} Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,} Geographic Focus: {investor.geographic_focus} Stage Focus: {investor.stage_focus.value} Sectors: {sectors} Portfolio Companies: {companies} """.strip() documents.append(doc_text) metadatas.append( { "name": investor.name, "stage_focus": investor.stage_focus.value, "geographic_focus": investor.geographic_focus, "aum": investor.aum, } ) ids.append( f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}" ) if documents: try: self.collection.add(documents=documents, metadatas=metadatas, ids=ids) print( f"Successfully saved {len(documents)} investors to vector database" ) except Exception as e: print(f"Error saving to vector database: {e}") async def process_csv( self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10 ) -> List[InvestorData]: """Process CSV data in parallel batches and save to databases""" results = [] # Create batches batches = [] for i in range(0, len(df), batch_size): batch = df.iloc[i : i + batch_size] batches.append((batch, i // batch_size)) # Process batches with concurrency control semaphore = asyncio.Semaphore(max_concurrent) async def process_with_semaphore(batch_data): batch, batch_idx = batch_data async with semaphore: return await self._process_batch(batch, batch_idx) # Execute all batches concurrently batch_results = await asyncio.gather( *[process_with_semaphore(batch_data) for batch_data in batches], return_exceptions=True, ) # Collect results, filtering out exceptions for batch_result in batch_results: if not isinstance(batch_result, Exception): results.extend(batch_result) # Save to databases if results: print(f"Successfully processed {len(results)} investors") await self._save_to_sql(results) await self._save_to_vector_db(results) return results