diff --git a/README.md b/README.md index 6f7f306..d5890cc 100644 --- a/README.md +++ b/README.md @@ -338,7 +338,7 @@ When `--use-llm` is enabled: OPENROUTER_API_KEY=your_openrouter_api_key_here # Database Configuration (optional, defaults to SQLite) -DATABASE_URL=sqlite:///investors_2.db +DATABASE_URL=sqlite:///investors.db # FastAPI Configuration API_HOST=localhost diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 15c2f23..9dcd943 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/db/__pycache__/db.cpython-312.pyc b/app/db/__pycache__/db.cpython-312.pyc index d31bfec..51bb0e0 100644 Binary files a/app/db/__pycache__/db.cpython-312.pyc and b/app/db/__pycache__/db.cpython-312.pyc differ diff --git a/app/db/__pycache__/models.cpython-312.pyc b/app/db/__pycache__/models.cpython-312.pyc index 5076c67..e6daac0 100644 Binary files a/app/db/__pycache__/models.cpython-312.pyc and b/app/db/__pycache__/models.cpython-312.pyc differ diff --git a/app/main.py b/app/main.py index ed1caa1..2eb0264 100644 --- a/app/main.py +++ b/app/main.py @@ -6,7 +6,7 @@ from db.db import db_dependency, init_database from fastapi import FastAPI, File, UploadFile from py_schemas import InvestorList from pydantic import BaseModel -from services.openrouter import InvestorProcessor +from services.openrouter_v2 import InvestorProcessor from services.querying import QueryProcessor app = FastAPI() diff --git a/app/services/__pycache__/openrouter_v2.cpython-312.pyc b/app/services/__pycache__/openrouter_v2.cpython-312.pyc new file mode 100644 index 0000000..a386d81 Binary files /dev/null and b/app/services/__pycache__/openrouter_v2.cpython-312.pyc differ diff --git a/app/services/__pycache__/querying.cpython-312.pyc b/app/services/__pycache__/querying.cpython-312.pyc index b6c62aa..ba949b0 100644 Binary files a/app/services/__pycache__/querying.cpython-312.pyc and b/app/services/__pycache__/querying.cpython-312.pyc differ diff --git a/app/services/openrouter_v2.py b/app/services/openrouter_v2.py new file mode 100644 index 0000000..d37120d --- /dev/null +++ b/app/services/openrouter_v2.py @@ -0,0 +1,290 @@ +import asyncio +from typing import List, Optional + +import chromadb +import pandas as pd +from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +from py_schemas import InvestorData +from pydantic import BaseModel +from settings import settings + + +class InvestorOutput(BaseModel): + """Schema for LLM structured output""" + + investor_data: InvestorData + + +class InvestorProcessor: + def __init__( + self, + sql_session: Optional[object] = None, + vector_db_client: Optional[object] = None, + ): + self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a structured record. + +Given the following CSV data row: +{question} + +Extract and structure the following fields for the investor: +- name: The investor's full name +- description: Description of the investor +- aum: Assets under management (as integer, use 0 if not available) +- check_size_lower: Lower bound of investment check size (as integer) +- check_size_upper: Upper bound of investment check size (as integer) +- geographic_focus: Geographic region focus +- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage) +- number_of_investments: Number of investments made (default 0) + +Also extract related data: +- portfolio_companies: List of companies they've invested in +- team_members: List of team members with name, role, email +- sectors: List of sectors they focus on + +Important: +- If a field is not available, use appropriate defaults +- stage_focus must be one of the valid enum values +- Return clean, valid JSON only + +Return the data as a single comprehensive investor data record.""" + + self.prompt = PromptTemplate( + template=self.template, input_variables=["question"] + ) + + self.llm = ChatOpenAI( + api_key=settings.OPENROUTER_API_KEY, + base_url="https://openrouter.ai/api/v1", + model="google/gemini-2.5-flash-lite", + temperature=0, + ) + + self.structured_llm = self.llm.with_structured_output(InvestorOutput) + self.sql_session = sql_session + self.vector_db_client = vector_db_client + + self.vector_db_client = chromadb.PersistentClient(path="./chroma_db") + self.collection = self.vector_db_client.get_or_create_collection( + name="investor_descriptions", + metadata={ + "description": "Investor descriptions and investment thesis focus" + }, + ) + + async def _process_row( + self, row: pd.Series, row_idx: int + ) -> Optional[InvestorData]: + """Process a single row of data""" + # Clean values to remove control characters + cleaned_row = {} + for key, value in row.items(): + if pd.notna(value): + # Convert to string and clean control characters + clean_value = ( + str(value) + .replace("\n", " ") + .replace("\r", " ") + .replace("\t", " ") + ) + # Remove other control characters + clean_value = "".join( + char + for char in clean_value + if ord(char) >= 32 or char in ["\n", "\r", "\t"] + ) + cleaned_row[key] = clean_value + + row_str = ", ".join( + [f"{key}: {value}" for key, value in cleaned_row.items()] + ) + + try: + print(f"Processing row {row_idx + 1}...") + result = await self.structured_llm.ainvoke(row_str) + if result.investor_data: + return result.investor_data + return None + except Exception as e: + print(f"Error processing row {row_idx + 1}: {e}") + return None + + async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None: + """Save investors and related data to SQL database""" + if not self.sql_session: + return + + try: + for investor_data in investor_data_list: + # Save investor + db_investor = InvestorTable( + name=investor_data.investor.name, + description=investor_data.investor.description, + aum=investor_data.investor.aum, + check_size_lower=investor_data.investor.check_size_lower, + check_size_upper=investor_data.investor.check_size_upper, + geographic_focus=investor_data.investor.geographic_focus, + stage_focus=investor_data.investor.stage_focus, + number_of_investments=investor_data.investor.number_of_investments, + ) + self.sql_session.add(db_investor) + self.sql_session.flush() # Get the ID + + # Save sectors and create associations + for sector_data in investor_data.sectors: + # Check if sector exists, create if not + existing_sector = ( + self.sql_session.query(SectorTable) + .filter(SectorTable.name == sector_data.name) + .first() + ) + + if not existing_sector: + db_sector = SectorTable(name=sector_data.name) + self.sql_session.add(db_sector) + self.sql_session.flush() + # Add sector to investor's sectors + db_investor.sectors.append(db_sector) + else: + # Add existing sector to investor if not already there + if existing_sector not in db_investor.sectors: + db_investor.sectors.append(existing_sector) + + # Save companies and create portfolio associations + for company_data in investor_data.portfolio_companies: + # Check if company exists, create if not + existing_company = ( + self.sql_session.query(CompanyTable) + .filter(CompanyTable.name == company_data.name) + .first() + ) + + if not existing_company: + db_company = CompanyTable( + name=company_data.name, + industry=company_data.industry, + location=company_data.location, + founded_year=company_data.founded_year, + website=company_data.website, + ) + self.sql_session.add(db_company) + self.sql_session.flush() + + # Add to investor's portfolio + db_investor.portfolio_companies.append(db_company) + else: + # Add existing company to portfolio if not already there + if existing_company not in db_investor.portfolio_companies: + db_investor.portfolio_companies.append(existing_company) + + # Save team members + for team_member_data in investor_data.team_members: + # Check if team member exists + existing_member = ( + self.sql_session.query(InvestorTeamMember) + .filter(InvestorTeamMember.email == team_member_data.email) + .first() + ) + + if not existing_member: + db_team_member = InvestorTeamMember( + name=team_member_data.name, + role=team_member_data.role, + email=team_member_data.email, + investor_id=db_investor.id, + ) + self.sql_session.add(db_team_member) + + self.sql_session.commit() + print(f"Successfully saved {len(investor_data_list)} investors to database") + + except Exception as e: + self.sql_session.rollback() + print(f"Error saving to SQL database: {e}") + raise + + async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None: + """Save investors to vector database""" + if not self.vector_db_client: + return + + documents = [] + metadatas = [] + ids = [] + + for i, investor_data in enumerate(investor_data_list): + investor = investor_data.investor + sectors = ", ".join([s.name for s in investor_data.sectors]) + companies = ", ".join([c.name for c in investor_data.portfolio_companies]) + + doc_text = f""" + Investor: {investor.name} + Description: {investor.description or "N/A"} + AUM: ${investor.aum:,} + Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,} + Geographic Focus: {investor.geographic_focus} + Stage Focus: {investor.stage_focus.value} + Sectors: {sectors} + Portfolio Companies: {companies} + """.strip() + + documents.append(doc_text) + metadatas.append( + { + "name": investor.name, + "stage_focus": investor.stage_focus.value, + "geographic_focus": investor.geographic_focus, + "aum": investor.aum, + } + ) + ids.append( + f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}" + ) + + if documents: + try: + self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + print( + f"Successfully saved {len(documents)} investors to vector database" + ) + except Exception as e: + print(f"Error saving to vector database: {e}") + + async def process_csv( + self, df: pd.DataFrame, max_concurrent: int = 10 + ) -> List[InvestorData]: + """Process CSV data one row at a time and save to databases""" + results = [] + + # Create semaphore for concurrency control + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_row_with_semaphore(row_data): + row, row_idx = row_data + async with semaphore: + return await self._process_row(row, row_idx) + + # Create row tasks + row_tasks = [] + for idx, row in df.iterrows(): + row_tasks.append((row, idx)) + + # Execute all rows concurrently + row_results = await asyncio.gather( + *[process_row_with_semaphore(row_data) for row_data in row_tasks], + return_exceptions=True, + ) + + # Collect results, filtering out exceptions and None values + for row_result in row_results: + if not isinstance(row_result, Exception) and row_result is not None: + results.append(row_result) + + # Save to databases + if results: + print(f"Successfully processed {len(results)} investors") + await self._save_to_sql(results) + await self._save_to_vector_db(results) + + return results diff --git a/app/services/querying.py b/app/services/querying.py index da1c6a3..e76f94f 100644 --- a/app/services/querying.py +++ b/app/services/querying.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import selectinload # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") -db = SQLDatabase.from_uri("sqlite:///investors_2.db") +db = SQLDatabase.from_uri("sqlite:///investors.db") system_message = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n Get answers from the Sql database and the vector database"