2025-08-29 18:42:55 +01:00
import asyncio
from typing import List , Optional
import chromadb
import pandas as pd
2025-09-02 15:51:35 +01:00
from db . models import CompanyTable , InvestorTable , InvestorTeamMember , SectorTable
2025-08-29 18:42:55 +01:00
from langchain_core . prompts import PromptTemplate
from langchain_openai import ChatOpenAI
2025-09-02 15:51:35 +01:00
from py_schemas import InvestorData
from pydantic import BaseModel
2025-08-29 18:42:55 +01:00
from settings import settings
2025-09-02 15:51:35 +01:00
class InvestorList ( BaseModel ) :
""" Schema for LLM structured output """
investor_list : List [ InvestorData ]
2025-08-29 18:42:55 +01:00
class InvestorProcessor :
def __init__ (
self ,
sql_session : Optional [ object ] = None ,
vector_db_client : Optional [ object ] = None ,
) :
self . template = """ You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
Given the following CSV data rows:
{question}
2025-09-02 15:51:35 +01:00
For each row, extract and structure the following fields for the investor:
2025-08-29 18:42:55 +01:00
- name: The investor ' s full name
2025-09-02 15:51:35 +01:00
- description: Description of the investor
2025-08-29 18:42:55 +01:00
- aum: Assets under management (as integer, use 0 if not available)
2025-09-02 15:51:35 +01:00
- check_size_lower: Lower bound of investment check size (as integer)
- check_size_upper: Upper bound of investment check size (as integer)
- geographic_focus: Geographic region focus
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
- number_of_investments: Number of investments made (default 0)
Also extract related data:
- portfolio_companies: List of companies they ' ve invested in
- team_members: List of team members with name, role, email
- sectors: List of sectors they focus on
2025-08-29 18:42:55 +01:00
Important:
2025-09-02 15:51:35 +01:00
- If a field is not available, use appropriate defaults
- stage_focus must be one of the valid enum values
2025-08-29 18:42:55 +01:00
- Return clean, valid JSON only
2025-09-02 15:51:35 +01:00
Return the data as a structured list of comprehensive investor data. """
2025-08-29 18:42:55 +01:00
self . prompt = PromptTemplate (
template = self . template , input_variables = [ " question " ]
)
self . llm = ChatOpenAI (
api_key = settings . OPENROUTER_API_KEY ,
base_url = " https://openrouter.ai/api/v1 " ,
2025-09-02 15:51:35 +01:00
model = " google/gemini-2.5-flash-lite " ,
2025-08-29 18:42:55 +01:00
temperature = 0 ,
)
self . structured_llm = self . llm . with_structured_output ( InvestorList )
self . sql_session = sql_session
self . vector_db_client = vector_db_client
self . vector_db_client = chromadb . PersistentClient ( path = " ./chroma_db " )
self . collection = self . vector_db_client . get_or_create_collection (
name = " investor_descriptions " ,
metadata = {
" description " : " Investor descriptions and investment thesis focus "
} ,
)
2025-09-02 15:51:35 +01:00
async def _process_batch (
self , batch : pd . DataFrame , batch_idx : int
) - > List [ InvestorData ] :
2025-08-29 18:42:55 +01:00
""" Process a single batch of data """
# Convert batch to string representation - clean the data
batch_str = " "
for idx , row in batch . iterrows ( ) :
# Clean values to remove control characters
cleaned_row = { }
for key , value in row . items ( ) :
if pd . notna ( value ) :
# Convert to string and clean control characters
clean_value = (
str ( value )
. replace ( " \n " , " " )
. replace ( " \r " , " " )
. replace ( " \t " , " " )
)
# Remove other control characters
clean_value = " " . join (
char
for char in clean_value
if ord ( char ) > = 32 or char in [ " \n " , " \r " , " \t " ]
)
cleaned_row [ key ] = clean_value
row_str = " , " . join (
[ f " { key } : { value } " for key , value in cleaned_row . items ( ) ]
)
batch_str + = f " Row { idx + 1 } : { row_str } \n "
try :
print ( f " Processing batch { batch_idx + 1 } ... " )
batch_results = await self . structured_llm . ainvoke ( batch_str )
return batch_results . investor_list
except Exception as e :
print ( f " Error processing batch { batch_idx + 1 } : { e } " )
return [ ]
2025-09-02 15:51:35 +01:00
async def _save_to_sql ( self , investor_data_list : List [ InvestorData ] ) - > None :
""" Save investors and related data to SQL database """
2025-08-29 18:42:55 +01:00
if not self . sql_session :
return
2025-09-02 15:51:35 +01:00
try :
for investor_data in investor_data_list :
# Save investor
db_investor = InvestorTable (
name = investor_data . investor . name ,
description = investor_data . investor . description ,
aum = investor_data . investor . aum ,
check_size_lower = investor_data . investor . check_size_lower ,
check_size_upper = investor_data . investor . check_size_upper ,
geographic_focus = investor_data . investor . geographic_focus ,
stage_focus = investor_data . investor . stage_focus ,
number_of_investments = investor_data . investor . number_of_investments ,
)
self . sql_session . add ( db_investor )
self . sql_session . flush ( ) # Get the ID
# Save sectors and create associations
for sector_data in investor_data . sectors :
# Check if sector exists, create if not
existing_sector = (
self . sql_session . query ( SectorTable )
. filter ( SectorTable . name == sector_data . name )
. first ( )
)
if not existing_sector :
db_sector = SectorTable ( name = sector_data . name )
self . sql_session . add ( db_sector )
self . sql_session . flush ( )
# Add sector to investor's sectors
db_investor . sectors . append ( db_sector )
else :
# Add existing sector to investor if not already there
if existing_sector not in db_investor . sectors :
db_investor . sectors . append ( existing_sector )
# Save companies and create portfolio associations
for company_data in investor_data . portfolio_companies :
# Check if company exists, create if not
existing_company = (
self . sql_session . query ( CompanyTable )
. filter ( CompanyTable . name == company_data . name )
. first ( )
)
if not existing_company :
db_company = CompanyTable (
name = company_data . name ,
industry = company_data . industry ,
location = company_data . location ,
founded_year = company_data . founded_year ,
website = company_data . website ,
)
self . sql_session . add ( db_company )
self . sql_session . flush ( )
2025-08-29 18:42:55 +01:00
2025-09-02 15:51:35 +01:00
# Add to investor's portfolio
db_investor . portfolio_companies . append ( db_company )
else :
# Add existing company to portfolio if not already there
if existing_company not in db_investor . portfolio_companies :
db_investor . portfolio_companies . append ( existing_company )
# Save team members
for team_member_data in investor_data . team_members :
# Check if team member exists
existing_member = (
self . sql_session . query ( InvestorTeamMember )
. filter ( InvestorTeamMember . email == team_member_data . email )
. first ( )
)
if not existing_member :
db_team_member = InvestorTeamMember (
name = team_member_data . name ,
role = team_member_data . role ,
email = team_member_data . email ,
investor_id = db_investor . id ,
)
self . sql_session . add ( db_team_member )
self . sql_session . commit ( )
print ( f " Successfully saved { len ( investor_data_list ) } investors to database " )
except Exception as e :
self . sql_session . rollback ( )
print ( f " Error saving to SQL database: { e } " )
raise
async def _save_to_vector_db ( self , investor_data_list : List [ InvestorData ] ) - > None :
2025-08-29 18:42:55 +01:00
""" Save investors to vector database """
if not self . vector_db_client :
return
documents = [ ]
metadatas = [ ]
ids = [ ]
2025-09-02 15:51:35 +01:00
for i , investor_data in enumerate ( investor_data_list ) :
investor = investor_data . investor
sectors = " , " . join ( [ s . name for s in investor_data . sectors ] )
companies = " , " . join ( [ c . name for c in investor_data . portfolio_companies ] )
doc_text = f """
Investor: { investor . name }
Description: { investor . description or " N/A " }
AUM: $ { investor . aum : , }
Check Size: $ { investor . check_size_lower : , } - $ { investor . check_size_upper : , }
Geographic Focus: { investor . geographic_focus }
Stage Focus: { investor . stage_focus . value }
Sectors: { sectors }
Portfolio Companies: { companies }
""" . strip ( )
2025-08-29 18:42:55 +01:00
documents . append ( doc_text )
2025-09-02 15:51:35 +01:00
metadatas . append (
{
" name " : investor . name ,
" stage_focus " : investor . stage_focus . value ,
" geographic_focus " : investor . geographic_focus ,
" aum " : investor . aum ,
}
)
ids . append (
f " investor_ { i } _ { investor . name . replace ( ' ' , ' _ ' ) . replace ( ' / ' , ' _ ' ) } "
)
2025-08-29 18:42:55 +01:00
if documents :
2025-09-02 15:51:35 +01:00
try :
self . collection . add ( documents = documents , metadatas = metadatas , ids = ids )
print (
f " Successfully saved { len ( documents ) } investors to vector database "
)
except Exception as e :
print ( f " Error saving to vector database: { e } " )
2025-08-29 18:42:55 +01:00
async def process_csv (
self , df : pd . DataFrame , batch_size : int = 10 , max_concurrent : int = 10
2025-09-02 15:51:35 +01:00
) - > List [ InvestorData ] :
2025-08-29 18:42:55 +01:00
""" Process CSV data in parallel batches and save to databases """
results = [ ]
# Create batches
batches = [ ]
for i in range ( 0 , len ( df ) , batch_size ) :
batch = df . iloc [ i : i + batch_size ]
batches . append ( ( batch , i / / batch_size ) )
# Process batches with concurrency control
semaphore = asyncio . Semaphore ( max_concurrent )
async def process_with_semaphore ( batch_data ) :
batch , batch_idx = batch_data
async with semaphore :
return await self . _process_batch ( batch , batch_idx )
# Execute all batches concurrently
batch_results = await asyncio . gather (
* [ process_with_semaphore ( batch_data ) for batch_data in batches ] ,
return_exceptions = True ,
)
# Collect results, filtering out exceptions
for batch_result in batch_results :
if not isinstance ( batch_result , Exception ) :
results . extend ( batch_result )
# Save to databases
if results :
2025-09-02 15:51:35 +01:00
print ( f " Successfully processed { len ( results ) } investors " )
2025-08-29 18:42:55 +01:00
await self . _save_to_sql ( results )
await self . _save_to_vector_db ( results )
return results