import asyncio from typing import List, Optional import chromadb import pandas as pd from db.tables import InvestorTable from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from pydantic_schemas import Investor, InvestorList from settings import settings # Add these imports for your databases # from sqlalchemy.ext.asyncio import AsyncSession # from your_vector_db import VectorDBClient class InvestorProcessor: def __init__( self, sql_session: Optional[object] = None, vector_db_client: Optional[object] = None, ): self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records. Given the following CSV data rows: {question} For each row, extract and structure the following fields: - name: The investor's full name - aum: Assets under management (as integer, use 0 if not available) - check_size: Investment check size (as string) - sector_focus: Sector focus (as string) - stage_focus: Investment stage focus (as string) - region: Geographic region (as string) - investment_thesis: Investment thesis (as string) - investor_description: Description of the investor (as string) Important: - If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers) - Ensure all text fields are properly escaped and contain no control characters - Return clean, valid JSON only Return the data as a structured list of investors.""" self.prompt = PromptTemplate( template=self.template, input_variables=["question"] ) self.llm = ChatOpenAI( api_key=settings.OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1", model="openai/gpt-oss-120b:fre", temperature=0, ) self.structured_llm = self.llm.with_structured_output(InvestorList) self.sql_session = sql_session self.vector_db_client = vector_db_client self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") self.collection = self.vector_db_client.get_or_create_collection( name="investor_descriptions", metadata={ "description": "Investor descriptions and investment thesis focus" }, ) async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List: """Process a single batch of data""" # Convert batch to string representation - clean the data batch_str = "" for idx, row in batch.iterrows(): # Clean values to remove control characters cleaned_row = {} for key, value in row.items(): if pd.notna(value): # Convert to string and clean control characters clean_value = ( str(value) .replace("\n", " ") .replace("\r", " ") .replace("\t", " ") ) # Remove other control characters clean_value = "".join( char for char in clean_value if ord(char) >= 32 or char in ["\n", "\r", "\t"] ) cleaned_row[key] = clean_value row_str = ", ".join( [f"{key}: {value}" for key, value in cleaned_row.items()] ) batch_str += f"Row {idx + 1}: {row_str}\n" try: print(f"Processing batch {batch_idx + 1}...") batch_results = await self.structured_llm.ainvoke(batch_str) return batch_results.investor_list except Exception as e: print(f"Error processing batch {batch_idx + 1}: {e}") return [] async def _save_to_sql(self, investors: List[Investor]) -> None: """Save investors to SQL database""" if not self.sql_session: return # Implement SQL saving logic here for investor in investors: db_investor = InvestorTable( name=investor.name, aum=investor.aum, check_size=investor.check_size, sector_focus=investor.sector_focus, stage_focus=investor.stage_focus, region=investor.region, ) self.sql_session.add(db_investor) self.sql_session.commit() async def _save_to_vector_db(self, investors: List[Investor]) -> None: """Save investors to vector database""" if not self.vector_db_client: return documents = [] metadatas = [] ids = [] for i, investor in enumerate(investors): doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}" documents.append(doc_text) metadatas.append({"name": investor.name}) ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}") if documents: # Use add method with proper parameters self.collection.add(documents=documents, metadatas=metadatas, ids=ids) async def process_csv( self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10 ) -> List: """Process CSV data in parallel batches and save to databases""" results = [] # Create batches batches = [] for i in range(0, len(df), batch_size): batch = df.iloc[i : i + batch_size] batches.append((batch, i // batch_size)) # Process batches with concurrency control semaphore = asyncio.Semaphore(max_concurrent) async def process_with_semaphore(batch_data): batch, batch_idx = batch_data async with semaphore: return await self._process_batch(batch, batch_idx) # Execute all batches concurrently batch_results = await asyncio.gather( *[process_with_semaphore(batch_data) for batch_data in batches], return_exceptions=True, ) # Collect results, filtering out exceptions for batch_result in batch_results: if not isinstance(batch_result, Exception): results.extend(batch_result) # Save to databases if results: await self._save_to_sql(results) await self._save_to_vector_db(results) return results