made version 2

This commit is contained in:
bolade
2025-09-25 17:00:38 +01:00
parent b1b1c5ea1e
commit 0f7beca5e1
42 changed files with 660 additions and 2036 deletions
-577
View File
@@ -1,577 +0,0 @@
# LLM-Powered Investor & Company Management API
A comprehensive FastAPI-based system for managing investor and company data with LLM-powered CSV parsing, semantic search, and advanced filtering capabilities.
## Features
- **FastAPI REST API**: Modern, auto-documented API with OpenAPI/Swagger support
- **CSV Data Processing**: Parse complex investor data from CSV files using LLM assistance
- **Dual Database Storage**: Structured data in SQL database and semantic search via ChromaDB
- **Natural Language Queries**: AI-powered query processing for complex investor searches
- **Advanced Filtering**: Filter investors and companies by multiple criteria
- **Relationship Management**: Many-to-many relationships between investors, companies, and sectors
- **Auto-Generated Documentation**: Interactive API docs at `/docs`
## Architecture
### Components
1. **FastAPI Application (`app/main.py`)**: Main API server with route configuration
2. **Database Models (`app/db/models.py`)**: SQLAlchemy models for investors, companies, sectors
3. **Pydantic Schemas (`app/py_schemas.py`)**: Request/response validation and serialization
4. **API Routes**:
- `app/api/investors.py`: Investor CRUD operations and filtering
- `app/api/companies.py`: Company CRUD operations and filtering
5. **Services**:
- `app/services/openrouter.py`: LLM-powered CSV processing
- `app/services/querying.py`: Natural language query processing
6. **Database (`app/db/`)**: Database connection, models, and schemas
### Data Flow
```
CSV Upload → LLM Processing → Data Extraction → SQL Storage → Vector Storage → API Endpoints
Natural Language Query → AI Analysis → Database Filtering → Structured Response
```
## Installation
### Prerequisites
- Python 3.12+
- FastAPI and dependencies
### Setup
1. Clone the repository and navigate to the project directory:
```bash
cd /path/to/anton_wireframe
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Configure environment variables:
```bash
cp .env.example .env
# Edit .env and add your OpenRouter API key for LLM features
```
4. Initialize the database:
```bash
cd app
python -c "from db.db import init_database; init_database()"
```
5. Start the API server:
```bash
cd app
uvicorn main:app --reload --host localhost --port 8000
```
The API will be available at:
- **API Base**: http://localhost:8000
- **Interactive Docs**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
## Database Schema
### SQL Database (SQLite)
#### Investors Table
- **Basic Info**: name, description, geographic_focus
- **Investment Data**: aum, check_size_lower, check_size_upper
- **Stage Focus**: investment stage (SEED, SERIES_A, etc.)
- **Relationships**: Many-to-many with companies and sectors
- **Team**: One-to-many with team members
- **Metadata**: created_at, updated_at timestamps
#### Companies Table
- **Basic Info**: name, industry, location
- **Details**: founded_year, website
- **Relationships**: Many-to-many with investors
- **Metadata**: created_at, updated_at timestamps
#### Association Tables
- **investor_companies**: Links investors to their portfolio companies
- **investor_sectors**: Links investors to their focus sectors
- **investor_team**: Team member details for each investor
#### Supporting Tables
- **sectors**: Investment focus areas (fintech, healthcare, etc.)
### Vector Database (ChromaDB)
Stores embeddings for semantic search of:
- Investor descriptions
- Investment thesis focus areas
- Combined investor profiles
## API Usage
### Interactive Documentation
Visit http://localhost:8000/docs for the auto-generated Swagger UI where you can:
- Explore all endpoints
- Test API calls directly
- View request/response schemas
- See example requests
### Core Endpoints
#### Investor Management
```bash
# Get all investors with relationships
GET /investors
# Filter investors by criteria
GET /investors/filter?stage=GROWTH&geography=US&sector=fintech&min_check_size=1000000
# Get specific investor
GET /investors/{investor_id}
# Create new investor
POST /investors
{
"name": "Example VC",
"description": "Early stage fintech investor",
"aum": 50000000,
"check_size_lower": 100000,
"check_size_upper": 2000000,
"geographic_focus": "US",
"stage_focus": "SEED",
"number_of_investments": 25
}
# Update investor
PUT /investors/{investor_id}
# Delete investor
DELETE /investors/{investor_id}
```
#### Company Management
```bash
# Get all companies with investor relationships
GET /companies
# Filter companies by criteria
GET /companies/filter?industry=fintech&location=San Francisco&founded_after=2015
# Get specific company
GET /companies/{company_id}
# Create new company
POST /companies
{
"name": "Example Startup",
"industry": "fintech",
"location": "San Francisco",
"founded_year": 2020,
"website": "https://example.com"
}
# Update company
PUT /companies/{company_id}
# Delete company
DELETE /companies/{company_id}
```
#### CSV Processing
```bash
# Upload and process CSV file
POST /parse-csv
Content-Type: multipart/form-data
File: investors.csv
```
#### Natural Language Queries
```bash
# Query investors using natural language
POST /query
{
"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $1 million"
}
```
### Advanced Filtering Examples
#### Investor Filters
```bash
# Early stage investors in Europe
GET /investors/filter?stage=SEED&geography=Europe
# High AUM growth investors
GET /investors/filter?stage=GROWTH&min_aum=100000000
# Healthcare investors with large checks
GET /investors/filter?sector=healthcare&min_check_size=5000000
# Specific geographic focus
GET /investors/filter?geography=Silicon Valley
```
#### Company Filters
```bash
# Recent fintech companies
GET /companies/filter?industry=fintech&founded_after=2020
# Companies with websites
GET /companies/filter?has_website=true
# Companies backed by specific investor
GET /companies/filter?investor_name=Sequoia
# Location-based filtering
GET /companies/filter?location=New York
```
### Response Format
All endpoints return structured JSON with full relationship data:
```json
{
"investor": {
"id": 1,
"name": "Example VC",
"description": "Early stage investor",
"aum": 50000000,
"check_size_lower": 100000,
"check_size_upper": 2000000,
"geographic_focus": "US",
"stage_focus": "SEED",
"number_of_investments": 25
},
"portfolio_companies": [
{
"id": 1,
"name": "StartupCo",
"industry": "fintech",
"location": "San Francisco"
}
],
"team_members": [
{
"id": 1,
"name": "John Partner",
"role": "Managing Partner",
"email": "john@examplevc.com"
}
],
"sectors": [
{
"id": 1,
"name": "fintech"
}
]
}
```
## Data Processing Pipeline
### 1. CSV Parsing
- Reads CSV with pandas
- Handles nested JSON fields in columns
- Validates data with Pydantic models
### 2. JSON Field Processing
- Direct parsing for well-formed JSON
- LLM-assisted cleaning for malformed JSON (when enabled)
- Graceful fallback to empty objects
### 3. Data Extraction
Extracts key fields:
- Company name and website
- Investor description
- Investment thesis/focus areas
- Headquarters location
- Assets Under Management (AUM)
- Fund information
### 4. LLM Enhancement (Optional)
When `--use-llm` is enabled:
- Standardizes investor descriptions
- Normalizes investment focus areas
- Cleans headquarters location format
- Repairs malformed JSON data
### 5. Dual Storage
- **SQL Database**: Structured, queryable data
- **Vector Database**: Semantic search capabilities
## Configuration
### Environment Variables (.env)
```bash
# OpenRouter API Configuration (required for LLM features)
OPENROUTER_API_KEY=your_openrouter_api_key_here
# Database Configuration (optional, defaults to SQLite)
DATABASE_URL=sqlite:///investors.db
# FastAPI Configuration
API_HOST=localhost
API_PORT=8000
```
### LLM Configuration
- **Provider**: OpenRouter (supports multiple models)
- **Default Model**: google/gemini-2.5-flash-lite
- **Temperature**: 0.3 for enhancement, 0 for structured data
- **Fallback**: Graceful degradation when API unavailable
## Natural Language Query Processing
The system supports intelligent natural language queries that automatically extract filters and search criteria:
### Query Examples
```bash
# Stage-based queries
"Show me seed stage investors"
"Find growth stage VCs"
# Geographic queries
"Investors in Silicon Valley"
"European venture capital firms"
# Sector-specific queries
"Fintech investors"
"Healthcare and biotech VCs"
# Size-based queries
"Investors with $5M+ check sizes"
"High AUM growth investors"
# Combined queries
"Growth stage fintech investors in the US with check sizes over $1 million"
"European healthcare investors focusing on early stage"
```
### Query Processing Features
- **Automatic Filter Extraction**: Detects investment stages, geographies, sectors, and check sizes
- **Semantic Understanding**: Uses AI to interpret complex queries
- **Database Integration**: Combines AI analysis with efficient SQL filtering
- **Complete Relationships**: Returns full investor data with portfolio companies, team members, and sectors
### Query Response
The `/query` endpoint returns a structured `InvestorList` with complete relationship data, making it easy to get comprehensive information about matching investors.
## Error Handling
### API Error Responses
The API provides clear HTTP status codes and error messages:
```json
// 404 Not Found
{
"detail": "Investor not found"
}
// 422 Validation Error
{
"detail": [
{
"loc": ["body", "stage_focus"],
"msg": "value is not a valid enumeration member",
"type": "type_error.enum"
}
]
}
```
### Robust Processing
- **Data Validation**: Pydantic models ensure data integrity
- **Relationship Management**: Automatic handling of foreign key constraints
- **LLM Fallbacks**: Graceful degradation when AI services unavailable
- **Transaction Safety**: Database rollbacks on errors
- **Comprehensive Logging**: Detailed error tracking and debugging
### Common Issues and Solutions
1. **Invalid Enum Values**
- Solution: Use uppercase enum values (SEED, GROWTH, etc.)
- Check: Investment stages must match defined enum
2. **Missing OpenRouter API Key**
- Solution: Set OPENROUTER_API_KEY in environment
- Fallback: CSV processing continues without LLM enhancement
3. **Database Connection Issues**
- Solution: Verify DATABASE_URL configuration
- Default: Uses SQLite (no external dependencies)
4. **Relationship Errors**
- Solution: Ensure proper foreign key relationships
- Check: Use existing sector/company IDs or create new ones
## Performance
### Benchmarks (Approximate)
- **API Response Time**: <200ms for standard queries
- **Database Queries**: <50ms for filtered searches with relationships
- **CSV Processing**: ~5-15 seconds per row (depends on LLM API latency)
- **Natural Language Queries**: ~2-5 seconds (AI processing + database query)
- **Vector Search**: <100ms for semantic similarity queries
### Optimization Features
1. **Eager Loading**: Efficient relationship loading with `selectinload()`
2. **Query Optimization**: Smart filtering to reduce database load
3. **Caching**: Database connection pooling and session management
4. **Pagination**: Built-in limits to prevent overwhelming responses
5. **Async Processing**: FastAPI async capabilities for better performance
### Production Recommendations
1. **Database**: Consider PostgreSQL for production workloads
2. **Caching**: Add Redis for frequently accessed data
3. **Load Balancing**: Deploy multiple API instances behind a load balancer
4. **Monitoring**: Implement logging and metrics collection
5. **Rate Limiting**: Add API rate limiting for public endpoints
## File Structure
```
anton_wireframe/
├── app/
│ ├── main.py # FastAPI application and main endpoints
│ ├── py_schemas.py # Pydantic models for validation
│ ├── settings.py # Configuration management
│ ├── api/
│ │ ├── __init__.py
│ │ ├── investors.py # Investor CRUD and filtering endpoints
│ │ └── companies.py # Company CRUD and filtering endpoints
│ ├── db/
│ │ ├── __init__.py
│ │ ├── db.py # Database connection and session management
│ │ ├── models.py # SQLAlchemy database models
│ │ └── new_schema.py # Additional schema definitions
│ └── services/
│ ├── __init__.py
│ ├── openrouter.py # LLM-powered CSV processing
│ ├── querying.py # Natural language query processing
│ └── langgraph_agent.py # AI agent configuration
├── chroma_db/ # Vector database directory
├── requirements.txt # Python dependencies
├── README.md # This documentation
└── .env # Environment configuration
```
## Example Usage Scenarios
### 1. Upload and Process Investor Data
```bash
# Upload CSV file via API
curl -X POST "http://localhost:8000/parse-csv" \
-H "Content-Type: multipart/form-data" \
-F "file=@investors.csv"
```
### 2. Find Specific Investors
```bash
# Natural language search
curl -X POST "http://localhost:8000/query" \
-H "Content-Type: application/json" \
-d '{"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $2 million"}'
# Structured filtering
curl "http://localhost:8000/investors/filter?stage=GROWTH&sector=fintech&geography=Silicon%20Valley&min_check_size=2000000"
```
### 3. Company Research
```bash
# Find companies in specific sector
curl "http://localhost:8000/companies/filter?industry=fintech&founded_after=2020"
# Find companies backed by specific investor
curl "http://localhost:8000/companies/filter?investor_name=Sequoia"
```
### 4. Investment Analysis
```bash
# Get investor with full portfolio
curl "http://localhost:8000/investors/1"
# Find all companies in a specific location
curl "http://localhost:8000/companies/filter?location=San%20Francisco"
```
## Development
### Running in Development Mode
```bash
cd app
uvicorn main:app --reload --host localhost --port 8000
```
### Testing the API
1. **Interactive Testing**: Visit http://localhost:8000/docs
2. **Manual Testing**: Use curl or Postman with the examples above
3. **Database Inspection**: Use SQLite browser to inspect `investors_2.db`
### Adding New Features
1. **New Endpoints**: Add routes to `api/investors.py` or `api/companies.py`
2. **New Models**: Update `db/models.py` and `py_schemas.py`
3. **New Filters**: Extend filtering logic in route handlers
4. **New LLM Features**: Modify `services/openrouter.py` or `services/querying.py`
## License
This project is part of the MKD Anton Wireframe system.
## Support
For issues and questions:
1. Check logs for detailed error messages
2. Verify environment configuration
3. Test with limited datasets first
4. Review CSV data format requirements
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
-46
View File
@@ -1,46 +0,0 @@
from sqlalchemy.orm import Session
from db.models import InvestorTable
from db.db import get_db
def update_stage_focus_values():
"""Update existing stage_focus values from lowercase to uppercase"""
db = next(get_db())
try:
# Mapping of old lowercase values to new uppercase values
stage_mappings = {
'seed': 'SEED',
'series_a': 'SERIES_A',
'series_b': 'SERIES_B',
'series_c': 'SERIES_C',
'growth': 'GROWTH',
'late_stage': 'LATE_STAGE'
}
updated_count = 0
for old_value, new_value in stage_mappings.items():
# Update records with the old value
result = db.query(InvestorTable).filter(
InvestorTable.stage_focus == old_value
).update(
{InvestorTable.stage_focus: new_value},
synchronize_session=False
)
updated_count += result
print(f"Updated {result} records from '{old_value}' to '{new_value}'")
db.commit()
print(f"Successfully updated {updated_count} total records")
except Exception as e:
db.rollback()
print(f"Error updating stage_focus values: {e}")
raise
finally:
db.close()
# Run the update
if __name__ == "__main__":
update_stage_focus_values()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+5 -1
View File
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base()
# Database configuration
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db")
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
# Create engine
engine = create_engine(DATABASE_URL, echo=False)
@@ -38,3 +38,7 @@ def init_database():
def get_session_sync() -> Session:
"""Get a database session for synchronous operations"""
return SessionLocal()
def get_db_session():
"""Get a database session for direct use."""
return SessionLocal()
+54 -30
View File
@@ -1,11 +1,17 @@
import datetime
import enum
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
from sqlalchemy.orm import relationship
from db.db import Base
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
from sqlalchemy.orm import declarative_mixin, relationship
from sqlalchemy.types import Enum
from db.db import Base
@declarative_mixin
class TimestampMixin:
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class InvestmentStage(enum.Enum):
@@ -16,6 +22,7 @@ class InvestmentStage(enum.Enum):
GROWTH = "GROWTH"
LATE_STAGE = "LATE_STAGE"
# Association table for many-to-many relationship between investors and companies
investor_company_association = Table(
"investor_companies",
@@ -34,7 +41,15 @@ investor_sector_association = Table(
)
class InvestorTable(Base):
company_sector_association = Table(
"company_sector",
Base.metadata,
Column("company_id", Integer, ForeignKey("companies.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors"
id = Column(Integer, primary_key=True, index=True)
@@ -46,12 +61,6 @@ class InvestorTable(Base):
geographic_focus = Column(String, nullable=False)
stage_focus = Column(Enum(InvestmentStage), nullable=False)
number_of_investments = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
updated_at = Column(
DateTime,
default=datetime.datetime.now(datetime.UTC),
onupdate=datetime.datetime.now(datetime.UTC),
)
# Relationship to portfolio companies
portfolio_companies = relationship(
@@ -59,7 +68,7 @@ class InvestorTable(Base):
secondary=investor_company_association,
back_populates="investors",
)
team_members = relationship("InvestorTeamMember", back_populates="investor")
team_members = relationship("InvestorMember", back_populates="investor")
sectors = relationship(
"SectorTable",
secondary=investor_sector_association,
@@ -67,22 +76,29 @@ class InvestorTable(Base):
)
class CompanyTable(Base):
class InvestorMember(Base, TimestampMixin):
__tablename__ = "investor_members"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
role = Column(String, nullable=False)
email = Column(String, nullable=False)
investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members")
class CompanyTable(Base, TimestampMixin):
__tablename__ = "companies"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
industry = Column(String, nullable=False)
location = Column(String, nullable=False)
description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
updated_at = Column(
DateTime,
default=datetime.datetime.now(datetime.UTC),
onupdate=datetime.datetime.now(datetime.UTC),
)
members = relationship("CompanyMember", back_populates="company")
# Relationship back to investors
investors = relationship(
"InvestorTable",
@@ -90,8 +106,23 @@ class CompanyTable(Base):
back_populates="portfolio_companies",
)
sectors = relationship(
"SectorTable", secondary=company_sector_association, back_populates="companies"
)
class SectorTable(Base):
class CompanyMember(Base, TimestampMixin):
__tablename__ = "company_members"
id = Column(Integer, primary_key=True)
name = Column(String)
linkedin = Column(String)
role = Column(String)
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
company = relationship("CompanyTable", back_populates="members")
class SectorTable(Base, TimestampMixin):
__tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True)
@@ -104,13 +135,6 @@ class SectorTable(Base):
back_populates="sectors",
)
class InvestorTeamMember(Base):
__tablename__ = "investor_team"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
role = Column(String, nullable=False)
email = Column(String, nullable=False)
investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members")
companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
)
-115
View File
@@ -1,115 +0,0 @@
import json
from typing import List, Optional
from pydantic import BaseModel
from sqlalchemy import JSON, Column, DateTime, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
Base = declarative_base()
class Investor(Base):
__tablename__ = "investors"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(500), nullable=False)
website = Column(String(1000))
# Core investment information
investor_description = Column(Text)
investment_thesis_focus = Column(JSON) # List of focus areas
headquarters = Column(String(1000))
# AUM information
aum_amount = Column(String(200))
aum_as_of_date = Column(String(100))
aum_source_url = Column(String(1000))
# Fund information
funds_info = Column(JSON) # Complex fund data
# Raw data columns for reference
crunchbase_urls = Column(Text)
crunchbase_extract = Column(Text)
linkedin_profile = Column(Text)
source_truth_profile = Column(Text)
# Metadata
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
def __repr__(self):
return f"<Investor(name='{self.name}', website='{self.website}')>"
# Pydantic models for data validation and parsing
class AUMInfo(BaseModel):
aumAmount: Optional[str] = None
asOfDate: Optional[str] = None
sourceUrl: Optional[str] = None
class FundInfo(BaseModel):
fundName: Optional[str] = None
fundSize: Optional[str] = None
vintage: Optional[str] = None
status: Optional[str] = None
description: Optional[str] = None
class InvestorProfile(BaseModel):
websiteURL: Optional[str] = None
investorDescription: Optional[str] = None
investmentThesisFocus: Optional[List[str]] = None
headquarters: Optional[str] = None
overallAssetsUnderManagement: Optional[AUMInfo] = None
funds: Optional[List[FundInfo]] = None
class CSVRow(BaseModel):
name: str
website: Optional[str] = None
investment_firm_profile: Optional[str] = None
crunchbase_linkedin_urls: Optional[str] = None
crunchbase_firm_extract: Optional[str] = None
linkedin_investment_profile: Optional[str] = None
source_of_truth_profile: Optional[str] = None
def get_combined_description(self) -> str:
"""Combine all description fields for vector embedding"""
descriptions = []
if self.investment_firm_profile:
try:
profile_data = json.loads(self.investment_firm_profile)
if isinstance(profile_data, dict):
desc = profile_data.get("investorDescription", "")
if desc:
descriptions.append(desc)
except (json.JSONDecodeError, TypeError):
pass
if self.crunchbase_firm_extract:
descriptions.append(self.crunchbase_firm_extract)
if self.linkedin_investment_profile:
descriptions.append(self.linkedin_investment_profile)
if self.source_of_truth_profile:
descriptions.append(self.source_of_truth_profile)
return " ".join(descriptions)
def get_investment_focus(self) -> List[str]:
"""Extract investment thesis focus"""
if self.investment_firm_profile:
try:
profile_data = json.loads(self.investment_firm_profile)
if isinstance(profile_data, dict):
focus = profile_data.get("investmentThesisFocus", [])
if isinstance(focus, list):
return focus
except (json.JSONDecodeError, TypeError):
pass
return []
+18 -11
View File
@@ -1,17 +1,20 @@
import io
import pandas as pd
from api import companies, investors
from db.db import db_dependency, init_database
from fastapi import FastAPI, File, UploadFile
from py_schemas import InvestorList
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel
from services.openrouter_v2 import InvestorProcessor
from routers import companies, investors
from schemas.router_schemas import InvestorList
from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor
app = FastAPI()
load_dotenv()
init_database()
app = FastAPI()
# Request models
class QueryRequest(BaseModel):
@@ -20,7 +23,7 @@ class QueryRequest(BaseModel):
class Config:
json_schema_extra = {
"example": {
"question": "Show me growth stage fintech investors in the US with check sizes over $1 million"
"question": "Find me deep tech investors that do deals in Europe under 5 million."
}
}
@@ -31,21 +34,25 @@ def health():
@app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict])
async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
async def parse_csv(db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)):
# Read uploaded CSV with pandas
content = await file.read()
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
# Process the dataframe
processor = InvestorProcessor(sql_session=db)
results = await processor.process_csv(df)
processor = InvestorProcessor()
if is_investor == 1:
results = await processor.parse_investors(df)
else:
results = await processor.parse_companies(df)
# Convert Pydantic objects to dictionaries
return [r.model_dump() for r in results]
@app.post("/query", response_model=InvestorList, tags=["Querying"])
async def query_investors(db: db_dependency, request: QueryRequest):
async def query_investors(request: QueryRequest):
"""
Query investors using natural language.
@@ -55,7 +62,7 @@ async def query_investors(db: db_dependency, request: QueryRequest):
- "Growth stage investors with $5M+ check sizes"
- "Healthcare investors in Europe"
"""
processor = QueryProcessor(sql_session=db)
processor = QueryProcessor()
results = processor.process_query(request.question)
return results
-38
View File
@@ -1,38 +0,0 @@
from typing import List
from pydantic import BaseModel
class Investor(BaseModel):
name: str
aum: int
check_size: str
sector_focus: str
stage_focus: str
region: str
investment_thesis: str
investor_description: str
class InvestorList(BaseModel):
investor_list: List[Investor]
class QueryResponse(BaseModel):
name: str
aum: int
check_size: str
sector_focus: str
stage_focus: str
region: str
investment_thesis: str
investor_description: str
reason: str
class QueryRequest(BaseModel):
question: str
class QueryResponseList(BaseModel):
responses: List[QueryResponse]
Binary file not shown.
Binary file not shown.
@@ -3,8 +3,8 @@ from typing import List, Optional
from db.db import get_db
from db.models import CompanyTable, InvestorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from py_schemas import CompanySchema
from pydantic import BaseModel
from schemas.router_schemas import CompanyData
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Company Routes"])
@@ -15,6 +15,7 @@ class CompanyCreate(BaseModel):
name: str
industry: str
location: str
description: Optional[str] = None
founded_year: Optional[int] = None
website: Optional[str] = None
@@ -23,46 +24,33 @@ class CompanyUpdate(BaseModel):
name: Optional[str] = None
industry: Optional[str] = None
location: Optional[str] = None
description: Optional[str] = None
founded_year: Optional[int] = None
website: Optional[str] = None
# Response schema with relationships
class CompanyData(BaseModel):
"""Comprehensive company data schema"""
company: CompanySchema
investors: List["InvestorBasic"] = []
class Config:
from_attributes = True
class InvestorBasic(BaseModel):
"""Basic investor info for company responses"""
id: int
name: str
geographic_focus: str
stage_focus: str
check_size_lower: int
check_size_upper: int
class Config:
from_attributes = True
@router.get("/companies", response_model=List[CompanyData])
def read_companies(db: Session = Depends(get_db)):
"""Get all companies with their investor relationships"""
companies = (
db.query(CompanyTable).options(selectinload(CompanyTable.investors)).all()
db.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.all()
)
# Transform CompanyTable objects to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(company=company, investors=company.investors)
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
return company_data_list
@@ -89,7 +77,11 @@ def filter_companies(
"""Filter companies based on various criteria"""
# Start with base query
query = db.query(CompanyTable).options(selectinload(CompanyTable.investors))
query = db.query(CompanyTable).options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
# Apply filters
if industry:
@@ -121,7 +113,12 @@ def filter_companies(
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(company=company, investors=company.investors)
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
return company_data_list
@@ -132,7 +129,11 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
"""Get a specific company by ID with its investors"""
company = (
db.query(CompanyTable)
.options(selectinload(CompanyTable.investors))
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == company_id)
.first()
)
@@ -141,7 +142,12 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
raise HTTPException(status_code=404, detail="Company not found")
# Transform to CompanyData format
return CompanyData(company=company, investors=company.investors)
return CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
@router.post("/companies", response_model=CompanyData)
@@ -155,14 +161,21 @@ def create_company(company: CompanyCreate, db: Session = Depends(get_db)):
# Reload with relationships
company_with_relations = (
db.query(CompanyTable)
.options(selectinload(CompanyTable.investors))
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == db_company.id)
.first()
)
# Transform to CompanyData format
return CompanyData(
company=company_with_relations, investors=company_with_relations.investors
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
)
@@ -185,14 +198,21 @@ def update_company(
# Reload with relationships
company_with_relations = (
db.query(CompanyTable)
.options(selectinload(CompanyTable.investors))
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == company_id)
.first()
)
# Transform to CompanyData format
return CompanyData(
company=company_with_relations, investors=company_with_relations.investors
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
)
@@ -1,9 +1,10 @@
from typing import List, Optional
from db.db import get_db
from db.models import InvestorTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from py_schemas import InvestmentStage, InvestorData
from schemas.router_schemas import InvestmentStage, InvestorData
from pydantic import BaseModel
from sqlalchemy.orm import Session, selectinload
@@ -13,7 +14,7 @@ router = APIRouter(tags=["Investor Routes"])
# Request schemas for creating/updating
class InvestorCreate(BaseModel):
name: str
description: str = None
description: Optional[str] = None
aum: int
check_size_lower: int
check_size_upper: int
@@ -23,14 +24,14 @@ class InvestorCreate(BaseModel):
class InvestorUpdate(BaseModel):
name: str = None
description: str = None
aum: int = None
check_size_lower: int = None
check_size_upper: int = None
geographic_focus: str = None
stage_focus: InvestmentStage = None
number_of_investments: int = None
name: Optional[str] = None
description: Optional[str] = None
aum: Optional[int] = None
check_size_lower: Optional[int] = None
check_size_upper: Optional[int] = None
geographic_focus: Optional[str] = None
stage_focus: Optional[InvestmentStage] = None
number_of_investments: Optional[int] = None
@router.get("/investors", response_model=List[InvestorData])
Binary file not shown.
Binary file not shown.
+111
View File
@@ -0,0 +1,111 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, field_validator
class InvestmentStage(str, Enum):
SEED = "SEED"
SERIES_A = "SERIES_A"
SERIES_B = "SERIES_B"
SERIES_C = "SERIES_C"
GROWTH = "GROWTH"
LATE_STAGE = "LATE_STAGE"
class SectorSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class InvestorMemberSchema(BaseModel):
id: int
name: str
role: str
email: str
investor_id: int
class Config:
from_attributes = True
class CompanyMemberSchema(BaseModel):
id: int
name: Optional[str] = None
linkedin: Optional[str] = None
role: Optional[str] = None
company_id: int
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str
location: str
description: Optional[str] = None # Fixed typo from 'nullabel'
founded_year: Optional[int] = None # Changed from str to int to match model
website: Optional[str] = None
@field_validator("founded_year", mode="before")
@classmethod
def validate_founded_year(cls, v):
if v is None or v == "Not Available" or v == "":
return None
if isinstance(v, str):
try:
return int(v)
except ValueError:
return None
return v
class Config:
from_attributes = True
class InvestorSchema(BaseModel):
id: int
name: str
description: Optional[str] = None
aum: int
check_size_lower: int
check_size_upper: int
geographic_focus: str
stage_focus: InvestmentStage
number_of_investments: int = 0
class Config:
from_attributes = True
class InvestorData(BaseModel):
"""Comprehensive investor data schema for LLM processing"""
investor: InvestorSchema
portfolio_companies: List[CompanySchema] = []
team_members: List[InvestorMemberSchema] = [] # Changed from TeamMember
sectors: List[SectorSchema] = []
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchema
sectors: List[SectorSchema] = []
members: List[CompanyMemberSchema] = [] # Changed to match model relationship name
investors: List[InvestorSchema] = []
class Config:
from_attributes = True
class InvestorList(BaseModel):
investors: List[InvestorData] = []
@@ -22,11 +22,31 @@ class SectorSchema(BaseModel):
from_attributes = True
class InvestorMemberSchema(BaseModel):
id: int
name: str
role: str
email: str
class Config:
from_attributes = True
class CompanyMemberSchema(BaseModel):
id: int
name: Optional[str] = None
linkedin: Optional[str] = None
role: Optional[str] = None
company_id: int
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str
location: str
description: Optional[str]
founded_year: Optional[int]
website: Optional[str]
created_at: Optional[datetime]
@@ -36,15 +56,6 @@ class CompanySchema(BaseModel):
from_attributes = True
class InvestorTeamMemberSchema(BaseModel):
id: int
name: str
role: str
email: str
class Config:
from_attributes = True
class InvestorSchema(BaseModel):
id: int
@@ -67,13 +78,22 @@ class InvestorData(BaseModel):
"""Comprehensive investor data schema for LLM processing"""
investor: InvestorSchema
portfolio_companies: List[CompanySchema] = []
team_members: List[InvestorTeamMemberSchema] = []
sectors: List[SectorSchema] = []
portfolio_companies: List[CompanySchema]
team_members: List[InvestorMemberSchema]
sectors: List[SectorSchema]
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchema
sectors: List[SectorSchema]
members: List[CompanyMemberSchema]
investors: List[InvestorSchema]
class Config:
from_attributes = True
class InvestorList(BaseModel):
investors: List[InvestorData]
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+305 -336
View File
@@ -1,368 +1,337 @@
import json
import logging
import asyncio
import os
from typing import Any, Dict, Optional
from typing import Optional
import chromadb
import pandas as pd
from dotenv import load_dotenv
from openai import OpenAI
from db import get_session, init_database
from py_schemas import CSVRow, Investor
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLMInvestorParser:
def __init__(self):
# Initialize OpenAI client
self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# Initialize ChromaDB
self.chroma_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.chroma_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
from db.db import get_db_session
from db.models import (
CompanyMember,
CompanyTable,
InvestorMember,
InvestorTable,
SectorTable,
)
from langchain_openai import ChatOpenAI
from schemas.py_schemas import CompanyData, InvestorData
from sqlalchemy.orm import Session
# Initialize database
init_database()
def parse_json_field(self, json_str: str) -> Dict[str, Any]:
"""Safely parse JSON string with LLM assistance if needed"""
if not json_str or json_str.strip() == "":
return {}
try:
# Try direct JSON parsing first
return json.loads(json_str)
except json.JSONDecodeError:
# If direct parsing fails, use LLM to clean and parse
logger.info("Direct JSON parsing failed, using LLM to clean JSON")
return self._llm_clean_json(json_str)
def _llm_clean_json(self, malformed_json: str) -> Dict[str, Any]:
"""Use LLM to clean and parse malformed JSON"""
try:
prompt = f"""
The following text appears to be malformed JSON. Please clean it up and return valid JSON.
If it's not possible to create valid JSON, return an empty object {{}}.
Original text:
{malformed_json[:2000]} # Limit length for API
Return only the cleaned JSON, no explanations:
"""
response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
class InvestorProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-5-nano",
temperature=0,
)
cleaned_json = response.choices[0].message.content.strip()
return json.loads(cleaned_json)
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
except Exception as e:
logger.error(f"LLM JSON cleaning failed: {e}")
return {}
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
"""Get existing sector or create new one"""
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
if not sector:
sector = SectorTable(name=sector_name)
db.add(sector)
db.flush() # Get the ID without committing
return sector
def extract_structured_data(self, csv_row: CSVRow) -> Dict[str, Any]:
"""Extract and structure data from CSV row using LLM"""
# Parse the investment firm profile
profile_data = {}
if csv_row.investment_firm_profile:
profile_data = self.parse_json_field(csv_row.investment_firm_profile)
# Create structured output
structured_data = {
"name": csv_row.name,
"website": csv_row.website or profile_data.get("websiteURL"),
"investor_description": profile_data.get("investorDescription", ""),
"investment_thesis_focus": profile_data.get("investmentThesisFocus", []),
"headquarters": profile_data.get("headquarters", ""),
"aum_info": profile_data.get("overallAssetsUnderManagement", {}),
"funds_info": profile_data.get("funds", []),
"crunchbase_urls": csv_row.crunchbase_linkedin_urls or "",
"crunchbase_extract": csv_row.crunchbase_firm_extract or "",
"linkedin_profile": csv_row.linkedin_investment_profile or "",
"source_truth_profile": csv_row.source_of_truth_profile or "",
}
return structured_data
def enhance_with_llm(self, investor_data: Dict[str, Any]) -> Dict[str, Any]:
"""Use LLM to enhance and standardize investor data"""
try:
# Combine all available text for context
context_text = " ".join(
[
investor_data.get("investor_description", ""),
investor_data.get("crunchbase_extract", ""),
investor_data.get("linkedin_profile", ""),
investor_data.get("source_truth_profile", ""),
]
def _save_investor_to_db(
self, db: Session, investor_data: InvestorData
) -> InvestorTable:
"""Save investor data to database"""
# Create investor record
investor = InvestorTable(
name=investor_data.investor.name,
description=investor_data.investor.description,
aum=investor_data.investor.aum,
check_size_lower=investor_data.investor.check_size_lower,
check_size_upper=investor_data.investor.check_size_upper,
geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
)
db.add(investor)
db.flush() # Get the ID
if not context_text.strip():
return investor_data
prompt = f"""
Based on the following information about an investor, please extract and standardize:
1. A concise investor description (2-3 sentences)
2. Investment thesis focus areas (list of specific focus areas)
3. Headquarters location (city, country format)
Investor: {investor_data["name"]}
Context: {context_text[:3000]} # Limit for API
Return in JSON format:
{{
"enhanced_description": "concise description here",
"standardized_focus": ["focus area 1", "focus area 2", ...],
"standardized_headquarters": "City, Country"
}}
"""
response = self.openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
# Add team members
for member_data in investor_data.team_members:
member = InvestorMember(
name=member_data.name,
role=member_data.role,
email=member_data.email,
investor_id=investor.id,
)
db.add(member)
enhanced_data = json.loads(response.choices[0].message.content)
# Add sectors
for sector_data in investor_data.sectors:
sector = self._get_or_create_sector(db, sector_data.name)
investor.sectors.append(sector)
# Update investor data with enhanced information
if enhanced_data.get("enhanced_description"):
investor_data["enhanced_description"] = enhanced_data[
"enhanced_description"
]
# Add portfolio companies
for company_schema in investor_data.portfolio_companies:
# Convert CompanySchema to CompanyData format
company_data = CompanyData(
company=company_schema,
sectors=[], # Will be empty for portfolio companies
members=[], # Will be empty for portfolio companies
investors=[], # Will be empty for portfolio companies
)
company = self._save_company_to_db(db, company_data, skip_investors=True)
investor.portfolio_companies.append(company)
if enhanced_data.get("standardized_focus"):
investor_data["standardized_focus"] = enhanced_data[
"standardized_focus"
]
return investor
if enhanced_data.get("standardized_headquarters"):
investor_data["standardized_headquarters"] = enhanced_data[
"standardized_headquarters"
]
return investor_data
except Exception as e:
logger.error(f"LLM enhancement failed for {investor_data['name']}: {e}")
return investor_data
def save_to_sql(self, investor_data: Dict[str, Any]) -> int:
"""Save investor data to SQL database"""
try:
with get_session() as session:
# Check if investor already exists
existing = (
session.query(Investor)
.filter_by(name=investor_data["name"])
def _save_company_to_db(
self, db: Session, company_data: CompanyData, skip_investors: bool = False
) -> CompanyTable:
"""Save company data to database"""
# Check if company already exists
existing_company = (
db.query(CompanyTable)
.filter(CompanyTable.name == company_data.company.name)
.first()
)
if existing_company:
return existing_company
if existing:
logger.info(f"Updating existing investor: {investor_data['name']}")
investor = existing
# Create company record
company = CompanyTable(
name=company_data.company.name,
industry=company_data.company.industry,
location=company_data.company.location,
description=company_data.company.description,
founded_year=company_data.company.founded_year,
website=company_data.company.website,
)
db.add(company)
db.flush() # Get the ID
# Add company members
for member_data in company_data.members:
if member_data.name: # Only add members with names
member = CompanyMember(
name=member_data.name,
linkedin=member_data.linkedin,
role=member_data.role,
company_id=company.id,
)
db.add(member)
# Add sectors
for sector_data in company_data.sectors:
sector = self._get_or_create_sector(db, sector_data.name)
company.sectors.append(sector)
# Add investors (if not skipping to avoid circular references)
if not skip_investors:
for investor_data in company_data.investors:
# Look for existing investor by name
existing_investor = (
db.query(InvestorTable)
.filter(InvestorTable.name == investor_data.name)
.first()
)
if existing_investor:
company.investors.append(existing_investor)
return company
async def _process_row(
self, row: pd.Series, row_idx: int, is_investor: bool = True
) -> Optional[InvestorData | CompanyData]:
"""Process a single row of data"""
# Clean values to remove control characters
cleaned_row = {}
for key, value in row.items():
if pd.notna(value):
# Convert to string and clean control characters
clean_value = (
str(value).replace("\n", " ").replace("\r", " ").replace("\t", " ")
)
# Remove other control characters
clean_value = "".join(
char
for char in clean_value
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
)
cleaned_row[key] = clean_value
row_str = ", ".join([f"{key}: {value}" for key, value in cleaned_row.items()])
try:
print(f"Processing row {row_idx + 1}...")
if is_investor:
result = await self.investor_structured_llm.ainvoke(row_str)
else:
logger.info(f"Creating new investor: {investor_data['name']}")
investor = Investor()
# Map data to investor object
investor.name = investor_data["name"]
investor.website = investor_data.get("website")
investor.investor_description = investor_data.get(
"enhanced_description"
) or investor_data.get("investor_description")
investor.investment_thesis_focus = investor_data.get(
"standardized_focus"
) or investor_data.get("investment_thesis_focus")
investor.headquarters = investor_data.get(
"standardized_headquarters"
) or investor_data.get("headquarters")
# AUM information
aum_info = investor_data.get("aum_info", {})
investor.aum_amount = aum_info.get("aumAmount")
investor.aum_as_of_date = aum_info.get("asOfDate")
investor.aum_source_url = aum_info.get("sourceUrl")
# Fund information
investor.funds_info = investor_data.get("funds_info", [])
# Raw data
investor.crunchbase_urls = investor_data.get("crunchbase_urls")
investor.crunchbase_extract = investor_data.get("crunchbase_extract")
investor.linkedin_profile = investor_data.get("linkedin_profile")
investor.source_truth_profile = investor_data.get(
"source_truth_profile"
)
if not existing:
session.add(investor)
session.flush() # Get the ID
return investor.id
result = await self.company_structured_llm.ainvoke(row_str)
if result:
return result.model_dump()
return None
except Exception as e:
logger.error(f"Failed to save to SQL: {e}")
raise
def save_to_vector_db(self, investor_id: int, investor_data: Dict[str, Any]):
"""Save investor description and focus to ChromaDB"""
try:
# Prepare text for embedding
description_text = investor_data.get(
"enhanced_description"
) or investor_data.get("investor_description", "")
focus_areas = investor_data.get("standardized_focus") or investor_data.get(
"investment_thesis_focus", []
)
if isinstance(focus_areas, list):
focus_text = " ".join(focus_areas)
else:
focus_text = str(focus_areas)
# Combine description and focus for embedding
combined_text = f"{description_text} {focus_text}".strip()
if not combined_text:
logger.warning(f"No text to embed for investor {investor_data['name']}")
return
# Create metadata
metadata = {
"investor_id": investor_id,
"name": investor_data["name"],
"website": investor_data.get("website", ""),
"headquarters": investor_data.get("standardized_headquarters")
or investor_data.get("headquarters", ""),
"focus_areas_count": len(focus_areas)
if isinstance(focus_areas, list)
else 0,
}
# Add to ChromaDB
self.collection.add(
documents=[combined_text],
metadatas=[metadata],
ids=[f"investor_{investor_id}"],
)
logger.info(f"Added investor {investor_data['name']} to vector database")
except Exception as e:
logger.error(f"Failed to save to vector DB: {e}")
def process_csv_file(self, csv_file_path: str, limit: Optional[int] = None):
"""Process the entire CSV file"""
logger.info(f"Starting to process CSV file: {csv_file_path}")
# Read CSV
df = pd.read_csv(csv_file_path)
logger.info(f"Loaded {len(df)} rows from CSV")
if limit:
df = df.head(limit)
logger.info(f"Processing limited to {limit} rows")
processed_count = 0
error_count = 0
for index, row in df.iterrows():
try:
logger.info(f"Processing row {index + 1}/{len(df)}: {row['Name']}")
# Create CSVRow object
csv_row = CSVRow(
name=row["Name"],
website=row.get("Website"),
investment_firm_profile=row.get("Investment Firm Profile"),
crunchbase_linkedin_urls=row.get("Crunchbase & LinkedIn URLs"),
crunchbase_firm_extract=row.get("Crunchbase Firm Extract"),
linkedin_investment_profile=row.get("LinkedIn Investment Profile"),
source_of_truth_profile=row.get("Source of Truth Profile"),
)
# Extract structured data
structured_data = self.extract_structured_data(csv_row)
# Enhance with LLM
enhanced_data = self.enhance_with_llm(structured_data)
# Save to SQL database
investor_id = self.save_to_sql(enhanced_data)
# Save to vector database
self.save_to_vector_db(investor_id, enhanced_data)
processed_count += 1
# Progress update every 10 rows
if (index + 1) % 10 == 0:
logger.info(
f"Processed {processed_count} rows successfully, {error_count} errors"
)
except Exception as e:
error_count += 1
logger.error(
f"Error processing row {index + 1} ({row.get('Name', 'Unknown')}): {e}"
)
continue
logger.info(
f"Processing complete! Processed: {processed_count}, Errors: {error_count}"
)
return processed_count, error_count
def search_investors(self, query: str, limit: int = 5):
"""Search investors using vector similarity"""
try:
results = self.collection.query(query_texts=[query], n_results=limit)
return results
except Exception as e:
logger.error(f"Search failed: {e}")
print(f"Error processing row {row_idx + 1}: {e}")
return None
async def parse_investors(self, df, save_to_db: bool = True):
"""Parse investors from DataFrame and optionally save to database"""
investors = []
def main():
"""Main function to run the parser"""
parser = LLMInvestorParser()
db = None
if save_to_db:
db = get_db_session()
# Process the CSV file
csv_file = "/home/oluwasanmi/Documents/Work/MKD/anton_wireframe/New Excerpt 5 investors - Sheet1 parse.csv"
try:
# Process rows in batches asynchronously
batch_size = 15 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()]
# Start with a small sample for testing
processed, errors = parser.process_csv_file(csv_file, limit=5)
for i in range(0, len(rows), batch_size):
batch = rows[i : i + batch_size]
print("\nProcessing complete!")
print(f"Successfully processed: {processed} investors")
print(f"Errors encountered: {errors}")
# Process batch asynchronously
tasks = [
self._process_row(row, idx, is_investor=True) for idx, row in batch
]
# Test search functionality
print("\nTesting search functionality...")
results = parser.search_investors("bioeconomy circular economy")
if results:
print(f"Found {len(results['documents'][0])} similar investors")
for i, doc in enumerate(results["documents"][0]):
print(f" {i + 1}. {results['metadatas'][0][i]['name']}")
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle results from batch
for (idx, row), result in zip(batch, batch_results):
if isinstance(result, Exception):
print(f"Error processing row {idx}: {result}")
if db:
db.rollback()
continue
if result:
# Convert dict to InvestorData if needed
if isinstance(result, dict):
investor_data = InvestorData(**result)
else:
investor_data = result
investors.append(investor_data)
# Save to database if requested
if save_to_db and db:
try:
saved_investor = self._save_investor_to_db(
db, investor_data
)
db.commit()
print(
f"✅ Saved investor '{saved_investor.name}' to database"
)
except Exception as e:
db.rollback()
print(f"❌ Failed to save investor to database: {e}")
print(
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
)
except Exception as e:
print(f"Error in batch processing: {e}")
if db:
db.rollback()
finally:
if db:
db.close()
return investors
async def parse_companies(self, df, save_to_db: bool = True):
"""Parse companies from DataFrame and optionally save to database"""
companies = []
db = None
if save_to_db:
db = get_db_session()
try:
# Process rows in batches asynchronously
batch_size = 15 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()]
for i in range(0, len(rows), batch_size):
batch = rows[i : i + batch_size]
# Process batch asynchronously
tasks = [
self._process_row(row, idx, is_investor=False) for idx, row in batch
]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle results from batch
for (idx, row), result in zip(batch, batch_results):
if isinstance(result, Exception):
print(f"Error processing row {idx}: {result}")
if db:
db.rollback()
continue
if result:
# Convert dict to CompanyData if needed
if isinstance(result, dict):
company_data = CompanyData(**result)
else:
company_data = result
companies.append(company_data)
# Save to database if requested
if save_to_db and db:
try:
saved_company = self._save_company_to_db(
db, company_data
)
db.commit()
print(
f"✅ Saved company '{saved_company.name}' to database"
)
except Exception as e:
db.rollback()
print(f"❌ Failed to save company to database: {e}")
print(
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
)
except Exception as e:
print(f"Error processing row {idx}: {e}")
if db:
db.rollback()
finally:
if db:
db.close()
return companies
if __name__ == "__main__":
main()
# async def main():
# """Main execution function"""
# # Initialize database tables
# print("🔧 Initializing database...")
# init_database()
# # Create processor
# processor = InvestorProcessor()
# print("📊 Processing companies...")
# companies = await processor.parse_companies(
# "data/19 Companies data.csv", save_to_db=True
# )
# print(f"Processed {len(companies)} companies")
# print("\n💰 Processing investors...")
# investors = await processor.parse_investors(
# "data/19 Investors data.csv", save_to_db=True
# )
# print(f"Processed {len(investors)} investors")
# print("\n✨ Processing complete!")
# if __name__ == "__main__":
# asyncio.run(main())
-293
View File
@@ -1,293 +0,0 @@
import asyncio
from typing import List, Optional
import chromadb
import pandas as pd
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from py_schemas import InvestorData
from pydantic import BaseModel
from settings import settings
class InvestorList(BaseModel):
"""Schema for LLM structured output"""
investor_list: List[InvestorData]
class InvestorProcessor:
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
):
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
Given the following CSV data rows:
{question}
For each row, extract and structure the following fields for the investor:
- name: The investor's full name
- description: Description of the investor
- aum: Assets under management (as integer, use 0 if not available)
- check_size_lower: Lower bound of investment check size (as integer)
- check_size_upper: Upper bound of investment check size (as integer)
- geographic_focus: Geographic region focus
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
- number_of_investments: Number of investments made (default 0)
Also extract related data:
- portfolio_companies: List of companies they've invested in
- team_members: List of team members with name, role, email
- sectors: List of sectors they focus on
Important:
- If a field is not available, use appropriate defaults
- stage_focus must be one of the valid enum values
- Return clean, valid JSON only
Return the data as a structured list of comprehensive investor data."""
self.prompt = PromptTemplate(
template=self.template, input_variables=["question"]
)
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite",
temperature=0,
)
self.structured_llm = self.llm.with_structured_output(InvestorList)
self.sql_session = sql_session
self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
)
async def _process_batch(
self, batch: pd.DataFrame, batch_idx: int
) -> List[InvestorData]:
"""Process a single batch of data"""
# Convert batch to string representation - clean the data
batch_str = ""
for idx, row in batch.iterrows():
# Clean values to remove control characters
cleaned_row = {}
for key, value in row.items():
if pd.notna(value):
# Convert to string and clean control characters
clean_value = (
str(value)
.replace("\n", " ")
.replace("\r", " ")
.replace("\t", " ")
)
# Remove other control characters
clean_value = "".join(
char
for char in clean_value
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
)
cleaned_row[key] = clean_value
row_str = ", ".join(
[f"{key}: {value}" for key, value in cleaned_row.items()]
)
batch_str += f"Row {idx + 1}: {row_str}\n"
try:
print(f"Processing batch {batch_idx + 1}...")
batch_results = await self.structured_llm.ainvoke(batch_str)
return batch_results.investor_list
except Exception as e:
print(f"Error processing batch {batch_idx + 1}: {e}")
return []
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors and related data to SQL database"""
if not self.sql_session:
return
try:
for investor_data in investor_data_list:
# Save investor
db_investor = InvestorTable(
name=investor_data.investor.name,
description=investor_data.investor.description,
aum=investor_data.investor.aum,
check_size_lower=investor_data.investor.check_size_lower,
check_size_upper=investor_data.investor.check_size_upper,
geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
)
self.sql_session.add(db_investor)
self.sql_session.flush() # Get the ID
# Save sectors and create associations
for sector_data in investor_data.sectors:
# Check if sector exists, create if not
existing_sector = (
self.sql_session.query(SectorTable)
.filter(SectorTable.name == sector_data.name)
.first()
)
if not existing_sector:
db_sector = SectorTable(name=sector_data.name)
self.sql_session.add(db_sector)
self.sql_session.flush()
# Add sector to investor's sectors
db_investor.sectors.append(db_sector)
else:
# Add existing sector to investor if not already there
if existing_sector not in db_investor.sectors:
db_investor.sectors.append(existing_sector)
# Save companies and create portfolio associations
for company_data in investor_data.portfolio_companies:
# Check if company exists, create if not
existing_company = (
self.sql_session.query(CompanyTable)
.filter(CompanyTable.name == company_data.name)
.first()
)
if not existing_company:
db_company = CompanyTable(
name=company_data.name,
industry=company_data.industry,
location=company_data.location,
founded_year=company_data.founded_year,
website=company_data.website,
)
self.sql_session.add(db_company)
self.sql_session.flush()
# Add to investor's portfolio
db_investor.portfolio_companies.append(db_company)
else:
# Add existing company to portfolio if not already there
if existing_company not in db_investor.portfolio_companies:
db_investor.portfolio_companies.append(existing_company)
# Save team members
for team_member_data in investor_data.team_members:
# Check if team member exists
existing_member = (
self.sql_session.query(InvestorTeamMember)
.filter(InvestorTeamMember.email == team_member_data.email)
.first()
)
if not existing_member:
db_team_member = InvestorTeamMember(
name=team_member_data.name,
role=team_member_data.role,
email=team_member_data.email,
investor_id=db_investor.id,
)
self.sql_session.add(db_team_member)
self.sql_session.commit()
print(f"Successfully saved {len(investor_data_list)} investors to database")
except Exception as e:
self.sql_session.rollback()
print(f"Error saving to SQL database: {e}")
raise
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors to vector database"""
if not self.vector_db_client:
return
documents = []
metadatas = []
ids = []
for i, investor_data in enumerate(investor_data_list):
investor = investor_data.investor
sectors = ", ".join([s.name for s in investor_data.sectors])
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
doc_text = f"""
Investor: {investor.name}
Description: {investor.description or "N/A"}
AUM: ${investor.aum:,}
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
Geographic Focus: {investor.geographic_focus}
Stage Focus: {investor.stage_focus.value}
Sectors: {sectors}
Portfolio Companies: {companies}
""".strip()
documents.append(doc_text)
metadatas.append(
{
"name": investor.name,
"stage_focus": investor.stage_focus.value,
"geographic_focus": investor.geographic_focus,
"aum": investor.aum,
}
)
ids.append(
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
)
if documents:
try:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
print(
f"Successfully saved {len(documents)} investors to vector database"
)
except Exception as e:
print(f"Error saving to vector database: {e}")
async def process_csv(
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
) -> List[InvestorData]:
"""Process CSV data in parallel batches and save to databases"""
results = []
# Create batches
batches = []
for i in range(0, len(df), batch_size):
batch = df.iloc[i : i + batch_size]
batches.append((batch, i // batch_size))
# Process batches with concurrency control
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(batch_data):
batch, batch_idx = batch_data
async with semaphore:
return await self._process_batch(batch, batch_idx)
# Execute all batches concurrently
batch_results = await asyncio.gather(
*[process_with_semaphore(batch_data) for batch_data in batches],
return_exceptions=True,
)
# Collect results, filtering out exceptions
for batch_result in batch_results:
if not isinstance(batch_result, Exception):
results.extend(batch_result)
# Save to databases
if results:
print(f"Successfully processed {len(results)} investors")
await self._save_to_sql(results)
await self._save_to_vector_db(results)
return results
-290
View File
@@ -1,290 +0,0 @@
import asyncio
from typing import List, Optional
import chromadb
import pandas as pd
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from py_schemas import InvestorData
from pydantic import BaseModel
from settings import settings
class InvestorOutput(BaseModel):
"""Schema for LLM structured output"""
investor_data: InvestorData
class InvestorProcessor:
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
):
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a structured record.
Given the following CSV data row:
{question}
Extract and structure the following fields for the investor:
- name: The investor's full name
- description: Description of the investor
- aum: Assets under management (as integer, use 0 if not available)
- check_size_lower: Lower bound of investment check size (as integer)
- check_size_upper: Upper bound of investment check size (as integer)
- geographic_focus: Geographic region focus
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
- number_of_investments: Number of investments made (default 0)
Also extract related data:
- portfolio_companies: List of companies they've invested in
- team_members: List of team members with name, role, email
- sectors: List of sectors they focus on
Important:
- If a field is not available, use appropriate defaults
- stage_focus must be one of the valid enum values
- Return clean, valid JSON only
Return the data as a single comprehensive investor data record."""
self.prompt = PromptTemplate(
template=self.template, input_variables=["question"]
)
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite",
temperature=0,
)
self.structured_llm = self.llm.with_structured_output(InvestorOutput)
self.sql_session = sql_session
self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
)
async def _process_row(
self, row: pd.Series, row_idx: int
) -> Optional[InvestorData]:
"""Process a single row of data"""
# Clean values to remove control characters
cleaned_row = {}
for key, value in row.items():
if pd.notna(value):
# Convert to string and clean control characters
clean_value = (
str(value)
.replace("\n", " ")
.replace("\r", " ")
.replace("\t", " ")
)
# Remove other control characters
clean_value = "".join(
char
for char in clean_value
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
)
cleaned_row[key] = clean_value
row_str = ", ".join(
[f"{key}: {value}" for key, value in cleaned_row.items()]
)
try:
print(f"Processing row {row_idx + 1}...")
result = await self.structured_llm.ainvoke(row_str)
if result.investor_data:
return result.investor_data
return None
except Exception as e:
print(f"Error processing row {row_idx + 1}: {e}")
return None
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors and related data to SQL database"""
if not self.sql_session:
return
try:
for investor_data in investor_data_list:
# Save investor
db_investor = InvestorTable(
name=investor_data.investor.name,
description=investor_data.investor.description,
aum=investor_data.investor.aum,
check_size_lower=investor_data.investor.check_size_lower,
check_size_upper=investor_data.investor.check_size_upper,
geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
)
self.sql_session.add(db_investor)
self.sql_session.flush() # Get the ID
# Save sectors and create associations
for sector_data in investor_data.sectors:
# Check if sector exists, create if not
existing_sector = (
self.sql_session.query(SectorTable)
.filter(SectorTable.name == sector_data.name)
.first()
)
if not existing_sector:
db_sector = SectorTable(name=sector_data.name)
self.sql_session.add(db_sector)
self.sql_session.flush()
# Add sector to investor's sectors
db_investor.sectors.append(db_sector)
else:
# Add existing sector to investor if not already there
if existing_sector not in db_investor.sectors:
db_investor.sectors.append(existing_sector)
# Save companies and create portfolio associations
for company_data in investor_data.portfolio_companies:
# Check if company exists, create if not
existing_company = (
self.sql_session.query(CompanyTable)
.filter(CompanyTable.name == company_data.name)
.first()
)
if not existing_company:
db_company = CompanyTable(
name=company_data.name,
industry=company_data.industry,
location=company_data.location,
founded_year=company_data.founded_year,
website=company_data.website,
)
self.sql_session.add(db_company)
self.sql_session.flush()
# Add to investor's portfolio
db_investor.portfolio_companies.append(db_company)
else:
# Add existing company to portfolio if not already there
if existing_company not in db_investor.portfolio_companies:
db_investor.portfolio_companies.append(existing_company)
# Save team members
for team_member_data in investor_data.team_members:
# Check if team member exists
existing_member = (
self.sql_session.query(InvestorTeamMember)
.filter(InvestorTeamMember.email == team_member_data.email)
.first()
)
if not existing_member:
db_team_member = InvestorTeamMember(
name=team_member_data.name,
role=team_member_data.role,
email=team_member_data.email,
investor_id=db_investor.id,
)
self.sql_session.add(db_team_member)
self.sql_session.commit()
print(f"Successfully saved {len(investor_data_list)} investors to database")
except Exception as e:
self.sql_session.rollback()
print(f"Error saving to SQL database: {e}")
raise
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors to vector database"""
if not self.vector_db_client:
return
documents = []
metadatas = []
ids = []
for i, investor_data in enumerate(investor_data_list):
investor = investor_data.investor
sectors = ", ".join([s.name for s in investor_data.sectors])
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
doc_text = f"""
Investor: {investor.name}
Description: {investor.description or "N/A"}
AUM: ${investor.aum:,}
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
Geographic Focus: {investor.geographic_focus}
Stage Focus: {investor.stage_focus.value}
Sectors: {sectors}
Portfolio Companies: {companies}
""".strip()
documents.append(doc_text)
metadatas.append(
{
"name": investor.name,
"stage_focus": investor.stage_focus.value,
"geographic_focus": investor.geographic_focus,
"aum": investor.aum,
}
)
ids.append(
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
)
if documents:
try:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
print(
f"Successfully saved {len(documents)} investors to vector database"
)
except Exception as e:
print(f"Error saving to vector database: {e}")
async def process_csv(
self, df: pd.DataFrame, max_concurrent: int = 10
) -> List[InvestorData]:
"""Process CSV data one row at a time and save to databases"""
results = []
# Create semaphore for concurrency control
semaphore = asyncio.Semaphore(max_concurrent)
async def process_row_with_semaphore(row_data):
row, row_idx = row_data
async with semaphore:
return await self._process_row(row, row_idx)
# Create row tasks
row_tasks = []
for idx, row in df.iterrows():
row_tasks.append((row, idx))
# Execute all rows concurrently
row_results = await asyncio.gather(
*[process_row_with_semaphore(row_data) for row_data in row_tasks],
return_exceptions=True,
)
# Collect results, filtering out exceptions and None values
for row_result in row_results:
if not isinstance(row_result, Exception) and row_result is not None:
results.append(row_result)
# Save to databases
if results:
print(f"Successfully processed {len(results)} investors")
await self._save_to_sql(results)
await self._save_to_vector_db(results)
return results
+53 -215
View File
@@ -1,88 +1,47 @@
from typing import List, Optional
import os
from typing import List
import chromadb
from db.db import DATABASE_URL, get_db
from db.models import InvestorTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from py_schemas import InvestorData, InvestorList
from settings import settings
from schemas.py_schemas import InvestorData, InvestorList
from sqlalchemy.orm import selectinload
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri("sqlite:///investors.db")
system_message = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n Get answers from the Sql database and the vector database"
)
db = SQLDatabase.from_uri(DATABASE_URL)
class QueryProcessor:
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
):
self.sql_session = sql_session
def __init__(self):
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite",
model="openai/gpt-5-nano",
temperature=0.3,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
+ "Do NOT return any other information, explanations, or data. "
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
+ "Example format: 1, 5, 12, 23"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools() + [self.query_vector_database],
prompt=system_message,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
)
self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
)
def query_sql_database(self, query: str) -> Optional[InvestorList]:
"""Query the SQL database for investor information."""
if not self.sql_session:
return None
# Implement SQL querying logic here
result = self.sql_session.execute(query)
investors = result.scalars().all()
return InvestorList(investors=investors)
def query_vector_database(self, query: str) -> Optional[InvestorList]:
"""Query the vector database for investor information."""
if not self.vector_db_client:
return None
print("VECTOR STORE WAS CALLED")
# Query the collection directly, not passing collection as parameter
results = self.collection.query(
query_texts=[query], # ChromaDB expects a list of query texts
n_results=3, # Specify how many results you want
)
print(results)
# ChromaDB returns results in a different structure
# results will have 'documents', 'metadatas', 'ids', 'distances'
return results
def process_query(self, question: str) -> InvestorList:
"""Process a query using the LLM and return structured investor data."""
# Extract filters from the query first
filters = self._extract_filters_from_query(question)
# Get AI response for additional context
"""Process a query using the LLM and return investor data."""
# Let the LLM handle all database interactions and filtering to get IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
)
@@ -92,178 +51,54 @@ class QueryProcessor:
response["messages"][-1].content if response.get("messages") else ""
)
# Try to extract investor IDs or names from the AI response
investor_ids = self._extract_investor_info_from_response(ai_response)
# Extract investor IDs from the AI response
investor_ids = self._extract_investor_ids_from_response(ai_response)
# Fetch filtered investor data with relationships from database
return self._fetch_investors_with_relationships(investor_ids, filters)
# Fetch full investor data using the IDs
return self._fetch_investors_by_ids(investor_ids)
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response."""
import re
def _extract_investor_info_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response. This is a simple implementation."""
# This is a basic implementation - you might want to make it more sophisticated
# based on how your AI formats responses
investor_ids = []
# If the AI can't provide structured data, fall back to getting all investors
# that match basic criteria
try:
# Try to extract numbers that might be IDs
import re
# Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response)
investor_ids = [int(num) for num in numbers]
ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower())
investor_ids = [int(id_str) for id_str in ids]
except Exception:
pass
# Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches:
investor_ids = [int(id_str) for id_str in id_matches]
return investor_ids if investor_ids else []
except Exception as e:
print(f"Error extracting IDs from response: {e}")
return []
def _extract_filters_from_query(self, question: str) -> dict:
"""Extract filter criteria from natural language query."""
question_lower = question.lower()
filters = {}
return investor_ids
# Extract stage filters
if any(
stage in question_lower
for stage in [
"seed",
"series a",
"series b",
"series c",
"growth",
"late stage",
]
):
if "seed" in question_lower:
filters["stage"] = "SEED"
elif "series a" in question_lower:
filters["stage"] = "SERIES_A"
elif "series b" in question_lower:
filters["stage"] = "SERIES_B"
elif "series c" in question_lower:
filters["stage"] = "SERIES_C"
elif "growth" in question_lower:
filters["stage"] = "GROWTH"
elif "late stage" in question_lower:
filters["stage"] = "LATE_STAGE"
# Extract geographic filters
if any(
geo in question_lower
for geo in [
"us",
"usa",
"united states",
"europe",
"asia",
"silicon valley",
"bay area",
]
):
if (
"us" in question_lower
or "usa" in question_lower
or "united states" in question_lower
):
filters["geography"] = "US"
elif "europe" in question_lower:
filters["geography"] = "Europe"
elif "asia" in question_lower:
filters["geography"] = "Asia"
elif "silicon valley" in question_lower or "bay area" in question_lower:
filters["geography"] = "Silicon Valley"
# Extract sector filters
sectors = [
"fintech",
"healthcare",
"saas",
"ai",
"biotech",
"consumer",
"enterprise",
"crypto",
"blockchain",
]
for sector in sectors:
if sector in question_lower:
filters["sector"] = sector
break
# Extract check size filters (simple patterns)
import re
amounts = re.findall(
r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower
)
if amounts:
amount = amounts[0].replace(",", "")
if "million" in question_lower or "m" in question_lower:
filters["min_check_size"] = int(float(amount) * 1000000)
elif "thousand" in question_lower or "k" in question_lower:
filters["min_check_size"] = int(float(amount) * 1000)
return filters
def _fetch_investors_with_relationships(
self, investor_ids: List[int] = None, filters: dict = None
) -> InvestorList:
"""Fetch investors with all their relationships from the database."""
if not self.sql_session:
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
"""Fetch investors with all their relationships from the database using IDs."""
if not investor_ids:
return InvestorList(investors=[])
# Import here to avoid circular imports
from db.models import SectorTable
# Get database session
db_session = next(get_db())
try:
# Build query with all relationships loaded
query = self.sql_session.query(InvestorTable).options(
query = (
db_session.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
# Apply filters if provided
if filters:
if "stage" in filters:
from db.models import InvestmentStage
stage_enum = getattr(InvestmentStage, filters["stage"])
query = query.filter(InvestorTable.stage_focus == stage_enum)
if "geography" in filters:
query = query.filter(
InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%")
.filter(InvestorTable.id.in_(investor_ids))
)
if "min_check_size" in filters:
query = query.filter(
InvestorTable.check_size_lower >= filters["min_check_size"]
)
if "max_check_size" in filters:
query = query.filter(
InvestorTable.check_size_upper <= filters["max_check_size"]
)
if "min_aum" in filters:
query = query.filter(InvestorTable.aum >= filters["min_aum"])
if "max_aum" in filters:
query = query.filter(InvestorTable.aum <= filters["max_aum"])
if "sector" in filters:
query = query.join(InvestorTable.sectors).filter(
SectorTable.name.ilike(f"%{filters['sector']}%")
)
# Filter by IDs if provided
if investor_ids:
query = query.filter(InvestorTable.id.in_(investor_ids))
else:
# If no specific IDs and no filters, limit to prevent overwhelming response
if not filters:
query = query.limit(10)
investors = query.all()
# Transform to InvestorData format
@@ -278,3 +113,6 @@ class QueryProcessor:
investor_data_list.append(investor_data)
return InvestorList(investors=investor_data_list)
finally:
db_session.close()
-11
View File
@@ -1,11 +0,0 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
OPENROUTER_API_KEY: str
class Config:
env_file = ".env"
settings = Settings()