made version 2
This commit is contained in:
@@ -1,577 +0,0 @@
|
||||
# LLM-Powered Investor & Company Management API
|
||||
|
||||
A comprehensive FastAPI-based system for managing investor and company data with LLM-powered CSV parsing, semantic search, and advanced filtering capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **FastAPI REST API**: Modern, auto-documented API with OpenAPI/Swagger support
|
||||
- **CSV Data Processing**: Parse complex investor data from CSV files using LLM assistance
|
||||
- **Dual Database Storage**: Structured data in SQL database and semantic search via ChromaDB
|
||||
- **Natural Language Queries**: AI-powered query processing for complex investor searches
|
||||
- **Advanced Filtering**: Filter investors and companies by multiple criteria
|
||||
- **Relationship Management**: Many-to-many relationships between investors, companies, and sectors
|
||||
- **Auto-Generated Documentation**: Interactive API docs at `/docs`
|
||||
|
||||
## Architecture
|
||||
|
||||
### Components
|
||||
|
||||
1. **FastAPI Application (`app/main.py`)**: Main API server with route configuration
|
||||
2. **Database Models (`app/db/models.py`)**: SQLAlchemy models for investors, companies, sectors
|
||||
3. **Pydantic Schemas (`app/py_schemas.py`)**: Request/response validation and serialization
|
||||
4. **API Routes**:
|
||||
- `app/api/investors.py`: Investor CRUD operations and filtering
|
||||
- `app/api/companies.py`: Company CRUD operations and filtering
|
||||
5. **Services**:
|
||||
- `app/services/openrouter.py`: LLM-powered CSV processing
|
||||
- `app/services/querying.py`: Natural language query processing
|
||||
6. **Database (`app/db/`)**: Database connection, models, and schemas
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
CSV Upload → LLM Processing → Data Extraction → SQL Storage → Vector Storage → API Endpoints
|
||||
↓
|
||||
Natural Language Query → AI Analysis → Database Filtering → Structured Response
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- FastAPI and dependencies
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the repository and navigate to the project directory:
|
||||
|
||||
```bash
|
||||
cd /path/to/anton_wireframe
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Configure environment variables:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your OpenRouter API key for LLM features
|
||||
```
|
||||
|
||||
4. Initialize the database:
|
||||
|
||||
```bash
|
||||
cd app
|
||||
python -c "from db.db import init_database; init_database()"
|
||||
```
|
||||
|
||||
5. Start the API server:
|
||||
|
||||
```bash
|
||||
cd app
|
||||
uvicorn main:app --reload --host localhost --port 8000
|
||||
```
|
||||
|
||||
The API will be available at:
|
||||
|
||||
- **API Base**: http://localhost:8000
|
||||
- **Interactive Docs**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
## Database Schema
|
||||
|
||||
### SQL Database (SQLite)
|
||||
|
||||
#### Investors Table
|
||||
|
||||
- **Basic Info**: name, description, geographic_focus
|
||||
- **Investment Data**: aum, check_size_lower, check_size_upper
|
||||
- **Stage Focus**: investment stage (SEED, SERIES_A, etc.)
|
||||
- **Relationships**: Many-to-many with companies and sectors
|
||||
- **Team**: One-to-many with team members
|
||||
- **Metadata**: created_at, updated_at timestamps
|
||||
|
||||
#### Companies Table
|
||||
|
||||
- **Basic Info**: name, industry, location
|
||||
- **Details**: founded_year, website
|
||||
- **Relationships**: Many-to-many with investors
|
||||
- **Metadata**: created_at, updated_at timestamps
|
||||
|
||||
#### Association Tables
|
||||
|
||||
- **investor_companies**: Links investors to their portfolio companies
|
||||
- **investor_sectors**: Links investors to their focus sectors
|
||||
- **investor_team**: Team member details for each investor
|
||||
|
||||
#### Supporting Tables
|
||||
|
||||
- **sectors**: Investment focus areas (fintech, healthcare, etc.)
|
||||
|
||||
### Vector Database (ChromaDB)
|
||||
|
||||
Stores embeddings for semantic search of:
|
||||
|
||||
- Investor descriptions
|
||||
- Investment thesis focus areas
|
||||
- Combined investor profiles
|
||||
|
||||
## API Usage
|
||||
|
||||
### Interactive Documentation
|
||||
|
||||
Visit http://localhost:8000/docs for the auto-generated Swagger UI where you can:
|
||||
|
||||
- Explore all endpoints
|
||||
- Test API calls directly
|
||||
- View request/response schemas
|
||||
- See example requests
|
||||
|
||||
### Core Endpoints
|
||||
|
||||
#### Investor Management
|
||||
|
||||
```bash
|
||||
# Get all investors with relationships
|
||||
GET /investors
|
||||
|
||||
# Filter investors by criteria
|
||||
GET /investors/filter?stage=GROWTH&geography=US§or=fintech&min_check_size=1000000
|
||||
|
||||
# Get specific investor
|
||||
GET /investors/{investor_id}
|
||||
|
||||
# Create new investor
|
||||
POST /investors
|
||||
{
|
||||
"name": "Example VC",
|
||||
"description": "Early stage fintech investor",
|
||||
"aum": 50000000,
|
||||
"check_size_lower": 100000,
|
||||
"check_size_upper": 2000000,
|
||||
"geographic_focus": "US",
|
||||
"stage_focus": "SEED",
|
||||
"number_of_investments": 25
|
||||
}
|
||||
|
||||
# Update investor
|
||||
PUT /investors/{investor_id}
|
||||
|
||||
# Delete investor
|
||||
DELETE /investors/{investor_id}
|
||||
```
|
||||
|
||||
#### Company Management
|
||||
|
||||
```bash
|
||||
# Get all companies with investor relationships
|
||||
GET /companies
|
||||
|
||||
# Filter companies by criteria
|
||||
GET /companies/filter?industry=fintech&location=San Francisco&founded_after=2015
|
||||
|
||||
# Get specific company
|
||||
GET /companies/{company_id}
|
||||
|
||||
# Create new company
|
||||
POST /companies
|
||||
{
|
||||
"name": "Example Startup",
|
||||
"industry": "fintech",
|
||||
"location": "San Francisco",
|
||||
"founded_year": 2020,
|
||||
"website": "https://example.com"
|
||||
}
|
||||
|
||||
# Update company
|
||||
PUT /companies/{company_id}
|
||||
|
||||
# Delete company
|
||||
DELETE /companies/{company_id}
|
||||
```
|
||||
|
||||
#### CSV Processing
|
||||
|
||||
```bash
|
||||
# Upload and process CSV file
|
||||
POST /parse-csv
|
||||
Content-Type: multipart/form-data
|
||||
File: investors.csv
|
||||
```
|
||||
|
||||
#### Natural Language Queries
|
||||
|
||||
```bash
|
||||
# Query investors using natural language
|
||||
POST /query
|
||||
{
|
||||
"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $1 million"
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Filtering Examples
|
||||
|
||||
#### Investor Filters
|
||||
|
||||
```bash
|
||||
# Early stage investors in Europe
|
||||
GET /investors/filter?stage=SEED&geography=Europe
|
||||
|
||||
# High AUM growth investors
|
||||
GET /investors/filter?stage=GROWTH&min_aum=100000000
|
||||
|
||||
# Healthcare investors with large checks
|
||||
GET /investors/filter?sector=healthcare&min_check_size=5000000
|
||||
|
||||
# Specific geographic focus
|
||||
GET /investors/filter?geography=Silicon Valley
|
||||
```
|
||||
|
||||
#### Company Filters
|
||||
|
||||
```bash
|
||||
# Recent fintech companies
|
||||
GET /companies/filter?industry=fintech&founded_after=2020
|
||||
|
||||
# Companies with websites
|
||||
GET /companies/filter?has_website=true
|
||||
|
||||
# Companies backed by specific investor
|
||||
GET /companies/filter?investor_name=Sequoia
|
||||
|
||||
# Location-based filtering
|
||||
GET /companies/filter?location=New York
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
All endpoints return structured JSON with full relationship data:
|
||||
|
||||
```json
|
||||
{
|
||||
"investor": {
|
||||
"id": 1,
|
||||
"name": "Example VC",
|
||||
"description": "Early stage investor",
|
||||
"aum": 50000000,
|
||||
"check_size_lower": 100000,
|
||||
"check_size_upper": 2000000,
|
||||
"geographic_focus": "US",
|
||||
"stage_focus": "SEED",
|
||||
"number_of_investments": 25
|
||||
},
|
||||
"portfolio_companies": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "StartupCo",
|
||||
"industry": "fintech",
|
||||
"location": "San Francisco"
|
||||
}
|
||||
],
|
||||
"team_members": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Partner",
|
||||
"role": "Managing Partner",
|
||||
"email": "john@examplevc.com"
|
||||
}
|
||||
],
|
||||
"sectors": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "fintech"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Data Processing Pipeline
|
||||
|
||||
### 1. CSV Parsing
|
||||
|
||||
- Reads CSV with pandas
|
||||
- Handles nested JSON fields in columns
|
||||
- Validates data with Pydantic models
|
||||
|
||||
### 2. JSON Field Processing
|
||||
|
||||
- Direct parsing for well-formed JSON
|
||||
- LLM-assisted cleaning for malformed JSON (when enabled)
|
||||
- Graceful fallback to empty objects
|
||||
|
||||
### 3. Data Extraction
|
||||
|
||||
Extracts key fields:
|
||||
|
||||
- Company name and website
|
||||
- Investor description
|
||||
- Investment thesis/focus areas
|
||||
- Headquarters location
|
||||
- Assets Under Management (AUM)
|
||||
- Fund information
|
||||
|
||||
### 4. LLM Enhancement (Optional)
|
||||
|
||||
When `--use-llm` is enabled:
|
||||
|
||||
- Standardizes investor descriptions
|
||||
- Normalizes investment focus areas
|
||||
- Cleans headquarters location format
|
||||
- Repairs malformed JSON data
|
||||
|
||||
### 5. Dual Storage
|
||||
|
||||
- **SQL Database**: Structured, queryable data
|
||||
- **Vector Database**: Semantic search capabilities
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables (.env)
|
||||
|
||||
```bash
|
||||
# OpenRouter API Configuration (required for LLM features)
|
||||
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||
|
||||
# Database Configuration (optional, defaults to SQLite)
|
||||
DATABASE_URL=sqlite:///investors.db
|
||||
|
||||
# FastAPI Configuration
|
||||
API_HOST=localhost
|
||||
API_PORT=8000
|
||||
```
|
||||
|
||||
### LLM Configuration
|
||||
|
||||
- **Provider**: OpenRouter (supports multiple models)
|
||||
- **Default Model**: google/gemini-2.5-flash-lite
|
||||
- **Temperature**: 0.3 for enhancement, 0 for structured data
|
||||
- **Fallback**: Graceful degradation when API unavailable
|
||||
|
||||
## Natural Language Query Processing
|
||||
|
||||
The system supports intelligent natural language queries that automatically extract filters and search criteria:
|
||||
|
||||
### Query Examples
|
||||
|
||||
```bash
|
||||
# Stage-based queries
|
||||
"Show me seed stage investors"
|
||||
"Find growth stage VCs"
|
||||
|
||||
# Geographic queries
|
||||
"Investors in Silicon Valley"
|
||||
"European venture capital firms"
|
||||
|
||||
# Sector-specific queries
|
||||
"Fintech investors"
|
||||
"Healthcare and biotech VCs"
|
||||
|
||||
# Size-based queries
|
||||
"Investors with $5M+ check sizes"
|
||||
"High AUM growth investors"
|
||||
|
||||
# Combined queries
|
||||
"Growth stage fintech investors in the US with check sizes over $1 million"
|
||||
"European healthcare investors focusing on early stage"
|
||||
```
|
||||
|
||||
### Query Processing Features
|
||||
|
||||
- **Automatic Filter Extraction**: Detects investment stages, geographies, sectors, and check sizes
|
||||
- **Semantic Understanding**: Uses AI to interpret complex queries
|
||||
- **Database Integration**: Combines AI analysis with efficient SQL filtering
|
||||
- **Complete Relationships**: Returns full investor data with portfolio companies, team members, and sectors
|
||||
|
||||
### Query Response
|
||||
|
||||
The `/query` endpoint returns a structured `InvestorList` with complete relationship data, making it easy to get comprehensive information about matching investors.
|
||||
|
||||
## Error Handling
|
||||
|
||||
### API Error Responses
|
||||
|
||||
The API provides clear HTTP status codes and error messages:
|
||||
|
||||
```json
|
||||
// 404 Not Found
|
||||
{
|
||||
"detail": "Investor not found"
|
||||
}
|
||||
|
||||
// 422 Validation Error
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["body", "stage_focus"],
|
||||
"msg": "value is not a valid enumeration member",
|
||||
"type": "type_error.enum"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Robust Processing
|
||||
|
||||
- **Data Validation**: Pydantic models ensure data integrity
|
||||
- **Relationship Management**: Automatic handling of foreign key constraints
|
||||
- **LLM Fallbacks**: Graceful degradation when AI services unavailable
|
||||
- **Transaction Safety**: Database rollbacks on errors
|
||||
- **Comprehensive Logging**: Detailed error tracking and debugging
|
||||
|
||||
### Common Issues and Solutions
|
||||
|
||||
1. **Invalid Enum Values**
|
||||
|
||||
- Solution: Use uppercase enum values (SEED, GROWTH, etc.)
|
||||
- Check: Investment stages must match defined enum
|
||||
|
||||
2. **Missing OpenRouter API Key**
|
||||
|
||||
- Solution: Set OPENROUTER_API_KEY in environment
|
||||
- Fallback: CSV processing continues without LLM enhancement
|
||||
|
||||
3. **Database Connection Issues**
|
||||
|
||||
- Solution: Verify DATABASE_URL configuration
|
||||
- Default: Uses SQLite (no external dependencies)
|
||||
|
||||
4. **Relationship Errors**
|
||||
- Solution: Ensure proper foreign key relationships
|
||||
- Check: Use existing sector/company IDs or create new ones
|
||||
|
||||
## Performance
|
||||
|
||||
### Benchmarks (Approximate)
|
||||
|
||||
- **API Response Time**: <200ms for standard queries
|
||||
- **Database Queries**: <50ms for filtered searches with relationships
|
||||
- **CSV Processing**: ~5-15 seconds per row (depends on LLM API latency)
|
||||
- **Natural Language Queries**: ~2-5 seconds (AI processing + database query)
|
||||
- **Vector Search**: <100ms for semantic similarity queries
|
||||
|
||||
### Optimization Features
|
||||
|
||||
1. **Eager Loading**: Efficient relationship loading with `selectinload()`
|
||||
2. **Query Optimization**: Smart filtering to reduce database load
|
||||
3. **Caching**: Database connection pooling and session management
|
||||
4. **Pagination**: Built-in limits to prevent overwhelming responses
|
||||
5. **Async Processing**: FastAPI async capabilities for better performance
|
||||
|
||||
### Production Recommendations
|
||||
|
||||
1. **Database**: Consider PostgreSQL for production workloads
|
||||
2. **Caching**: Add Redis for frequently accessed data
|
||||
3. **Load Balancing**: Deploy multiple API instances behind a load balancer
|
||||
4. **Monitoring**: Implement logging and metrics collection
|
||||
5. **Rate Limiting**: Add API rate limiting for public endpoints
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
anton_wireframe/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI application and main endpoints
|
||||
│ ├── py_schemas.py # Pydantic models for validation
|
||||
│ ├── settings.py # Configuration management
|
||||
│ ├── api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── investors.py # Investor CRUD and filtering endpoints
|
||||
│ │ └── companies.py # Company CRUD and filtering endpoints
|
||||
│ ├── db/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── db.py # Database connection and session management
|
||||
│ │ ├── models.py # SQLAlchemy database models
|
||||
│ │ └── new_schema.py # Additional schema definitions
|
||||
│ └── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── openrouter.py # LLM-powered CSV processing
|
||||
│ ├── querying.py # Natural language query processing
|
||||
│ └── langgraph_agent.py # AI agent configuration
|
||||
├── chroma_db/ # Vector database directory
|
||||
├── requirements.txt # Python dependencies
|
||||
├── README.md # This documentation
|
||||
└── .env # Environment configuration
|
||||
```
|
||||
|
||||
## Example Usage Scenarios
|
||||
|
||||
### 1. Upload and Process Investor Data
|
||||
|
||||
```bash
|
||||
# Upload CSV file via API
|
||||
curl -X POST "http://localhost:8000/parse-csv" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@investors.csv"
|
||||
```
|
||||
|
||||
### 2. Find Specific Investors
|
||||
|
||||
```bash
|
||||
# Natural language search
|
||||
curl -X POST "http://localhost:8000/query" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $2 million"}'
|
||||
|
||||
# Structured filtering
|
||||
curl "http://localhost:8000/investors/filter?stage=GROWTH§or=fintech&geography=Silicon%20Valley&min_check_size=2000000"
|
||||
```
|
||||
|
||||
### 3. Company Research
|
||||
|
||||
```bash
|
||||
# Find companies in specific sector
|
||||
curl "http://localhost:8000/companies/filter?industry=fintech&founded_after=2020"
|
||||
|
||||
# Find companies backed by specific investor
|
||||
curl "http://localhost:8000/companies/filter?investor_name=Sequoia"
|
||||
```
|
||||
|
||||
### 4. Investment Analysis
|
||||
|
||||
```bash
|
||||
# Get investor with full portfolio
|
||||
curl "http://localhost:8000/investors/1"
|
||||
|
||||
# Find all companies in a specific location
|
||||
curl "http://localhost:8000/companies/filter?location=San%20Francisco"
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Running in Development Mode
|
||||
|
||||
```bash
|
||||
cd app
|
||||
uvicorn main:app --reload --host localhost --port 8000
|
||||
```
|
||||
|
||||
### Testing the API
|
||||
|
||||
1. **Interactive Testing**: Visit http://localhost:8000/docs
|
||||
2. **Manual Testing**: Use curl or Postman with the examples above
|
||||
3. **Database Inspection**: Use SQLite browser to inspect `investors_2.db`
|
||||
|
||||
### Adding New Features
|
||||
|
||||
1. **New Endpoints**: Add routes to `api/investors.py` or `api/companies.py`
|
||||
2. **New Models**: Update `db/models.py` and `py_schemas.py`
|
||||
3. **New Filters**: Extend filtering logic in route handlers
|
||||
4. **New LLM Features**: Modify `services/openrouter.py` or `services/querying.py`
|
||||
|
||||
## License
|
||||
|
||||
This project is part of the MKD Anton Wireframe system.
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions:
|
||||
|
||||
1. Check logs for detailed error messages
|
||||
2. Verify environment configuration
|
||||
3. Test with limited datasets first
|
||||
4. Review CSV data format requirements
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,46 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from db.models import InvestorTable
|
||||
from db.db import get_db
|
||||
|
||||
def update_stage_focus_values():
|
||||
"""Update existing stage_focus values from lowercase to uppercase"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Mapping of old lowercase values to new uppercase values
|
||||
stage_mappings = {
|
||||
'seed': 'SEED',
|
||||
'series_a': 'SERIES_A',
|
||||
'series_b': 'SERIES_B',
|
||||
'series_c': 'SERIES_C',
|
||||
'growth': 'GROWTH',
|
||||
'late_stage': 'LATE_STAGE'
|
||||
}
|
||||
|
||||
updated_count = 0
|
||||
|
||||
for old_value, new_value in stage_mappings.items():
|
||||
# Update records with the old value
|
||||
result = db.query(InvestorTable).filter(
|
||||
InvestorTable.stage_focus == old_value
|
||||
).update(
|
||||
{InvestorTable.stage_focus: new_value},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
updated_count += result
|
||||
print(f"Updated {result} records from '{old_value}' to '{new_value}'")
|
||||
|
||||
db.commit()
|
||||
print(f"Successfully updated {updated_count} total records")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"Error updating stage_focus values: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Run the update
|
||||
if __name__ == "__main__":
|
||||
update_stage_focus_values()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+5
-1
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
@@ -38,3 +38,7 @@ def init_database():
|
||||
def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
+54
-30
@@ -1,11 +1,17 @@
|
||||
import datetime
|
||||
import enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from db.db import Base
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
|
||||
from sqlalchemy.orm import declarative_mixin, relationship
|
||||
from sqlalchemy.types import Enum
|
||||
|
||||
from db.db import Base
|
||||
|
||||
@declarative_mixin
|
||||
class TimestampMixin:
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
@@ -16,6 +22,7 @@ class InvestmentStage(enum.Enum):
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
investor_company_association = Table(
|
||||
"investor_companies",
|
||||
@@ -34,7 +41,15 @@ investor_sector_association = Table(
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base):
|
||||
company_sector_association = Table(
|
||||
"company_sector",
|
||||
Base.metadata,
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
@@ -46,12 +61,6 @@ class InvestorTable(Base):
|
||||
geographic_focus = Column(String, nullable=False)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||
number_of_investments = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
@@ -59,7 +68,7 @@ class InvestorTable(Base):
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
team_members = relationship("InvestorTeamMember", back_populates="investor")
|
||||
team_members = relationship("InvestorMember", back_populates="investor")
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=investor_sector_association,
|
||||
@@ -67,22 +76,29 @@ class InvestorTable(Base):
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base):
|
||||
class InvestorMember(Base, TimestampMixin):
|
||||
__tablename__ = "investor_members"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
email = Column(String, nullable=False)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=False)
|
||||
location = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
members = relationship("CompanyMember", back_populates="company")
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
@@ -90,8 +106,23 @@ class CompanyTable(Base):
|
||||
back_populates="portfolio_companies",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable", secondary=company_sector_association, back_populates="companies"
|
||||
)
|
||||
|
||||
class SectorTable(Base):
|
||||
|
||||
class CompanyMember(Base, TimestampMixin):
|
||||
__tablename__ = "company_members"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
linkedin = Column(String)
|
||||
role = Column(String)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
|
||||
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
@@ -104,13 +135,6 @@ class SectorTable(Base):
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class InvestorTeamMember(Base):
|
||||
__tablename__ = "investor_team"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
email = Column(String, nullable=False)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Investor(Base):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(500), nullable=False)
|
||||
website = Column(String(1000))
|
||||
|
||||
# Core investment information
|
||||
investor_description = Column(Text)
|
||||
investment_thesis_focus = Column(JSON) # List of focus areas
|
||||
headquarters = Column(String(1000))
|
||||
|
||||
# AUM information
|
||||
aum_amount = Column(String(200))
|
||||
aum_as_of_date = Column(String(100))
|
||||
aum_source_url = Column(String(1000))
|
||||
|
||||
# Fund information
|
||||
funds_info = Column(JSON) # Complex fund data
|
||||
|
||||
# Raw data columns for reference
|
||||
crunchbase_urls = Column(Text)
|
||||
crunchbase_extract = Column(Text)
|
||||
linkedin_profile = Column(Text)
|
||||
source_truth_profile = Column(Text)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Investor(name='{self.name}', website='{self.website}')>"
|
||||
|
||||
|
||||
# Pydantic models for data validation and parsing
|
||||
class AUMInfo(BaseModel):
|
||||
aumAmount: Optional[str] = None
|
||||
asOfDate: Optional[str] = None
|
||||
sourceUrl: Optional[str] = None
|
||||
|
||||
|
||||
class FundInfo(BaseModel):
|
||||
fundName: Optional[str] = None
|
||||
fundSize: Optional[str] = None
|
||||
vintage: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class InvestorProfile(BaseModel):
|
||||
websiteURL: Optional[str] = None
|
||||
investorDescription: Optional[str] = None
|
||||
investmentThesisFocus: Optional[List[str]] = None
|
||||
headquarters: Optional[str] = None
|
||||
overallAssetsUnderManagement: Optional[AUMInfo] = None
|
||||
funds: Optional[List[FundInfo]] = None
|
||||
|
||||
|
||||
class CSVRow(BaseModel):
|
||||
name: str
|
||||
website: Optional[str] = None
|
||||
investment_firm_profile: Optional[str] = None
|
||||
crunchbase_linkedin_urls: Optional[str] = None
|
||||
crunchbase_firm_extract: Optional[str] = None
|
||||
linkedin_investment_profile: Optional[str] = None
|
||||
source_of_truth_profile: Optional[str] = None
|
||||
|
||||
def get_combined_description(self) -> str:
|
||||
"""Combine all description fields for vector embedding"""
|
||||
descriptions = []
|
||||
|
||||
if self.investment_firm_profile:
|
||||
try:
|
||||
profile_data = json.loads(self.investment_firm_profile)
|
||||
if isinstance(profile_data, dict):
|
||||
desc = profile_data.get("investorDescription", "")
|
||||
if desc:
|
||||
descriptions.append(desc)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if self.crunchbase_firm_extract:
|
||||
descriptions.append(self.crunchbase_firm_extract)
|
||||
|
||||
if self.linkedin_investment_profile:
|
||||
descriptions.append(self.linkedin_investment_profile)
|
||||
|
||||
if self.source_of_truth_profile:
|
||||
descriptions.append(self.source_of_truth_profile)
|
||||
|
||||
return " ".join(descriptions)
|
||||
|
||||
def get_investment_focus(self) -> List[str]:
|
||||
"""Extract investment thesis focus"""
|
||||
if self.investment_firm_profile:
|
||||
try:
|
||||
profile_data = json.loads(self.investment_firm_profile)
|
||||
if isinstance(profile_data, dict):
|
||||
focus = profile_data.get("investmentThesisFocus", [])
|
||||
if isinstance(focus, list):
|
||||
return focus
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return []
|
||||
+18
-11
@@ -1,17 +1,20 @@
|
||||
import io
|
||||
|
||||
import pandas as pd
|
||||
from api import companies, investors
|
||||
from db.db import db_dependency, init_database
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from py_schemas import InvestorList
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from services.openrouter_v2 import InvestorProcessor
|
||||
from routers import companies, investors
|
||||
from schemas.router_schemas import InvestorList
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
init_database()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Request models
|
||||
class QueryRequest(BaseModel):
|
||||
@@ -20,7 +23,7 @@ class QueryRequest(BaseModel):
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"question": "Show me growth stage fintech investors in the US with check sizes over $1 million"
|
||||
"question": "Find me deep tech investors that do deals in Europe under 5 million."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,21 +34,25 @@ def health():
|
||||
|
||||
|
||||
@app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict])
|
||||
async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
||||
async def parse_csv(db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)):
|
||||
# Read uploaded CSV with pandas
|
||||
content = await file.read()
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
|
||||
# Process the dataframe
|
||||
processor = InvestorProcessor(sql_session=db)
|
||||
results = await processor.process_csv(df)
|
||||
processor = InvestorProcessor()
|
||||
|
||||
if is_investor == 1:
|
||||
results = await processor.parse_investors(df)
|
||||
else:
|
||||
results = await processor.parse_companies(df)
|
||||
|
||||
# Convert Pydantic objects to dictionaries
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
@@ -55,7 +62,7 @@ async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
- "Growth stage investors with $5M+ check sizes"
|
||||
- "Healthcare investors in Europe"
|
||||
"""
|
||||
processor = QueryProcessor(sql_session=db)
|
||||
processor = QueryProcessor()
|
||||
results = processor.process_query(request.question)
|
||||
return results
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Investor(BaseModel):
|
||||
name: str
|
||||
aum: int
|
||||
check_size: str
|
||||
sector_focus: str
|
||||
stage_focus: str
|
||||
region: str
|
||||
investment_thesis: str
|
||||
investor_description: str
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investor_list: List[Investor]
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
name: str
|
||||
aum: int
|
||||
check_size: str
|
||||
sector_focus: str
|
||||
stage_focus: str
|
||||
region: str
|
||||
investment_thesis: str
|
||||
investor_description: str
|
||||
reason: str
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
|
||||
class QueryResponseList(BaseModel):
|
||||
responses: List[QueryResponse]
|
||||
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
@@ -3,8 +3,8 @@ from typing import List, Optional
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import CompanySchema
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import CompanyData
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
@@ -15,6 +15,7 @@ class CompanyCreate(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
description: Optional[str] = None
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
@@ -23,46 +24,33 @@ class CompanyUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
# Response schema with relationships
|
||||
class CompanyData(BaseModel):
|
||||
"""Comprehensive company data schema"""
|
||||
|
||||
company: CompanySchema
|
||||
investors: List["InvestorBasic"] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorBasic(BaseModel):
|
||||
"""Basic investor info for company responses"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
geographic_focus: str
|
||||
stage_focus: str
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
companies = (
|
||||
db.query(CompanyTable).options(selectinload(CompanyTable.investors)).all()
|
||||
db.query(CompanyTable)
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform CompanyTable objects to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
@@ -89,7 +77,11 @@ def filter_companies(
|
||||
"""Filter companies based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(selectinload(CompanyTable.investors))
|
||||
query = db.query(CompanyTable).options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if industry:
|
||||
@@ -121,7 +113,12 @@ def filter_companies(
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
@@ -132,7 +129,11 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific company by ID with its investors"""
|
||||
company = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
@@ -141,7 +142,12 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(company=company, investors=company.investors)
|
||||
return CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/companies", response_model=CompanyData)
|
||||
@@ -155,14 +161,21 @@ def create_company(company: CompanyCreate, db: Session = Depends(get_db)):
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == db_company.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
company=company_with_relations,
|
||||
investors=company_with_relations.investors,
|
||||
members=company_with_relations.members,
|
||||
sectors=company_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@@ -185,14 +198,21 @@ def update_company(
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
company=company_with_relations,
|
||||
investors=company_with_relations.investors,
|
||||
members=company_with_relations.members,
|
||||
sectors=company_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import InvestmentStage, InvestorData
|
||||
from schemas.router_schemas import InvestmentStage, InvestorData
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@@ -13,7 +14,7 @@ router = APIRouter(tags=["Investor Routes"])
|
||||
# Request schemas for creating/updating
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
description: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
@@ -23,14 +24,14 @@ class InvestorCreate(BaseModel):
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: str = None
|
||||
description: str = None
|
||||
aum: int = None
|
||||
check_size_lower: int = None
|
||||
check_size_upper: int = None
|
||||
geographic_focus: str = None
|
||||
stage_focus: InvestmentStage = None
|
||||
number_of_investments: int = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
aum: Optional[int] = None
|
||||
check_size_lower: Optional[int] = None
|
||||
check_size_upper: Optional[int] = None
|
||||
geographic_focus: Optional[str] = None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
number_of_investments: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=List[InvestorData])
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,111 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
class SectorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str
|
||||
email: str
|
||||
investor_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: Optional[str] = None
|
||||
linkedin: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
company_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
description: Optional[str] = None # Fixed typo from 'nullabel'
|
||||
founded_year: Optional[int] = None # Changed from str to int to match model
|
||||
website: Optional[str] = None
|
||||
|
||||
@field_validator("founded_year", mode="before")
|
||||
@classmethod
|
||||
def validate_founded_year(cls, v):
|
||||
if v is None or v == "Not Available" or v == "":
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
return None
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int = 0
|
||||
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema] = []
|
||||
team_members: List[InvestorMemberSchema] = [] # Changed from TeamMember
|
||||
sectors: List[SectorSchema] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema] = []
|
||||
members: List[CompanyMemberSchema] = [] # Changed to match model relationship name
|
||||
investors: List[InvestorSchema] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData] = []
|
||||
|
||||
@@ -22,11 +22,31 @@ class SectorSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str
|
||||
email: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanyMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: Optional[str] = None
|
||||
linkedin: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
company_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
description: Optional[str]
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
@@ -36,15 +56,6 @@ class CompanySchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorTeamMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str
|
||||
email: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
@@ -67,13 +78,22 @@ class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema] = []
|
||||
team_members: List[InvestorTeamMemberSchema] = []
|
||||
sectors: List[SectorSchema] = []
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema]
|
||||
members: List[CompanyMemberSchema]
|
||||
investors: List[InvestorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+305
-336
@@ -1,368 +1,337 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from db import get_session, init_database
|
||||
from py_schemas import CSVRow, Investor
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
from db.db import get_db_session
|
||||
from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from schemas.py_schemas import CompanyData, InvestorData
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class LLMInvestorParser:
|
||||
class InvestorProcessor:
|
||||
def __init__(self):
|
||||
# Initialize OpenAI client
|
||||
self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Initialize ChromaDB
|
||||
self.chroma_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.chroma_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize database
|
||||
init_database()
|
||||
|
||||
def parse_json_field(self, json_str: str) -> Dict[str, Any]:
|
||||
"""Safely parse JSON string with LLM assistance if needed"""
|
||||
if not json_str or json_str.strip() == "":
|
||||
return {}
|
||||
|
||||
try:
|
||||
# Try direct JSON parsing first
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# If direct parsing fails, use LLM to clean and parse
|
||||
logger.info("Direct JSON parsing failed, using LLM to clean JSON")
|
||||
return self._llm_clean_json(json_str)
|
||||
|
||||
def _llm_clean_json(self, malformed_json: str) -> Dict[str, Any]:
|
||||
"""Use LLM to clean and parse malformed JSON"""
|
||||
try:
|
||||
prompt = f"""
|
||||
The following text appears to be malformed JSON. Please clean it up and return valid JSON.
|
||||
If it's not possible to create valid JSON, return an empty object {{}}.
|
||||
|
||||
Original text:
|
||||
{malformed_json[:2000]} # Limit length for API
|
||||
|
||||
Return only the cleaned JSON, no explanations:
|
||||
"""
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-5-nano",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
cleaned_json = response.choices[0].message.content.strip()
|
||||
return json.loads(cleaned_json)
|
||||
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
|
||||
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM JSON cleaning failed: {e}")
|
||||
return {}
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
if not sector:
|
||||
sector = SectorTable(name=sector_name)
|
||||
db.add(sector)
|
||||
db.flush() # Get the ID without committing
|
||||
return sector
|
||||
|
||||
def extract_structured_data(self, csv_row: CSVRow) -> Dict[str, Any]:
|
||||
"""Extract and structure data from CSV row using LLM"""
|
||||
# Parse the investment firm profile
|
||||
profile_data = {}
|
||||
if csv_row.investment_firm_profile:
|
||||
profile_data = self.parse_json_field(csv_row.investment_firm_profile)
|
||||
|
||||
# Create structured output
|
||||
structured_data = {
|
||||
"name": csv_row.name,
|
||||
"website": csv_row.website or profile_data.get("websiteURL"),
|
||||
"investor_description": profile_data.get("investorDescription", ""),
|
||||
"investment_thesis_focus": profile_data.get("investmentThesisFocus", []),
|
||||
"headquarters": profile_data.get("headquarters", ""),
|
||||
"aum_info": profile_data.get("overallAssetsUnderManagement", {}),
|
||||
"funds_info": profile_data.get("funds", []),
|
||||
"crunchbase_urls": csv_row.crunchbase_linkedin_urls or "",
|
||||
"crunchbase_extract": csv_row.crunchbase_firm_extract or "",
|
||||
"linkedin_profile": csv_row.linkedin_investment_profile or "",
|
||||
"source_truth_profile": csv_row.source_of_truth_profile or "",
|
||||
}
|
||||
|
||||
return structured_data
|
||||
|
||||
def enhance_with_llm(self, investor_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Use LLM to enhance and standardize investor data"""
|
||||
try:
|
||||
# Combine all available text for context
|
||||
context_text = " ".join(
|
||||
[
|
||||
investor_data.get("investor_description", ""),
|
||||
investor_data.get("crunchbase_extract", ""),
|
||||
investor_data.get("linkedin_profile", ""),
|
||||
investor_data.get("source_truth_profile", ""),
|
||||
]
|
||||
def _save_investor_to_db(
|
||||
self, db: Session, investor_data: InvestorData
|
||||
) -> InvestorTable:
|
||||
"""Save investor data to database"""
|
||||
# Create investor record
|
||||
investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
db.add(investor)
|
||||
db.flush() # Get the ID
|
||||
|
||||
if not context_text.strip():
|
||||
return investor_data
|
||||
|
||||
prompt = f"""
|
||||
Based on the following information about an investor, please extract and standardize:
|
||||
1. A concise investor description (2-3 sentences)
|
||||
2. Investment thesis focus areas (list of specific focus areas)
|
||||
3. Headquarters location (city, country format)
|
||||
|
||||
Investor: {investor_data["name"]}
|
||||
Context: {context_text[:3000]} # Limit for API
|
||||
|
||||
Return in JSON format:
|
||||
{{
|
||||
"enhanced_description": "concise description here",
|
||||
"standardized_focus": ["focus area 1", "focus area 2", ...],
|
||||
"standardized_headquarters": "City, Country"
|
||||
}}
|
||||
"""
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
# Add team members
|
||||
for member_data in investor_data.team_members:
|
||||
member = InvestorMember(
|
||||
name=member_data.name,
|
||||
role=member_data.role,
|
||||
email=member_data.email,
|
||||
investor_id=investor.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
enhanced_data = json.loads(response.choices[0].message.content)
|
||||
# Add sectors
|
||||
for sector_data in investor_data.sectors:
|
||||
sector = self._get_or_create_sector(db, sector_data.name)
|
||||
investor.sectors.append(sector)
|
||||
|
||||
# Update investor data with enhanced information
|
||||
if enhanced_data.get("enhanced_description"):
|
||||
investor_data["enhanced_description"] = enhanced_data[
|
||||
"enhanced_description"
|
||||
]
|
||||
# Add portfolio companies
|
||||
for company_schema in investor_data.portfolio_companies:
|
||||
# Convert CompanySchema to CompanyData format
|
||||
company_data = CompanyData(
|
||||
company=company_schema,
|
||||
sectors=[], # Will be empty for portfolio companies
|
||||
members=[], # Will be empty for portfolio companies
|
||||
investors=[], # Will be empty for portfolio companies
|
||||
)
|
||||
company = self._save_company_to_db(db, company_data, skip_investors=True)
|
||||
investor.portfolio_companies.append(company)
|
||||
|
||||
if enhanced_data.get("standardized_focus"):
|
||||
investor_data["standardized_focus"] = enhanced_data[
|
||||
"standardized_focus"
|
||||
]
|
||||
return investor
|
||||
|
||||
if enhanced_data.get("standardized_headquarters"):
|
||||
investor_data["standardized_headquarters"] = enhanced_data[
|
||||
"standardized_headquarters"
|
||||
]
|
||||
|
||||
return investor_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM enhancement failed for {investor_data['name']}: {e}")
|
||||
return investor_data
|
||||
|
||||
def save_to_sql(self, investor_data: Dict[str, Any]) -> int:
|
||||
"""Save investor data to SQL database"""
|
||||
try:
|
||||
with get_session() as session:
|
||||
# Check if investor already exists
|
||||
existing = (
|
||||
session.query(Investor)
|
||||
.filter_by(name=investor_data["name"])
|
||||
def _save_company_to_db(
|
||||
self, db: Session, company_data: CompanyData, skip_investors: bool = False
|
||||
) -> CompanyTable:
|
||||
"""Save company data to database"""
|
||||
# Check if company already exists
|
||||
existing_company = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.company.name)
|
||||
.first()
|
||||
)
|
||||
if existing_company:
|
||||
return existing_company
|
||||
|
||||
if existing:
|
||||
logger.info(f"Updating existing investor: {investor_data['name']}")
|
||||
investor = existing
|
||||
# Create company record
|
||||
company = CompanyTable(
|
||||
name=company_data.company.name,
|
||||
industry=company_data.company.industry,
|
||||
location=company_data.company.location,
|
||||
description=company_data.company.description,
|
||||
founded_year=company_data.company.founded_year,
|
||||
website=company_data.company.website,
|
||||
)
|
||||
db.add(company)
|
||||
db.flush() # Get the ID
|
||||
|
||||
# Add company members
|
||||
for member_data in company_data.members:
|
||||
if member_data.name: # Only add members with names
|
||||
member = CompanyMember(
|
||||
name=member_data.name,
|
||||
linkedin=member_data.linkedin,
|
||||
role=member_data.role,
|
||||
company_id=company.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# Add sectors
|
||||
for sector_data in company_data.sectors:
|
||||
sector = self._get_or_create_sector(db, sector_data.name)
|
||||
company.sectors.append(sector)
|
||||
|
||||
# Add investors (if not skipping to avoid circular references)
|
||||
if not skip_investors:
|
||||
for investor_data in company_data.investors:
|
||||
# Look for existing investor by name
|
||||
existing_investor = (
|
||||
db.query(InvestorTable)
|
||||
.filter(InvestorTable.name == investor_data.name)
|
||||
.first()
|
||||
)
|
||||
if existing_investor:
|
||||
company.investors.append(existing_investor)
|
||||
|
||||
return company
|
||||
|
||||
async def _process_row(
|
||||
self, row: pd.Series, row_idx: int, is_investor: bool = True
|
||||
) -> Optional[InvestorData | CompanyData]:
|
||||
"""Process a single row of data"""
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value).replace("\n", " ").replace("\r", " ").replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
row_str = ", ".join([f"{key}: {value}" for key, value in cleaned_row.items()])
|
||||
try:
|
||||
print(f"Processing row {row_idx + 1}...")
|
||||
if is_investor:
|
||||
result = await self.investor_structured_llm.ainvoke(row_str)
|
||||
else:
|
||||
logger.info(f"Creating new investor: {investor_data['name']}")
|
||||
investor = Investor()
|
||||
|
||||
# Map data to investor object
|
||||
investor.name = investor_data["name"]
|
||||
investor.website = investor_data.get("website")
|
||||
investor.investor_description = investor_data.get(
|
||||
"enhanced_description"
|
||||
) or investor_data.get("investor_description")
|
||||
investor.investment_thesis_focus = investor_data.get(
|
||||
"standardized_focus"
|
||||
) or investor_data.get("investment_thesis_focus")
|
||||
investor.headquarters = investor_data.get(
|
||||
"standardized_headquarters"
|
||||
) or investor_data.get("headquarters")
|
||||
|
||||
# AUM information
|
||||
aum_info = investor_data.get("aum_info", {})
|
||||
investor.aum_amount = aum_info.get("aumAmount")
|
||||
investor.aum_as_of_date = aum_info.get("asOfDate")
|
||||
investor.aum_source_url = aum_info.get("sourceUrl")
|
||||
|
||||
# Fund information
|
||||
investor.funds_info = investor_data.get("funds_info", [])
|
||||
|
||||
# Raw data
|
||||
investor.crunchbase_urls = investor_data.get("crunchbase_urls")
|
||||
investor.crunchbase_extract = investor_data.get("crunchbase_extract")
|
||||
investor.linkedin_profile = investor_data.get("linkedin_profile")
|
||||
investor.source_truth_profile = investor_data.get(
|
||||
"source_truth_profile"
|
||||
)
|
||||
|
||||
if not existing:
|
||||
session.add(investor)
|
||||
|
||||
session.flush() # Get the ID
|
||||
return investor.id
|
||||
|
||||
result = await self.company_structured_llm.ainvoke(row_str)
|
||||
if result:
|
||||
return result.model_dump()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to SQL: {e}")
|
||||
raise
|
||||
|
||||
def save_to_vector_db(self, investor_id: int, investor_data: Dict[str, Any]):
|
||||
"""Save investor description and focus to ChromaDB"""
|
||||
try:
|
||||
# Prepare text for embedding
|
||||
description_text = investor_data.get(
|
||||
"enhanced_description"
|
||||
) or investor_data.get("investor_description", "")
|
||||
focus_areas = investor_data.get("standardized_focus") or investor_data.get(
|
||||
"investment_thesis_focus", []
|
||||
)
|
||||
|
||||
if isinstance(focus_areas, list):
|
||||
focus_text = " ".join(focus_areas)
|
||||
else:
|
||||
focus_text = str(focus_areas)
|
||||
|
||||
# Combine description and focus for embedding
|
||||
combined_text = f"{description_text} {focus_text}".strip()
|
||||
|
||||
if not combined_text:
|
||||
logger.warning(f"No text to embed for investor {investor_data['name']}")
|
||||
return
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"investor_id": investor_id,
|
||||
"name": investor_data["name"],
|
||||
"website": investor_data.get("website", ""),
|
||||
"headquarters": investor_data.get("standardized_headquarters")
|
||||
or investor_data.get("headquarters", ""),
|
||||
"focus_areas_count": len(focus_areas)
|
||||
if isinstance(focus_areas, list)
|
||||
else 0,
|
||||
}
|
||||
|
||||
# Add to ChromaDB
|
||||
self.collection.add(
|
||||
documents=[combined_text],
|
||||
metadatas=[metadata],
|
||||
ids=[f"investor_{investor_id}"],
|
||||
)
|
||||
|
||||
logger.info(f"Added investor {investor_data['name']} to vector database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to vector DB: {e}")
|
||||
|
||||
def process_csv_file(self, csv_file_path: str, limit: Optional[int] = None):
|
||||
"""Process the entire CSV file"""
|
||||
logger.info(f"Starting to process CSV file: {csv_file_path}")
|
||||
|
||||
# Read CSV
|
||||
df = pd.read_csv(csv_file_path)
|
||||
logger.info(f"Loaded {len(df)} rows from CSV")
|
||||
|
||||
if limit:
|
||||
df = df.head(limit)
|
||||
logger.info(f"Processing limited to {limit} rows")
|
||||
|
||||
processed_count = 0
|
||||
error_count = 0
|
||||
|
||||
for index, row in df.iterrows():
|
||||
try:
|
||||
logger.info(f"Processing row {index + 1}/{len(df)}: {row['Name']}")
|
||||
|
||||
# Create CSVRow object
|
||||
csv_row = CSVRow(
|
||||
name=row["Name"],
|
||||
website=row.get("Website"),
|
||||
investment_firm_profile=row.get("Investment Firm Profile"),
|
||||
crunchbase_linkedin_urls=row.get("Crunchbase & LinkedIn URLs"),
|
||||
crunchbase_firm_extract=row.get("Crunchbase Firm Extract"),
|
||||
linkedin_investment_profile=row.get("LinkedIn Investment Profile"),
|
||||
source_of_truth_profile=row.get("Source of Truth Profile"),
|
||||
)
|
||||
|
||||
# Extract structured data
|
||||
structured_data = self.extract_structured_data(csv_row)
|
||||
|
||||
# Enhance with LLM
|
||||
enhanced_data = self.enhance_with_llm(structured_data)
|
||||
|
||||
# Save to SQL database
|
||||
investor_id = self.save_to_sql(enhanced_data)
|
||||
|
||||
# Save to vector database
|
||||
self.save_to_vector_db(investor_id, enhanced_data)
|
||||
|
||||
processed_count += 1
|
||||
|
||||
# Progress update every 10 rows
|
||||
if (index + 1) % 10 == 0:
|
||||
logger.info(
|
||||
f"Processed {processed_count} rows successfully, {error_count} errors"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"Error processing row {index + 1} ({row.get('Name', 'Unknown')}): {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Processing complete! Processed: {processed_count}, Errors: {error_count}"
|
||||
)
|
||||
return processed_count, error_count
|
||||
|
||||
def search_investors(self, query: str, limit: int = 5):
|
||||
"""Search investors using vector similarity"""
|
||||
try:
|
||||
results = self.collection.query(query_texts=[query], n_results=limit)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(self, df, save_to_db: bool = True):
|
||||
"""Parse investors from DataFrame and optionally save to database"""
|
||||
investors = []
|
||||
|
||||
def main():
|
||||
"""Main function to run the parser"""
|
||||
parser = LLMInvestorParser()
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
# Process the CSV file
|
||||
csv_file = "/home/oluwasanmi/Documents/Work/MKD/anton_wireframe/New Excerpt 5 investors - Sheet1 parse.csv"
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 15 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
# Start with a small sample for testing
|
||||
processed, errors = parser.process_csv_file(csv_file, limit=5)
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
print("\nProcessing complete!")
|
||||
print(f"Successfully processed: {processed} investors")
|
||||
print(f"Errors encountered: {errors}")
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=True) for idx, row in batch
|
||||
]
|
||||
|
||||
# Test search functionality
|
||||
print("\nTesting search functionality...")
|
||||
results = parser.search_investors("bioeconomy circular economy")
|
||||
if results:
|
||||
print(f"Found {len(results['documents'][0])} similar investors")
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
print(f" {i + 1}. {results['metadatas'][0][i]['name']}")
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
if result:
|
||||
# Convert dict to InvestorData if needed
|
||||
if isinstance(result, dict):
|
||||
investor_data = InvestorData(**result)
|
||||
else:
|
||||
investor_data = result
|
||||
|
||||
investors.append(investor_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_investor = self._save_investor_to_db(
|
||||
db, investor_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved investor '{saved_investor.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save investor to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in batch processing: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return investors
|
||||
|
||||
async def parse_companies(self, df, save_to_db: bool = True):
|
||||
"""Parse companies from DataFrame and optionally save to database"""
|
||||
companies = []
|
||||
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 15 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=False) for idx, row in batch
|
||||
]
|
||||
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
if result:
|
||||
# Convert dict to CompanyData if needed
|
||||
if isinstance(result, dict):
|
||||
company_data = CompanyData(**result)
|
||||
else:
|
||||
company_data = result
|
||||
|
||||
companies.append(company_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_company = self._save_company_to_db(
|
||||
db, company_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved company '{saved_company.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save company to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing row {idx}: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return companies
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# async def main():
|
||||
# """Main execution function"""
|
||||
# # Initialize database tables
|
||||
# print("🔧 Initializing database...")
|
||||
# init_database()
|
||||
|
||||
# # Create processor
|
||||
# processor = InvestorProcessor()
|
||||
|
||||
# print("📊 Processing companies...")
|
||||
# companies = await processor.parse_companies(
|
||||
# "data/19 Companies data.csv", save_to_db=True
|
||||
# )
|
||||
# print(f"Processed {len(companies)} companies")
|
||||
|
||||
# print("\n💰 Processing investors...")
|
||||
# investors = await processor.parse_investors(
|
||||
# "data/19 Investors data.csv", save_to_db=True
|
||||
# )
|
||||
# print(f"Processed {len(investors)} investors")
|
||||
# print("\n✨ Processing complete!")
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_list: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
|
||||
|
||||
Given the following CSV data rows:
|
||||
{question}
|
||||
|
||||
For each row, extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a structured list of comprehensive investor data."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
)
|
||||
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
||||
self.sql_session = sql_session
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_batch(
|
||||
self, batch: pd.DataFrame, batch_idx: int
|
||||
) -> List[InvestorData]:
|
||||
"""Process a single batch of data"""
|
||||
# Convert batch to string representation - clean the data
|
||||
batch_str = ""
|
||||
for idx, row in batch.iterrows():
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value)
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
row_str = ", ".join(
|
||||
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
||||
)
|
||||
batch_str += f"Row {idx + 1}: {row_str}\n"
|
||||
|
||||
try:
|
||||
print(f"Processing batch {batch_idx + 1}...")
|
||||
batch_results = await self.structured_llm.ainvoke(batch_str)
|
||||
return batch_results.investor_list
|
||||
except Exception as e:
|
||||
print(f"Error processing batch {batch_idx + 1}: {e}")
|
||||
return []
|
||||
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data in parallel batches and save to databases"""
|
||||
results = []
|
||||
|
||||
# Create batches
|
||||
batches = []
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch = df.iloc[i : i + batch_size]
|
||||
batches.append((batch, i // batch_size))
|
||||
|
||||
# Process batches with concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(batch_data):
|
||||
batch, batch_idx = batch_data
|
||||
async with semaphore:
|
||||
return await self._process_batch(batch, batch_idx)
|
||||
|
||||
# Execute all batches concurrently
|
||||
batch_results = await asyncio.gather(
|
||||
*[process_with_semaphore(batch_data) for batch_data in batches],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Collect results, filtering out exceptions
|
||||
for batch_result in batch_results:
|
||||
if not isinstance(batch_result, Exception):
|
||||
results.extend(batch_result)
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
return results
|
||||
@@ -1,290 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
|
||||
class InvestorOutput(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_data: InvestorData
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a structured record.
|
||||
|
||||
Given the following CSV data row:
|
||||
{question}
|
||||
|
||||
Extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a single comprehensive investor data record."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
)
|
||||
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorOutput)
|
||||
self.sql_session = sql_session
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_row(
|
||||
self, row: pd.Series, row_idx: int
|
||||
) -> Optional[InvestorData]:
|
||||
"""Process a single row of data"""
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value)
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
row_str = ", ".join(
|
||||
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Processing row {row_idx + 1}...")
|
||||
result = await self.structured_llm.ainvoke(row_str)
|
||||
if result.investor_data:
|
||||
return result.investor_data
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, max_concurrent: int = 10
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data one row at a time and save to databases"""
|
||||
results = []
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_row_with_semaphore(row_data):
|
||||
row, row_idx = row_data
|
||||
async with semaphore:
|
||||
return await self._process_row(row, row_idx)
|
||||
|
||||
# Create row tasks
|
||||
row_tasks = []
|
||||
for idx, row in df.iterrows():
|
||||
row_tasks.append((row, idx))
|
||||
|
||||
# Execute all rows concurrently
|
||||
row_results = await asyncio.gather(
|
||||
*[process_row_with_semaphore(row_data) for row_data in row_tasks],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Collect results, filtering out exceptions and None values
|
||||
for row_result in row_results:
|
||||
if not isinstance(row_result, Exception) and row_result is not None:
|
||||
results.append(row_result)
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
return results
|
||||
+53
-215
@@ -1,88 +1,47 @@
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import chromadb
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from py_schemas import InvestorData, InvestorList
|
||||
from settings import settings
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||
system_message = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n Get answers from the Sql database and the vector database"
|
||||
)
|
||||
db = SQLDatabase.from_uri(DATABASE_URL)
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.sql_session = sql_session
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
model="openai/gpt-5-nano",
|
||||
temperature=0.3,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||
prompt=system_message,
|
||||
tools=self.toolkit.get_tools(),
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
def query_sql_database(self, query: str) -> Optional[InvestorList]:
|
||||
"""Query the SQL database for investor information."""
|
||||
if not self.sql_session:
|
||||
return None
|
||||
|
||||
# Implement SQL querying logic here
|
||||
result = self.sql_session.execute(query)
|
||||
investors = result.scalars().all()
|
||||
return InvestorList(investors=investors)
|
||||
|
||||
def query_vector_database(self, query: str) -> Optional[InvestorList]:
|
||||
"""Query the vector database for investor information."""
|
||||
if not self.vector_db_client:
|
||||
return None
|
||||
print("VECTOR STORE WAS CALLED")
|
||||
|
||||
# Query the collection directly, not passing collection as parameter
|
||||
results = self.collection.query(
|
||||
query_texts=[query], # ChromaDB expects a list of query texts
|
||||
n_results=3, # Specify how many results you want
|
||||
)
|
||||
print(results)
|
||||
|
||||
# ChromaDB returns results in a different structure
|
||||
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
||||
return results
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return structured investor data."""
|
||||
# Extract filters from the query first
|
||||
filters = self._extract_filters_from_query(question)
|
||||
|
||||
# Get AI response for additional context
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
@@ -92,178 +51,54 @@ class QueryProcessor:
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Try to extract investor IDs or names from the AI response
|
||||
investor_ids = self._extract_investor_info_from_response(ai_response)
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
|
||||
# Fetch filtered investor data with relationships from database
|
||||
return self._fetch_investors_with_relationships(investor_ids, filters)
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
import re
|
||||
|
||||
def _extract_investor_info_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response. This is a simple implementation."""
|
||||
# This is a basic implementation - you might want to make it more sophisticated
|
||||
# based on how your AI formats responses
|
||||
investor_ids = []
|
||||
|
||||
# If the AI can't provide structured data, fall back to getting all investors
|
||||
# that match basic criteria
|
||||
try:
|
||||
# Try to extract numbers that might be IDs
|
||||
import re
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
|
||||
ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower())
|
||||
investor_ids = [int(id_str) for id_str in ids]
|
||||
except Exception:
|
||||
pass
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
return investor_ids if investor_ids else []
|
||||
except Exception as e:
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
def _extract_filters_from_query(self, question: str) -> dict:
|
||||
"""Extract filter criteria from natural language query."""
|
||||
question_lower = question.lower()
|
||||
filters = {}
|
||||
return investor_ids
|
||||
|
||||
# Extract stage filters
|
||||
if any(
|
||||
stage in question_lower
|
||||
for stage in [
|
||||
"seed",
|
||||
"series a",
|
||||
"series b",
|
||||
"series c",
|
||||
"growth",
|
||||
"late stage",
|
||||
]
|
||||
):
|
||||
if "seed" in question_lower:
|
||||
filters["stage"] = "SEED"
|
||||
elif "series a" in question_lower:
|
||||
filters["stage"] = "SERIES_A"
|
||||
elif "series b" in question_lower:
|
||||
filters["stage"] = "SERIES_B"
|
||||
elif "series c" in question_lower:
|
||||
filters["stage"] = "SERIES_C"
|
||||
elif "growth" in question_lower:
|
||||
filters["stage"] = "GROWTH"
|
||||
elif "late stage" in question_lower:
|
||||
filters["stage"] = "LATE_STAGE"
|
||||
|
||||
# Extract geographic filters
|
||||
if any(
|
||||
geo in question_lower
|
||||
for geo in [
|
||||
"us",
|
||||
"usa",
|
||||
"united states",
|
||||
"europe",
|
||||
"asia",
|
||||
"silicon valley",
|
||||
"bay area",
|
||||
]
|
||||
):
|
||||
if (
|
||||
"us" in question_lower
|
||||
or "usa" in question_lower
|
||||
or "united states" in question_lower
|
||||
):
|
||||
filters["geography"] = "US"
|
||||
elif "europe" in question_lower:
|
||||
filters["geography"] = "Europe"
|
||||
elif "asia" in question_lower:
|
||||
filters["geography"] = "Asia"
|
||||
elif "silicon valley" in question_lower or "bay area" in question_lower:
|
||||
filters["geography"] = "Silicon Valley"
|
||||
|
||||
# Extract sector filters
|
||||
sectors = [
|
||||
"fintech",
|
||||
"healthcare",
|
||||
"saas",
|
||||
"ai",
|
||||
"biotech",
|
||||
"consumer",
|
||||
"enterprise",
|
||||
"crypto",
|
||||
"blockchain",
|
||||
]
|
||||
for sector in sectors:
|
||||
if sector in question_lower:
|
||||
filters["sector"] = sector
|
||||
break
|
||||
|
||||
# Extract check size filters (simple patterns)
|
||||
import re
|
||||
|
||||
amounts = re.findall(
|
||||
r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower
|
||||
)
|
||||
if amounts:
|
||||
amount = amounts[0].replace(",", "")
|
||||
if "million" in question_lower or "m" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000000)
|
||||
elif "thousand" in question_lower or "k" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000)
|
||||
|
||||
return filters
|
||||
|
||||
def _fetch_investors_with_relationships(
|
||||
self, investor_ids: List[int] = None, filters: dict = None
|
||||
) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database."""
|
||||
if not self.sql_session:
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from db.models import SectorTable
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Build query with all relationships loaded
|
||||
query = self.sql_session.query(InvestorTable).options(
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
if "stage" in filters:
|
||||
from db.models import InvestmentStage
|
||||
|
||||
stage_enum = getattr(InvestmentStage, filters["stage"])
|
||||
query = query.filter(InvestorTable.stage_focus == stage_enum)
|
||||
|
||||
if "geography" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%")
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
)
|
||||
|
||||
if "min_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_lower >= filters["min_check_size"]
|
||||
)
|
||||
|
||||
if "max_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_upper <= filters["max_check_size"]
|
||||
)
|
||||
|
||||
if "min_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum >= filters["min_aum"])
|
||||
|
||||
if "max_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum <= filters["max_aum"])
|
||||
|
||||
if "sector" in filters:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{filters['sector']}%")
|
||||
)
|
||||
|
||||
# Filter by IDs if provided
|
||||
if investor_ids:
|
||||
query = query.filter(InvestorTable.id.in_(investor_ids))
|
||||
else:
|
||||
# If no specific IDs and no filters, limit to prevent overwhelming response
|
||||
if not filters:
|
||||
query = query.limit(10)
|
||||
|
||||
investors = query.all()
|
||||
|
||||
# Transform to InvestorData format
|
||||
@@ -278,3 +113,6 @@ class QueryProcessor:
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
OPENROUTER_API_KEY: str
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user