338 lines
12 KiB
Python
338 lines
12 KiB
Python
import asyncio
|
|
import os
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
from db.db import get_db_session
|
|
from db.models import (
|
|
CompanyMember,
|
|
CompanyTable,
|
|
InvestorMember,
|
|
InvestorTable,
|
|
SectorTable,
|
|
)
|
|
from langchain_openai import ChatOpenAI
|
|
from schemas.py_schemas import CompanyData, InvestorData
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
class InvestorProcessor:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="openai/gpt-4o-mini",
|
|
temperature=0,
|
|
)
|
|
|
|
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
|
|
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
|
|
|
|
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
|
"""Get existing sector or create new one"""
|
|
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
|
if not sector:
|
|
sector = SectorTable(name=sector_name)
|
|
db.add(sector)
|
|
db.flush() # Get the ID without committing
|
|
return sector
|
|
|
|
def _save_investor_to_db(
|
|
self, db: Session, investor_data: InvestorData
|
|
) -> InvestorTable:
|
|
"""Save investor data to database"""
|
|
# Create investor record
|
|
investor = InvestorTable(
|
|
name=investor_data.investor.name,
|
|
description=investor_data.investor.description,
|
|
aum=investor_data.investor.aum,
|
|
check_size_lower=investor_data.investor.check_size_lower,
|
|
check_size_upper=investor_data.investor.check_size_upper,
|
|
geographic_focus=investor_data.investor.geographic_focus,
|
|
stage_focus=investor_data.investor.stage_focus,
|
|
number_of_investments=investor_data.investor.number_of_investments,
|
|
)
|
|
db.add(investor)
|
|
db.flush() # Get the ID
|
|
|
|
# Add team members
|
|
for member_data in investor_data.team_members:
|
|
member = InvestorMember(
|
|
name=member_data.name,
|
|
role=member_data.role,
|
|
email=member_data.email,
|
|
investor_id=investor.id,
|
|
)
|
|
db.add(member)
|
|
|
|
# Add sectors
|
|
for sector_data in investor_data.sectors:
|
|
sector = self._get_or_create_sector(db, sector_data.name)
|
|
investor.sectors.append(sector)
|
|
|
|
# Add portfolio companies
|
|
for company_schema in investor_data.portfolio_companies:
|
|
# Convert CompanySchema to CompanyData format
|
|
company_data = CompanyData(
|
|
company=company_schema,
|
|
sectors=[], # Will be empty for portfolio companies
|
|
members=[], # Will be empty for portfolio companies
|
|
investors=[], # Will be empty for portfolio companies
|
|
)
|
|
company = self._save_company_to_db(db, company_data, skip_investors=True)
|
|
investor.portfolio_companies.append(company)
|
|
|
|
return investor
|
|
|
|
def _save_company_to_db(
|
|
self, db: Session, company_data: CompanyData, skip_investors: bool = False
|
|
) -> CompanyTable:
|
|
"""Save company data to database"""
|
|
# Check if company already exists
|
|
existing_company = (
|
|
db.query(CompanyTable)
|
|
.filter(CompanyTable.name == company_data.company.name)
|
|
.first()
|
|
)
|
|
if existing_company:
|
|
return existing_company
|
|
|
|
# Create company record
|
|
company = CompanyTable(
|
|
name=company_data.company.name,
|
|
industry=company_data.company.industry,
|
|
location=company_data.company.location,
|
|
description=company_data.company.description,
|
|
founded_year=company_data.company.founded_year,
|
|
website=company_data.company.website,
|
|
)
|
|
db.add(company)
|
|
db.flush() # Get the ID
|
|
|
|
# Add company members
|
|
for member_data in company_data.members:
|
|
if member_data.name: # Only add members with names
|
|
member = CompanyMember(
|
|
name=member_data.name,
|
|
linkedin=member_data.linkedin,
|
|
role=member_data.role,
|
|
company_id=company.id,
|
|
)
|
|
db.add(member)
|
|
|
|
# Add sectors
|
|
for sector_data in company_data.sectors:
|
|
sector = self._get_or_create_sector(db, sector_data.name)
|
|
company.sectors.append(sector)
|
|
|
|
# Add investors (if not skipping to avoid circular references)
|
|
if not skip_investors:
|
|
for investor_data in company_data.investors:
|
|
# Look for existing investor by name
|
|
existing_investor = (
|
|
db.query(InvestorTable)
|
|
.filter(InvestorTable.name == investor_data.name)
|
|
.first()
|
|
)
|
|
if existing_investor:
|
|
company.investors.append(existing_investor)
|
|
|
|
return company
|
|
|
|
async def _process_row(
|
|
self, row: pd.Series, row_idx: int, is_investor: bool = True
|
|
) -> Optional[InvestorData | CompanyData]:
|
|
"""Process a single row of data"""
|
|
# Clean values to remove control characters
|
|
cleaned_row = {}
|
|
for key, value in row.items():
|
|
if pd.notna(value):
|
|
# Convert to string and clean control characters
|
|
clean_value = (
|
|
str(value).replace("\n", " ").replace("\r", " ").replace("\t", " ")
|
|
)
|
|
# Remove other control characters
|
|
clean_value = "".join(
|
|
char
|
|
for char in clean_value
|
|
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
|
)
|
|
cleaned_row[key] = clean_value
|
|
|
|
row_str = ", ".join([f"{key}: {value}" for key, value in cleaned_row.items()])
|
|
try:
|
|
print(f"Processing row {row_idx + 1}...")
|
|
if is_investor:
|
|
result = await self.investor_structured_llm.ainvoke(row_str)
|
|
else:
|
|
result = await self.company_structured_llm.ainvoke(row_str)
|
|
if result:
|
|
return result.model_dump()
|
|
return None
|
|
except Exception as e:
|
|
print(f"Error processing row {row_idx + 1}: {e}")
|
|
return None
|
|
|
|
async def parse_investors(self, df, save_to_db: bool = True):
|
|
"""Parse investors from DataFrame and optionally save to database"""
|
|
investors = []
|
|
df = df[20:]
|
|
db = None
|
|
if save_to_db:
|
|
db = get_db_session()
|
|
|
|
try:
|
|
# Process rows in batches asynchronously
|
|
batch_size = 20 # Adjust batch size as needed
|
|
rows = [(idx, row) for idx, row in df.iterrows()]
|
|
|
|
for i in range(0, len(rows), batch_size):
|
|
batch = rows[i : i + batch_size]
|
|
|
|
# Process batch asynchronously
|
|
tasks = [
|
|
self._process_row(row, idx, is_investor=True) for idx, row in batch
|
|
]
|
|
|
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Handle results from batch
|
|
for (idx, row), result in zip(batch, batch_results):
|
|
if isinstance(result, Exception):
|
|
print(f"Error processing row {idx}: {result}")
|
|
if db:
|
|
db.rollback()
|
|
continue
|
|
|
|
if result:
|
|
# Convert dict to InvestorData if needed
|
|
if isinstance(result, dict):
|
|
investor_data = InvestorData(**result)
|
|
else:
|
|
investor_data = result
|
|
|
|
investors.append(investor_data)
|
|
|
|
# Save to database if requested
|
|
if save_to_db and db:
|
|
try:
|
|
saved_investor = self._save_investor_to_db(
|
|
db, investor_data
|
|
)
|
|
db.commit()
|
|
print(
|
|
f"✅ Saved investor '{saved_investor.name}' to database"
|
|
)
|
|
except Exception as e:
|
|
db.rollback()
|
|
print(f"❌ Failed to save investor to database: {e}")
|
|
|
|
print(
|
|
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Error in batch processing: {e}")
|
|
if db:
|
|
db.rollback()
|
|
finally:
|
|
if db:
|
|
db.close()
|
|
|
|
return investors
|
|
|
|
async def parse_companies(self, df, save_to_db: bool = True):
|
|
"""Parse companies from DataFrame and optionally save to database"""
|
|
companies = []
|
|
df = df[20:]
|
|
db = None
|
|
if save_to_db:
|
|
db = get_db_session()
|
|
|
|
try:
|
|
# Process rows in batches asynchronously
|
|
batch_size = 20 # Adjust batch size as needed
|
|
rows = [(idx, row) for idx, row in df.iterrows()]
|
|
|
|
for i in range(0, len(rows), batch_size):
|
|
batch = rows[i : i + batch_size]
|
|
|
|
# Process batch asynchronously
|
|
tasks = [
|
|
self._process_row(row, idx, is_investor=False) for idx, row in batch
|
|
]
|
|
|
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Handle results from batch
|
|
for (idx, row), result in zip(batch, batch_results):
|
|
if isinstance(result, Exception):
|
|
print(f"Error processing row {idx}: {result}")
|
|
if db:
|
|
db.rollback()
|
|
continue
|
|
|
|
if result:
|
|
# Convert dict to CompanyData if needed
|
|
if isinstance(result, dict):
|
|
company_data = CompanyData(**result)
|
|
else:
|
|
company_data = result
|
|
|
|
companies.append(company_data)
|
|
|
|
# Save to database if requested
|
|
if save_to_db and db:
|
|
try:
|
|
saved_company = self._save_company_to_db(
|
|
db, company_data
|
|
)
|
|
db.commit()
|
|
print(
|
|
f"✅ Saved company '{saved_company.name}' to database"
|
|
)
|
|
except Exception as e:
|
|
db.rollback()
|
|
print(f"❌ Failed to save company to database: {e}")
|
|
|
|
print(
|
|
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing row {idx}: {e}")
|
|
if db:
|
|
db.rollback()
|
|
finally:
|
|
if db:
|
|
db.close()
|
|
|
|
return companies
|
|
|
|
|
|
# async def main():
|
|
# """Main execution function"""
|
|
# # Initialize database tables
|
|
# print("🔧 Initializing database...")
|
|
# init_database()
|
|
|
|
# # Create processor
|
|
# processor = InvestorProcessor()
|
|
|
|
# print("📊 Processing companies...")
|
|
# companies = await processor.parse_companies(
|
|
# "data/19 Companies data.csv", save_to_db=True
|
|
# )
|
|
# print(f"Processed {len(companies)} companies")
|
|
|
|
# print("\n💰 Processing investors...")
|
|
# investors = await processor.parse_investors(
|
|
# "data/19 Investors data.csv", save_to_db=True
|
|
# )
|
|
# print(f"Processed {len(investors)} investors")
|
|
# print("\n✨ Processing complete!")
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# asyncio.run(main())
|