Compare commits
8 Commits
main
..
version_two
| Author | SHA1 | Date | |
|---|---|---|---|
| c5c94936f3 | |||
| 17bc5acbc8 | |||
| 6caea96658 | |||
| 6d902345c0 | |||
| d36367fbe9 | |||
| abac19c6ae | |||
| f2bbcb96f3 | |||
| 0f7beca5e1 |
+4
-2
@@ -8,8 +8,10 @@
|
||||
|
||||
/chroma_db
|
||||
|
||||
/*__pycache__*/
|
||||
*__pycache__
|
||||
|
||||
/*.db
|
||||
|
||||
/*.cypython-*
|
||||
*.cypython
|
||||
|
||||
/preprocessor
|
||||
@@ -1,577 +0,0 @@
|
||||
# LLM-Powered Investor & Company Management API
|
||||
|
||||
A comprehensive FastAPI-based system for managing investor and company data with LLM-powered CSV parsing, semantic search, and advanced filtering capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **FastAPI REST API**: Modern, auto-documented API with OpenAPI/Swagger support
|
||||
- **CSV Data Processing**: Parse complex investor data from CSV files using LLM assistance
|
||||
- **Dual Database Storage**: Structured data in SQL database and semantic search via ChromaDB
|
||||
- **Natural Language Queries**: AI-powered query processing for complex investor searches
|
||||
- **Advanced Filtering**: Filter investors and companies by multiple criteria
|
||||
- **Relationship Management**: Many-to-many relationships between investors, companies, and sectors
|
||||
- **Auto-Generated Documentation**: Interactive API docs at `/docs`
|
||||
|
||||
## Architecture
|
||||
|
||||
### Components
|
||||
|
||||
1. **FastAPI Application (`app/main.py`)**: Main API server with route configuration
|
||||
2. **Database Models (`app/db/models.py`)**: SQLAlchemy models for investors, companies, sectors
|
||||
3. **Pydantic Schemas (`app/py_schemas.py`)**: Request/response validation and serialization
|
||||
4. **API Routes**:
|
||||
- `app/api/investors.py`: Investor CRUD operations and filtering
|
||||
- `app/api/companies.py`: Company CRUD operations and filtering
|
||||
5. **Services**:
|
||||
- `app/services/openrouter.py`: LLM-powered CSV processing
|
||||
- `app/services/querying.py`: Natural language query processing
|
||||
6. **Database (`app/db/`)**: Database connection, models, and schemas
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
CSV Upload → LLM Processing → Data Extraction → SQL Storage → Vector Storage → API Endpoints
|
||||
↓
|
||||
Natural Language Query → AI Analysis → Database Filtering → Structured Response
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- FastAPI and dependencies
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the repository and navigate to the project directory:
|
||||
|
||||
```bash
|
||||
cd /path/to/anton_wireframe
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Configure environment variables:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your OpenRouter API key for LLM features
|
||||
```
|
||||
|
||||
4. Initialize the database:
|
||||
|
||||
```bash
|
||||
cd app
|
||||
python -c "from db.db import init_database; init_database()"
|
||||
```
|
||||
|
||||
5. Start the API server:
|
||||
|
||||
```bash
|
||||
cd app
|
||||
uvicorn main:app --reload --host localhost --port 8000
|
||||
```
|
||||
|
||||
The API will be available at:
|
||||
|
||||
- **API Base**: http://localhost:8000
|
||||
- **Interactive Docs**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
## Database Schema
|
||||
|
||||
### SQL Database (SQLite)
|
||||
|
||||
#### Investors Table
|
||||
|
||||
- **Basic Info**: name, description, geographic_focus
|
||||
- **Investment Data**: aum, check_size_lower, check_size_upper
|
||||
- **Stage Focus**: investment stage (SEED, SERIES_A, etc.)
|
||||
- **Relationships**: Many-to-many with companies and sectors
|
||||
- **Team**: One-to-many with team members
|
||||
- **Metadata**: created_at, updated_at timestamps
|
||||
|
||||
#### Companies Table
|
||||
|
||||
- **Basic Info**: name, industry, location
|
||||
- **Details**: founded_year, website
|
||||
- **Relationships**: Many-to-many with investors
|
||||
- **Metadata**: created_at, updated_at timestamps
|
||||
|
||||
#### Association Tables
|
||||
|
||||
- **investor_companies**: Links investors to their portfolio companies
|
||||
- **investor_sectors**: Links investors to their focus sectors
|
||||
- **investor_team**: Team member details for each investor
|
||||
|
||||
#### Supporting Tables
|
||||
|
||||
- **sectors**: Investment focus areas (fintech, healthcare, etc.)
|
||||
|
||||
### Vector Database (ChromaDB)
|
||||
|
||||
Stores embeddings for semantic search of:
|
||||
|
||||
- Investor descriptions
|
||||
- Investment thesis focus areas
|
||||
- Combined investor profiles
|
||||
|
||||
## API Usage
|
||||
|
||||
### Interactive Documentation
|
||||
|
||||
Visit http://localhost:8000/docs for the auto-generated Swagger UI where you can:
|
||||
|
||||
- Explore all endpoints
|
||||
- Test API calls directly
|
||||
- View request/response schemas
|
||||
- See example requests
|
||||
|
||||
### Core Endpoints
|
||||
|
||||
#### Investor Management
|
||||
|
||||
```bash
|
||||
# Get all investors with relationships
|
||||
GET /investors
|
||||
|
||||
# Filter investors by criteria
|
||||
GET /investors/filter?stage=GROWTH&geography=US§or=fintech&min_check_size=1000000
|
||||
|
||||
# Get specific investor
|
||||
GET /investors/{investor_id}
|
||||
|
||||
# Create new investor
|
||||
POST /investors
|
||||
{
|
||||
"name": "Example VC",
|
||||
"description": "Early stage fintech investor",
|
||||
"aum": 50000000,
|
||||
"check_size_lower": 100000,
|
||||
"check_size_upper": 2000000,
|
||||
"geographic_focus": "US",
|
||||
"stage_focus": "SEED",
|
||||
"number_of_investments": 25
|
||||
}
|
||||
|
||||
# Update investor
|
||||
PUT /investors/{investor_id}
|
||||
|
||||
# Delete investor
|
||||
DELETE /investors/{investor_id}
|
||||
```
|
||||
|
||||
#### Company Management
|
||||
|
||||
```bash
|
||||
# Get all companies with investor relationships
|
||||
GET /companies
|
||||
|
||||
# Filter companies by criteria
|
||||
GET /companies/filter?industry=fintech&location=San Francisco&founded_after=2015
|
||||
|
||||
# Get specific company
|
||||
GET /companies/{company_id}
|
||||
|
||||
# Create new company
|
||||
POST /companies
|
||||
{
|
||||
"name": "Example Startup",
|
||||
"industry": "fintech",
|
||||
"location": "San Francisco",
|
||||
"founded_year": 2020,
|
||||
"website": "https://example.com"
|
||||
}
|
||||
|
||||
# Update company
|
||||
PUT /companies/{company_id}
|
||||
|
||||
# Delete company
|
||||
DELETE /companies/{company_id}
|
||||
```
|
||||
|
||||
#### CSV Processing
|
||||
|
||||
```bash
|
||||
# Upload and process CSV file
|
||||
POST /parse-csv
|
||||
Content-Type: multipart/form-data
|
||||
File: investors.csv
|
||||
```
|
||||
|
||||
#### Natural Language Queries
|
||||
|
||||
```bash
|
||||
# Query investors using natural language
|
||||
POST /query
|
||||
{
|
||||
"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $1 million"
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Filtering Examples
|
||||
|
||||
#### Investor Filters
|
||||
|
||||
```bash
|
||||
# Early stage investors in Europe
|
||||
GET /investors/filter?stage=SEED&geography=Europe
|
||||
|
||||
# High AUM growth investors
|
||||
GET /investors/filter?stage=GROWTH&min_aum=100000000
|
||||
|
||||
# Healthcare investors with large checks
|
||||
GET /investors/filter?sector=healthcare&min_check_size=5000000
|
||||
|
||||
# Specific geographic focus
|
||||
GET /investors/filter?geography=Silicon Valley
|
||||
```
|
||||
|
||||
#### Company Filters
|
||||
|
||||
```bash
|
||||
# Recent fintech companies
|
||||
GET /companies/filter?industry=fintech&founded_after=2020
|
||||
|
||||
# Companies with websites
|
||||
GET /companies/filter?has_website=true
|
||||
|
||||
# Companies backed by specific investor
|
||||
GET /companies/filter?investor_name=Sequoia
|
||||
|
||||
# Location-based filtering
|
||||
GET /companies/filter?location=New York
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
All endpoints return structured JSON with full relationship data:
|
||||
|
||||
```json
|
||||
{
|
||||
"investor": {
|
||||
"id": 1,
|
||||
"name": "Example VC",
|
||||
"description": "Early stage investor",
|
||||
"aum": 50000000,
|
||||
"check_size_lower": 100000,
|
||||
"check_size_upper": 2000000,
|
||||
"geographic_focus": "US",
|
||||
"stage_focus": "SEED",
|
||||
"number_of_investments": 25
|
||||
},
|
||||
"portfolio_companies": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "StartupCo",
|
||||
"industry": "fintech",
|
||||
"location": "San Francisco"
|
||||
}
|
||||
],
|
||||
"team_members": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "John Partner",
|
||||
"role": "Managing Partner",
|
||||
"email": "john@examplevc.com"
|
||||
}
|
||||
],
|
||||
"sectors": [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "fintech"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Data Processing Pipeline
|
||||
|
||||
### 1. CSV Parsing
|
||||
|
||||
- Reads CSV with pandas
|
||||
- Handles nested JSON fields in columns
|
||||
- Validates data with Pydantic models
|
||||
|
||||
### 2. JSON Field Processing
|
||||
|
||||
- Direct parsing for well-formed JSON
|
||||
- LLM-assisted cleaning for malformed JSON (when enabled)
|
||||
- Graceful fallback to empty objects
|
||||
|
||||
### 3. Data Extraction
|
||||
|
||||
Extracts key fields:
|
||||
|
||||
- Company name and website
|
||||
- Investor description
|
||||
- Investment thesis/focus areas
|
||||
- Headquarters location
|
||||
- Assets Under Management (AUM)
|
||||
- Fund information
|
||||
|
||||
### 4. LLM Enhancement (Optional)
|
||||
|
||||
When `--use-llm` is enabled:
|
||||
|
||||
- Standardizes investor descriptions
|
||||
- Normalizes investment focus areas
|
||||
- Cleans headquarters location format
|
||||
- Repairs malformed JSON data
|
||||
|
||||
### 5. Dual Storage
|
||||
|
||||
- **SQL Database**: Structured, queryable data
|
||||
- **Vector Database**: Semantic search capabilities
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables (.env)
|
||||
|
||||
```bash
|
||||
# OpenRouter API Configuration (required for LLM features)
|
||||
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||
|
||||
# Database Configuration (optional, defaults to SQLite)
|
||||
DATABASE_URL=sqlite:///investors.db
|
||||
|
||||
# FastAPI Configuration
|
||||
API_HOST=localhost
|
||||
API_PORT=8000
|
||||
```
|
||||
|
||||
### LLM Configuration
|
||||
|
||||
- **Provider**: OpenRouter (supports multiple models)
|
||||
- **Default Model**: google/gemini-2.5-flash-lite
|
||||
- **Temperature**: 0.3 for enhancement, 0 for structured data
|
||||
- **Fallback**: Graceful degradation when API unavailable
|
||||
|
||||
## Natural Language Query Processing
|
||||
|
||||
The system supports intelligent natural language queries that automatically extract filters and search criteria:
|
||||
|
||||
### Query Examples
|
||||
|
||||
```bash
|
||||
# Stage-based queries
|
||||
"Show me seed stage investors"
|
||||
"Find growth stage VCs"
|
||||
|
||||
# Geographic queries
|
||||
"Investors in Silicon Valley"
|
||||
"European venture capital firms"
|
||||
|
||||
# Sector-specific queries
|
||||
"Fintech investors"
|
||||
"Healthcare and biotech VCs"
|
||||
|
||||
# Size-based queries
|
||||
"Investors with $5M+ check sizes"
|
||||
"High AUM growth investors"
|
||||
|
||||
# Combined queries
|
||||
"Growth stage fintech investors in the US with check sizes over $1 million"
|
||||
"European healthcare investors focusing on early stage"
|
||||
```
|
||||
|
||||
### Query Processing Features
|
||||
|
||||
- **Automatic Filter Extraction**: Detects investment stages, geographies, sectors, and check sizes
|
||||
- **Semantic Understanding**: Uses AI to interpret complex queries
|
||||
- **Database Integration**: Combines AI analysis with efficient SQL filtering
|
||||
- **Complete Relationships**: Returns full investor data with portfolio companies, team members, and sectors
|
||||
|
||||
### Query Response
|
||||
|
||||
The `/query` endpoint returns a structured `InvestorList` with complete relationship data, making it easy to get comprehensive information about matching investors.
|
||||
|
||||
## Error Handling
|
||||
|
||||
### API Error Responses
|
||||
|
||||
The API provides clear HTTP status codes and error messages:
|
||||
|
||||
```json
|
||||
// 404 Not Found
|
||||
{
|
||||
"detail": "Investor not found"
|
||||
}
|
||||
|
||||
// 422 Validation Error
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["body", "stage_focus"],
|
||||
"msg": "value is not a valid enumeration member",
|
||||
"type": "type_error.enum"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Robust Processing
|
||||
|
||||
- **Data Validation**: Pydantic models ensure data integrity
|
||||
- **Relationship Management**: Automatic handling of foreign key constraints
|
||||
- **LLM Fallbacks**: Graceful degradation when AI services unavailable
|
||||
- **Transaction Safety**: Database rollbacks on errors
|
||||
- **Comprehensive Logging**: Detailed error tracking and debugging
|
||||
|
||||
### Common Issues and Solutions
|
||||
|
||||
1. **Invalid Enum Values**
|
||||
|
||||
- Solution: Use uppercase enum values (SEED, GROWTH, etc.)
|
||||
- Check: Investment stages must match defined enum
|
||||
|
||||
2. **Missing OpenRouter API Key**
|
||||
|
||||
- Solution: Set OPENROUTER_API_KEY in environment
|
||||
- Fallback: CSV processing continues without LLM enhancement
|
||||
|
||||
3. **Database Connection Issues**
|
||||
|
||||
- Solution: Verify DATABASE_URL configuration
|
||||
- Default: Uses SQLite (no external dependencies)
|
||||
|
||||
4. **Relationship Errors**
|
||||
- Solution: Ensure proper foreign key relationships
|
||||
- Check: Use existing sector/company IDs or create new ones
|
||||
|
||||
## Performance
|
||||
|
||||
### Benchmarks (Approximate)
|
||||
|
||||
- **API Response Time**: <200ms for standard queries
|
||||
- **Database Queries**: <50ms for filtered searches with relationships
|
||||
- **CSV Processing**: ~5-15 seconds per row (depends on LLM API latency)
|
||||
- **Natural Language Queries**: ~2-5 seconds (AI processing + database query)
|
||||
- **Vector Search**: <100ms for semantic similarity queries
|
||||
|
||||
### Optimization Features
|
||||
|
||||
1. **Eager Loading**: Efficient relationship loading with `selectinload()`
|
||||
2. **Query Optimization**: Smart filtering to reduce database load
|
||||
3. **Caching**: Database connection pooling and session management
|
||||
4. **Pagination**: Built-in limits to prevent overwhelming responses
|
||||
5. **Async Processing**: FastAPI async capabilities for better performance
|
||||
|
||||
### Production Recommendations
|
||||
|
||||
1. **Database**: Consider PostgreSQL for production workloads
|
||||
2. **Caching**: Add Redis for frequently accessed data
|
||||
3. **Load Balancing**: Deploy multiple API instances behind a load balancer
|
||||
4. **Monitoring**: Implement logging and metrics collection
|
||||
5. **Rate Limiting**: Add API rate limiting for public endpoints
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
anton_wireframe/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI application and main endpoints
|
||||
│ ├── py_schemas.py # Pydantic models for validation
|
||||
│ ├── settings.py # Configuration management
|
||||
│ ├── api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── investors.py # Investor CRUD and filtering endpoints
|
||||
│ │ └── companies.py # Company CRUD and filtering endpoints
|
||||
│ ├── db/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── db.py # Database connection and session management
|
||||
│ │ ├── models.py # SQLAlchemy database models
|
||||
│ │ └── new_schema.py # Additional schema definitions
|
||||
│ └── services/
|
||||
│ ├── __init__.py
|
||||
│ ├── openrouter.py # LLM-powered CSV processing
|
||||
│ ├── querying.py # Natural language query processing
|
||||
│ └── langgraph_agent.py # AI agent configuration
|
||||
├── chroma_db/ # Vector database directory
|
||||
├── requirements.txt # Python dependencies
|
||||
├── README.md # This documentation
|
||||
└── .env # Environment configuration
|
||||
```
|
||||
|
||||
## Example Usage Scenarios
|
||||
|
||||
### 1. Upload and Process Investor Data
|
||||
|
||||
```bash
|
||||
# Upload CSV file via API
|
||||
curl -X POST "http://localhost:8000/parse-csv" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@investors.csv"
|
||||
```
|
||||
|
||||
### 2. Find Specific Investors
|
||||
|
||||
```bash
|
||||
# Natural language search
|
||||
curl -X POST "http://localhost:8000/query" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question": "Show me growth stage fintech investors in Silicon Valley with check sizes over $2 million"}'
|
||||
|
||||
# Structured filtering
|
||||
curl "http://localhost:8000/investors/filter?stage=GROWTH§or=fintech&geography=Silicon%20Valley&min_check_size=2000000"
|
||||
```
|
||||
|
||||
### 3. Company Research
|
||||
|
||||
```bash
|
||||
# Find companies in specific sector
|
||||
curl "http://localhost:8000/companies/filter?industry=fintech&founded_after=2020"
|
||||
|
||||
# Find companies backed by specific investor
|
||||
curl "http://localhost:8000/companies/filter?investor_name=Sequoia"
|
||||
```
|
||||
|
||||
### 4. Investment Analysis
|
||||
|
||||
```bash
|
||||
# Get investor with full portfolio
|
||||
curl "http://localhost:8000/investors/1"
|
||||
|
||||
# Find all companies in a specific location
|
||||
curl "http://localhost:8000/companies/filter?location=San%20Francisco"
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Running in Development Mode
|
||||
|
||||
```bash
|
||||
cd app
|
||||
uvicorn main:app --reload --host localhost --port 8000
|
||||
```
|
||||
|
||||
### Testing the API
|
||||
|
||||
1. **Interactive Testing**: Visit http://localhost:8000/docs
|
||||
2. **Manual Testing**: Use curl or Postman with the examples above
|
||||
3. **Database Inspection**: Use SQLite browser to inspect `investors_2.db`
|
||||
|
||||
### Adding New Features
|
||||
|
||||
1. **New Endpoints**: Add routes to `api/investors.py` or `api/companies.py`
|
||||
2. **New Models**: Update `db/models.py` and `py_schemas.py`
|
||||
3. **New Filters**: Extend filtering logic in route handlers
|
||||
4. **New LLM Features**: Modify `services/openrouter.py` or `services/querying.py`
|
||||
|
||||
## License
|
||||
|
||||
This project is part of the MKD Anton Wireframe system.
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions:
|
||||
|
||||
1. Check logs for detailed error messages
|
||||
2. Verify environment configuration
|
||||
3. Test with limited datasets first
|
||||
4. Review CSV data format requirements
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,46 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from db.models import InvestorTable
|
||||
from db.db import get_db
|
||||
|
||||
def update_stage_focus_values():
|
||||
"""Update existing stage_focus values from lowercase to uppercase"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Mapping of old lowercase values to new uppercase values
|
||||
stage_mappings = {
|
||||
'seed': 'SEED',
|
||||
'series_a': 'SERIES_A',
|
||||
'series_b': 'SERIES_B',
|
||||
'series_c': 'SERIES_C',
|
||||
'growth': 'GROWTH',
|
||||
'late_stage': 'LATE_STAGE'
|
||||
}
|
||||
|
||||
updated_count = 0
|
||||
|
||||
for old_value, new_value in stage_mappings.items():
|
||||
# Update records with the old value
|
||||
result = db.query(InvestorTable).filter(
|
||||
InvestorTable.stage_focus == old_value
|
||||
).update(
|
||||
{InvestorTable.stage_focus: new_value},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
updated_count += result
|
||||
print(f"Updated {result} records from '{old_value}' to '{new_value}'")
|
||||
|
||||
db.commit()
|
||||
print(f"Successfully updated {updated_count} total records")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"Error updating stage_focus values: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Run the update
|
||||
if __name__ == "__main__":
|
||||
update_stage_focus_values()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+5
-2
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
@@ -32,9 +32,12 @@ db_dependency = Annotated[Session, Depends(get_db)]
|
||||
def init_database():
|
||||
"""Initialize the database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("Database initialized successfully!")
|
||||
|
||||
|
||||
def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
+123
-33
@@ -1,13 +1,20 @@
|
||||
import datetime
|
||||
import enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
|
||||
from sqlalchemy.orm import declarative_mixin, relationship
|
||||
from sqlalchemy.types import Enum
|
||||
|
||||
from db.db import Base
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class TimestampMixin:
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
@@ -16,6 +23,7 @@ class InvestmentStage(enum.Enum):
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
investor_company_association = Table(
|
||||
"investor_companies",
|
||||
@@ -34,24 +42,49 @@ investor_sector_association = Table(
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base):
|
||||
company_sector_association = Table(
|
||||
"company_sector",
|
||||
Base.metadata,
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_sector_association = Table(
|
||||
"project_sector",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_investor_association = Table(
|
||||
"project_investors",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
)
|
||||
|
||||
project_company_association = Table(
|
||||
"project_companies",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
aum = Column(Integer, nullable=False) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=False) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=False) # Upper bound
|
||||
geographic_focus = Column(String, nullable=False)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=False)
|
||||
number_of_investments = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
aum = Column(Integer, nullable=True) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=True) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=True) # Upper bound
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=True)
|
||||
number_of_investments = Column(Integer, default=0, nullable=True)
|
||||
|
||||
team_members = relationship("InvestorMember", back_populates="investor")
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
@@ -59,30 +92,43 @@ class InvestorTable(Base):
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
team_members = relationship("InvestorTeamMember", back_populates="investor")
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
class CompanyTable(Base):
|
||||
|
||||
class InvestorMember(Base, TimestampMixin):
|
||||
__tablename__ = "investor_members"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=False)
|
||||
location = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(String, nullable=True)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=datetime.datetime.now(datetime.UTC),
|
||||
onupdate=datetime.datetime.now(datetime.UTC),
|
||||
)
|
||||
|
||||
members = relationship("CompanyMember", back_populates="company")
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
@@ -90,8 +136,29 @@ class CompanyTable(Base):
|
||||
back_populates="portfolio_companies",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable", secondary=company_sector_association, back_populates="companies"
|
||||
)
|
||||
|
||||
class SectorTable(Base):
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_company_association,
|
||||
back_populates="companies",
|
||||
)
|
||||
|
||||
|
||||
class CompanyMember(Base, TimestampMixin):
|
||||
__tablename__ = "company_members"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
linkedin = Column(String, nullable=True)
|
||||
role = Column(String, nullable=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
|
||||
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
@@ -104,13 +171,36 @@ class SectorTable(Base):
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
__tablename__ = "projects"
|
||||
|
||||
class InvestorTeamMember(Base):
|
||||
__tablename__ = "investor_team"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
email = Column(String, nullable=False)
|
||||
valuation = Column(Integer, nullable=True)
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
stage = Column(Enum(InvestmentStage), nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
start_date = Column(DateTime, nullable=True)
|
||||
end_date = Column(DateTime, nullable=True)
|
||||
|
||||
sector = relationship(
|
||||
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
||||
)
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="projects",
|
||||
)
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=project_company_association, back_populates="projects"
|
||||
)
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Investor(Base):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(500), nullable=False)
|
||||
website = Column(String(1000))
|
||||
|
||||
# Core investment information
|
||||
investor_description = Column(Text)
|
||||
investment_thesis_focus = Column(JSON) # List of focus areas
|
||||
headquarters = Column(String(1000))
|
||||
|
||||
# AUM information
|
||||
aum_amount = Column(String(200))
|
||||
aum_as_of_date = Column(String(100))
|
||||
aum_source_url = Column(String(1000))
|
||||
|
||||
# Fund information
|
||||
funds_info = Column(JSON) # Complex fund data
|
||||
|
||||
# Raw data columns for reference
|
||||
crunchbase_urls = Column(Text)
|
||||
crunchbase_extract = Column(Text)
|
||||
linkedin_profile = Column(Text)
|
||||
source_truth_profile = Column(Text)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Investor(name='{self.name}', website='{self.website}')>"
|
||||
|
||||
|
||||
# Pydantic models for data validation and parsing
|
||||
class AUMInfo(BaseModel):
|
||||
aumAmount: Optional[str] = None
|
||||
asOfDate: Optional[str] = None
|
||||
sourceUrl: Optional[str] = None
|
||||
|
||||
|
||||
class FundInfo(BaseModel):
|
||||
fundName: Optional[str] = None
|
||||
fundSize: Optional[str] = None
|
||||
vintage: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class InvestorProfile(BaseModel):
|
||||
websiteURL: Optional[str] = None
|
||||
investorDescription: Optional[str] = None
|
||||
investmentThesisFocus: Optional[List[str]] = None
|
||||
headquarters: Optional[str] = None
|
||||
overallAssetsUnderManagement: Optional[AUMInfo] = None
|
||||
funds: Optional[List[FundInfo]] = None
|
||||
|
||||
|
||||
class CSVRow(BaseModel):
|
||||
name: str
|
||||
website: Optional[str] = None
|
||||
investment_firm_profile: Optional[str] = None
|
||||
crunchbase_linkedin_urls: Optional[str] = None
|
||||
crunchbase_firm_extract: Optional[str] = None
|
||||
linkedin_investment_profile: Optional[str] = None
|
||||
source_of_truth_profile: Optional[str] = None
|
||||
|
||||
def get_combined_description(self) -> str:
|
||||
"""Combine all description fields for vector embedding"""
|
||||
descriptions = []
|
||||
|
||||
if self.investment_firm_profile:
|
||||
try:
|
||||
profile_data = json.loads(self.investment_firm_profile)
|
||||
if isinstance(profile_data, dict):
|
||||
desc = profile_data.get("investorDescription", "")
|
||||
if desc:
|
||||
descriptions.append(desc)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if self.crunchbase_firm_extract:
|
||||
descriptions.append(self.crunchbase_firm_extract)
|
||||
|
||||
if self.linkedin_investment_profile:
|
||||
descriptions.append(self.linkedin_investment_profile)
|
||||
|
||||
if self.source_of_truth_profile:
|
||||
descriptions.append(self.source_of_truth_profile)
|
||||
|
||||
return " ".join(descriptions)
|
||||
|
||||
def get_investment_focus(self) -> List[str]:
|
||||
"""Extract investment thesis focus"""
|
||||
if self.investment_firm_profile:
|
||||
try:
|
||||
profile_data = json.loads(self.investment_firm_profile)
|
||||
if isinstance(profile_data, dict):
|
||||
focus = profile_data.get("investmentThesisFocus", [])
|
||||
if isinstance(focus, list):
|
||||
return focus
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return []
|
||||
+31
-13
@@ -1,17 +1,27 @@
|
||||
import io
|
||||
|
||||
import pandas as pd
|
||||
from api import companies, investors
|
||||
from db.db import db_dependency, init_database
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from py_schemas import InvestorList
|
||||
from db.db import Base, db_dependency, engine
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from services.openrouter_v2 import InvestorProcessor
|
||||
from routers import companies, investors, projects
|
||||
from schemas.router_schemas import InvestorList
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
init_database()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Request models
|
||||
class QueryRequest(BaseModel):
|
||||
@@ -20,7 +30,7 @@ class QueryRequest(BaseModel):
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"question": "Show me growth stage fintech investors in the US with check sizes over $1 million"
|
||||
"question": "Find me deep tech investors that do deals in Europe under 5 million."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,21 +41,27 @@ def health():
|
||||
|
||||
|
||||
@app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict])
|
||||
async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
||||
async def parse_csv(
|
||||
db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)
|
||||
):
|
||||
# Read uploaded CSV with pandas
|
||||
content = await file.read()
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
|
||||
# Process the dataframe
|
||||
processor = InvestorProcessor(sql_session=db)
|
||||
results = await processor.process_csv(df)
|
||||
processor = InvestorProcessor()
|
||||
|
||||
if is_investor == 1:
|
||||
results = await processor.parse_investors(df)
|
||||
else:
|
||||
results = await processor.parse_companies(df)
|
||||
|
||||
# Convert Pydantic objects to dictionaries
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
@@ -55,14 +71,16 @@ async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
- "Growth stage investors with $5M+ check sizes"
|
||||
- "Healthcare investors in Europe"
|
||||
"""
|
||||
processor = QueryProcessor(sql_session=db)
|
||||
processor = QueryProcessor()
|
||||
results = processor.process_query(request.question)
|
||||
return results
|
||||
|
||||
|
||||
app.include_router(investors.router)
|
||||
app.include_router(companies.router)
|
||||
app.include_router(projects.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app="main:app", host="localhost", port=8000, reload=True)
|
||||
uvicorn.run(app="main:app", host="0.0.0.0", port=8585, reload=True)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Investor(BaseModel):
|
||||
name: str
|
||||
aum: int
|
||||
check_size: str
|
||||
sector_focus: str
|
||||
stage_focus: str
|
||||
region: str
|
||||
investment_thesis: str
|
||||
investor_description: str
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investor_list: List[Investor]
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
name: str
|
||||
aum: int
|
||||
check_size: str
|
||||
sector_focus: str
|
||||
stage_focus: str
|
||||
region: str
|
||||
investment_thesis: str
|
||||
investor_description: str
|
||||
reason: str
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
|
||||
class QueryResponseList(BaseModel):
|
||||
responses: List[QueryResponse]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -3,8 +3,8 @@ from typing import List, Optional
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import CompanySchema
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import CompanyData
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
@@ -15,6 +15,7 @@ class CompanyCreate(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
description: Optional[str] = None
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
@@ -23,46 +24,37 @@ class CompanyUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
# Response schema with relationships
|
||||
class CompanyData(BaseModel):
|
||||
"""Comprehensive company data schema"""
|
||||
|
||||
company: CompanySchema
|
||||
investors: List["InvestorBasic"] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorBasic(BaseModel):
|
||||
"""Basic investor info for company responses"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
geographic_focus: str
|
||||
stage_focus: str
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
companies = (
|
||||
db.query(CompanyTable).options(selectinload(CompanyTable.investors)).all()
|
||||
db.query(CompanyTable)
|
||||
.filter(
|
||||
CompanyTable.name.isnot(None),
|
||||
CompanyTable.description.isnot(None)
|
||||
)
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform CompanyTable objects to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
@@ -89,7 +81,11 @@ def filter_companies(
|
||||
"""Filter companies based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(selectinload(CompanyTable.investors))
|
||||
query = db.query(CompanyTable).options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if industry:
|
||||
@@ -121,7 +117,12 @@ def filter_companies(
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
@@ -132,7 +133,11 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific company by ID with its investors"""
|
||||
company = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
@@ -141,7 +146,12 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(company=company, investors=company.investors)
|
||||
return CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/companies", response_model=CompanyData)
|
||||
@@ -155,14 +165,21 @@ def create_company(company: CompanyCreate, db: Session = Depends(get_db)):
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == db_company.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
company=company_with_relations,
|
||||
investors=company_with_relations.investors,
|
||||
members=company_with_relations.members,
|
||||
sectors=company_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@@ -185,14 +202,21 @@ def update_company(
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
company=company_with_relations,
|
||||
investors=company_with_relations.investors,
|
||||
members=company_with_relations.members,
|
||||
sectors=company_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ from typing import List, Optional
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import InvestmentStage, InvestorData
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import InvestmentStage, InvestorData
|
||||
from services.querying import QueryProcessor
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Investor Routes"])
|
||||
@@ -13,7 +14,7 @@ router = APIRouter(tags=["Investor Routes"])
|
||||
# Request schemas for creating/updating
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
description: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
@@ -23,14 +24,14 @@ class InvestorCreate(BaseModel):
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: str = None
|
||||
description: str = None
|
||||
aum: int = None
|
||||
check_size_lower: int = None
|
||||
check_size_upper: int = None
|
||||
geographic_focus: str = None
|
||||
stage_focus: InvestmentStage = None
|
||||
number_of_investments: int = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
aum: Optional[int] = None
|
||||
check_size_lower: Optional[int] = None
|
||||
check_size_upper: Optional[int] = None
|
||||
geographic_focus: Optional[str] = None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
number_of_investments: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=List[InvestorData])
|
||||
@@ -180,26 +181,16 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
)
|
||||
|
||||
|
||||
@router.put("/investors/{investor_id}", response_model=InvestorData)
|
||||
def update_investor(
|
||||
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
|
||||
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
|
||||
def find_similar_investors(
|
||||
investor_id: int,
|
||||
limit: int = Query(10, description="Maximum number of similar investors to return"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an existing investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
update_data = investor.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_investor, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_investor)
|
||||
|
||||
# Reload with relationships
|
||||
investor_with_relations = (
|
||||
"""Find investors similar to a given investor based on characteristics"""
|
||||
|
||||
# Get the target investor
|
||||
target_investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
@@ -210,24 +201,81 @@ def update_investor(
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor_with_relations,
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/investors/{investor_id}")
|
||||
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete an investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
if not target_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
db.delete(db_investor)
|
||||
db.commit()
|
||||
return {"message": "Investor deleted successfully"}
|
||||
# Get target investor's sector IDs for comparison
|
||||
target_sector_ids = {sector.id for sector in target_investor.sectors}
|
||||
|
||||
# Query all other investors with their relationships
|
||||
candidates = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id != investor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Calculate similarity scores
|
||||
scored_investors = []
|
||||
for candidate in candidates:
|
||||
score = 0
|
||||
|
||||
# Stage focus match (30 points)
|
||||
if candidate.stage_focus == target_investor.stage_focus:
|
||||
score += 30
|
||||
|
||||
# Geographic focus match (20 points for exact, 10 for partial)
|
||||
if candidate.geographic_focus and target_investor.geographic_focus:
|
||||
if candidate.geographic_focus.lower() == target_investor.geographic_focus.lower():
|
||||
score += 20
|
||||
elif (candidate.geographic_focus.lower() in target_investor.geographic_focus.lower() or
|
||||
target_investor.geographic_focus.lower() in candidate.geographic_focus.lower()):
|
||||
score += 10
|
||||
|
||||
# Check size overlap (20 points max)
|
||||
if (candidate.check_size_lower and candidate.check_size_upper and
|
||||
target_investor.check_size_lower and target_investor.check_size_upper):
|
||||
# Calculate overlap percentage
|
||||
overlap_start = max(candidate.check_size_lower, target_investor.check_size_lower)
|
||||
overlap_end = min(candidate.check_size_upper, target_investor.check_size_upper)
|
||||
if overlap_end > overlap_start:
|
||||
overlap = overlap_end - overlap_start
|
||||
target_range = target_investor.check_size_upper - target_investor.check_size_lower
|
||||
overlap_ratio = overlap / target_range if target_range > 0 else 0
|
||||
score += int(20 * overlap_ratio)
|
||||
|
||||
# AUM similarity (15 points max)
|
||||
if candidate.aum and target_investor.aum:
|
||||
aum_diff = abs(candidate.aum - target_investor.aum)
|
||||
max_aum = max(candidate.aum, target_investor.aum)
|
||||
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
|
||||
score += int(15 * similarity_ratio)
|
||||
|
||||
# Sector overlap (30 points max)
|
||||
candidate_sector_ids = {sector.id for sector in candidate.sectors}
|
||||
if target_sector_ids and candidate_sector_ids:
|
||||
common_sectors = target_sector_ids.intersection(candidate_sector_ids)
|
||||
overlap_ratio = len(common_sectors) / len(target_sector_ids)
|
||||
score += int(30 * overlap_ratio)
|
||||
|
||||
if score > 0: # Only include investors with some similarity
|
||||
scored_investors.append((score, candidate))
|
||||
|
||||
# Sort by score (descending) and take top N
|
||||
scored_investors.sort(key=lambda x: x[0], reverse=True)
|
||||
similar_investors = [inv for score, inv in scored_investors[:limit]]
|
||||
|
||||
# Transform to InvestorData format
|
||||
return [
|
||||
InvestorData(
|
||||
investor=inv,
|
||||
portfolio_companies=inv.portfolio_companies,
|
||||
team_members=inv.team_members,
|
||||
sectors=inv.sectors,
|
||||
)
|
||||
for inv in similar_investors
|
||||
]
|
||||
@@ -0,0 +1,447 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import (
|
||||
CompanyTable,
|
||||
InvestorTable,
|
||||
ProjectTable,
|
||||
SectorTable,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from schemas.project_schemas import (
|
||||
InvestmentStage,
|
||||
ProjectCreate,
|
||||
ProjectData,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Project Routes"])
|
||||
|
||||
|
||||
@router.get("/projects", response_model=List[ProjectData])
|
||||
def read_projects(db: Session = Depends(get_db)):
|
||||
"""Get all projects with their related data"""
|
||||
projects = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform ProjectTable objects to ProjectData format
|
||||
project_data_list = []
|
||||
for project in projects:
|
||||
project_data = ProjectData(
|
||||
project=project,
|
||||
sector=project.sector,
|
||||
investors=project.investors,
|
||||
companies=project.companies,
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ProjectData)
|
||||
def read_project(project_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific project by ID"""
|
||||
project = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return ProjectData(
|
||||
project=project,
|
||||
sector=project.sector,
|
||||
investors=project.investors,
|
||||
companies=project.companies,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects", response_model=ProjectData)
|
||||
def create_project(project: ProjectCreate, db: Session = Depends(get_db)):
|
||||
"""Create a new project"""
|
||||
db_project = ProjectTable(**project.dict())
|
||||
db.add(db_project)
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
|
||||
# Reload with relationships
|
||||
db_project = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.filter(ProjectTable.id == db_project.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return ProjectData(
|
||||
project=db_project,
|
||||
sector=db_project.sector,
|
||||
investors=db_project.investors,
|
||||
companies=db_project.companies,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/projects/{project_id}", response_model=ProjectData)
|
||||
def update_project(
|
||||
project_id: int, project: ProjectUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an existing project"""
|
||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
|
||||
if not db_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Update only provided fields
|
||||
update_data = project.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_project, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
|
||||
# Reload with relationships
|
||||
db_project = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return ProjectData(
|
||||
project=db_project,
|
||||
sector=db_project.sector,
|
||||
investors=db_project.investors,
|
||||
companies=db_project.companies,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}")
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete a project"""
|
||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
|
||||
if not db_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db.delete(db_project)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Project deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/projects/filter", response_model=List[ProjectData])
|
||||
def filter_projects(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by project stage"
|
||||
),
|
||||
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
|
||||
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
|
||||
location: Optional[str] = Query(None, description="Location (partial match)"),
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Investor name (partial match)"
|
||||
),
|
||||
company_name: Optional[str] = Query(
|
||||
None, description="Company name (partial match)"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter projects based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(ProjectTable).options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if stage:
|
||||
query = query.filter(ProjectTable.stage == stage)
|
||||
|
||||
if min_valuation is not None:
|
||||
query = query.filter(ProjectTable.valuation >= min_valuation)
|
||||
|
||||
if max_valuation is not None:
|
||||
query = query.filter(ProjectTable.valuation <= max_valuation)
|
||||
|
||||
if location:
|
||||
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
|
||||
|
||||
if sector:
|
||||
query = query.join(ProjectTable.sector).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
)
|
||||
|
||||
if investor_name:
|
||||
query = query.join(ProjectTable.investors).filter(
|
||||
InvestorTable.name.ilike(f"%{investor_name}%")
|
||||
)
|
||||
|
||||
if company_name:
|
||||
query = query.join(ProjectTable.companies).filter(
|
||||
CompanyTable.name.ilike(f"%{company_name}%")
|
||||
)
|
||||
|
||||
projects = query.all()
|
||||
|
||||
# Transform to ProjectData format
|
||||
project_data_list = []
|
||||
for project in projects:
|
||||
project_data = ProjectData(
|
||||
project=project,
|
||||
sector=project.sector,
|
||||
investors=project.investors,
|
||||
companies=project.companies,
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
|
||||
|
||||
# Association management routes
|
||||
@router.post("/projects/{project_id}/investors/{investor_id}")
|
||||
def add_investor_to_project(
|
||||
project_id: int, investor_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add an investor to a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if investor exists
|
||||
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Check if association already exists
|
||||
if investor in project.investors:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Investor already associated with project"
|
||||
)
|
||||
|
||||
# Add association
|
||||
project.investors.append(investor)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Investor added to project successfully"}
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}/investors/{investor_id}")
|
||||
def remove_investor_from_project(
|
||||
project_id: int, investor_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Remove an investor from a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if investor exists
|
||||
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Check if association exists
|
||||
if investor not in project.investors:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Investor not associated with project"
|
||||
)
|
||||
|
||||
# Remove association
|
||||
project.investors.remove(investor)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Investor removed from project successfully"}
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/companies/{company_id}")
|
||||
def add_company_to_project(
|
||||
project_id: int, company_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add a company to a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if company exists
|
||||
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Check if association already exists
|
||||
if company in project.companies:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Company already associated with project"
|
||||
)
|
||||
|
||||
# Add association
|
||||
project.companies.append(company)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Company added to project successfully"}
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}/companies/{company_id}")
|
||||
def remove_company_from_project(
|
||||
project_id: int, company_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Remove a company from a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if company exists
|
||||
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Check if association exists
|
||||
if company not in project.companies:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Company not associated with project"
|
||||
)
|
||||
|
||||
# Remove association
|
||||
project.companies.remove(company)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Company removed from project successfully"}
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/sectors/{sector_id}")
|
||||
def add_sector_to_project(
|
||||
project_id: int, sector_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add a sector to a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if sector exists
|
||||
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
|
||||
if not sector:
|
||||
raise HTTPException(status_code=404, detail="Sector not found")
|
||||
|
||||
# Check if association already exists
|
||||
if sector in project.sector:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Sector already associated with project"
|
||||
)
|
||||
|
||||
# Add association
|
||||
project.sector.append(sector)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Sector added to project successfully"}
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}/sectors/{sector_id}")
|
||||
def remove_sector_from_project(
|
||||
project_id: int, sector_id: int, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Remove a sector from a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Check if sector exists
|
||||
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
|
||||
if not sector:
|
||||
raise HTTPException(status_code=404, detail="Sector not found")
|
||||
|
||||
# Check if association exists
|
||||
if sector not in project.sector:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Sector not associated with project"
|
||||
)
|
||||
|
||||
# Remove association
|
||||
project.sector.remove(sector)
|
||||
db.commit()
|
||||
|
||||
return {"message": "Sector removed from project successfully"}
|
||||
|
||||
|
||||
# Bulk association management
|
||||
@router.post("/projects/{project_id}/investors")
|
||||
def add_multiple_investors_to_project(
|
||||
project_id: int, investor_ids: List[int], db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add multiple investors to a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Get all investors
|
||||
investors = db.query(InvestorTable).filter(InvestorTable.id.in_(investor_ids)).all()
|
||||
|
||||
if len(investors) != len(investor_ids):
|
||||
raise HTTPException(status_code=404, detail="One or more investors not found")
|
||||
|
||||
# Add associations (only if not already associated)
|
||||
added_count = 0
|
||||
for investor in investors:
|
||||
if investor not in project.investors:
|
||||
project.investors.append(investor)
|
||||
added_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": f"Added {added_count} investors to project successfully"}
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/companies")
|
||||
def add_multiple_companies_to_project(
|
||||
project_id: int, company_ids: List[int], db: Session = Depends(get_db)
|
||||
):
|
||||
"""Add multiple companies to a project"""
|
||||
# Check if project exists
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Get all companies
|
||||
companies = db.query(CompanyTable).filter(CompanyTable.id.in_(company_ids)).all()
|
||||
|
||||
if len(companies) != len(company_ids):
|
||||
raise HTTPException(status_code=404, detail="One or more companies not found")
|
||||
|
||||
# Add associations (only if not already associated)
|
||||
added_count = 0
|
||||
for company in companies:
|
||||
if company not in project.companies:
|
||||
project.companies.append(company)
|
||||
added_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": f"Added {added_count} companies to project successfully"}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,117 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
class SectorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
aum: int | None
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str | None
|
||||
location: str | None
|
||||
description: Optional[str]
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
valuation: int | None
|
||||
stage: InvestmentStage | None
|
||||
location: str | None
|
||||
description: Optional[str]
|
||||
start_date: Optional[datetime]
|
||||
end_date: Optional[datetime]
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
valuation: Optional[int] = None
|
||||
stage: Optional[InvestmentStage] = None
|
||||
location: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
valuation: Optional[int] = None
|
||||
stage: Optional[InvestmentStage] = None
|
||||
location: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class ProjectData(BaseModel):
|
||||
"""Comprehensive project data schema"""
|
||||
|
||||
project: ProjectSchema
|
||||
sector: List[SectorSchema]
|
||||
investors: List[InvestorSchema]
|
||||
companies: List[CompanySchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectInvestorAssociation(BaseModel):
|
||||
project_id: int
|
||||
investor_id: int
|
||||
|
||||
|
||||
class ProjectCompanyAssociation(BaseModel):
|
||||
project_id: int
|
||||
company_id: int
|
||||
|
||||
|
||||
class ProjectSectorAssociation(BaseModel):
|
||||
project_id: int
|
||||
sector_id: int
|
||||
@@ -0,0 +1,356 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
class SectorSchema(BaseModel):
|
||||
"""
|
||||
Expert parser: Only extract sector information if clearly identifiable.
|
||||
Leave name empty if uncertain about the sector classification.
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Sector ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Sector name. Leave empty string if not clearly identifiable from the data.",
|
||||
)
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, v):
|
||||
"""Convert empty strings to None"""
|
||||
if v == "" or (isinstance(v, str) and v.strip() == ""):
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def zero_to_none(cls, v):
|
||||
"""Convert 0 to None for optional id field"""
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
"""
|
||||
Expert parser: Only extract team member information if clearly identifiable.
|
||||
Leave fields empty if uncertain about the member details.
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Team member name. Leave empty string if not clearly identifiable.",
|
||||
)
|
||||
role: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Team member role/title. Leave empty string if not clearly identifiable.",
|
||||
)
|
||||
email: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Team member email. Leave empty string if not clearly identifiable or not provided.",
|
||||
)
|
||||
investor_id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
|
||||
@field_validator("name", "role", "email", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, v):
|
||||
"""Convert empty strings to None"""
|
||||
if v == "" or (isinstance(v, str) and v.strip() == ""):
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("id", "investor_id", mode="before")
|
||||
@classmethod
|
||||
def zero_to_none(cls, v):
|
||||
"""Convert 0 to None for optional integer fields"""
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyMemberSchema(BaseModel):
|
||||
"""
|
||||
Expert parser: Only extract company member information if clearly identifiable.
|
||||
Leave fields empty if uncertain about the member details.
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company member name. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
linkedin: Optional[str] = Field(
|
||||
default=None,
|
||||
description="LinkedIn profile URL. Leave empty if not provided or uncertain.",
|
||||
)
|
||||
role: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company member role/title. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
company_id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
|
||||
@field_validator("name", "linkedin", "role", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, v):
|
||||
"""Convert empty strings to None"""
|
||||
if v == "" or (isinstance(v, str) and v.strip() == ""):
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("id", "company_id", mode="before")
|
||||
@classmethod
|
||||
def zero_to_none(cls, v):
|
||||
"""Convert 0 to None for optional integer fields"""
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
"""
|
||||
Expert parser: Only extract company information if clearly identifiable.
|
||||
Leave optional fields empty if uncertain. Integer values must be 0 or greater.
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company name. Leave empty string if not clearly identifiable.",
|
||||
)
|
||||
industry: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company industry/sector. Leave empty string if not clearly identifiable.",
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company location/address. Leave empty string if not clearly identifiable.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company description. Leave empty if not clearly available or uncertain.",
|
||||
)
|
||||
founded_year: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Year company was founded, must be 0 or greater. Leave None if not clearly identifiable or uncertain.",
|
||||
)
|
||||
website: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Company website URL. Leave empty if not provided or uncertain.",
|
||||
)
|
||||
|
||||
@field_validator(
|
||||
"name", "industry", "location", "description", "website", mode="before"
|
||||
)
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, v):
|
||||
"""Convert empty strings to None"""
|
||||
if v == "" or (isinstance(v, str) and v.strip() == ""):
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("id", "founded_year", mode="before")
|
||||
@classmethod
|
||||
def zero_to_none(cls, v):
|
||||
"""Convert 0 to None for founded_year"""
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator("founded_year", mode="before")
|
||||
@classmethod
|
||||
def validate_founded_year(cls, v):
|
||||
"""Expert parser: Only accept clearly identifiable founding years"""
|
||||
if v is None or v == "Not Available" or v == "" or v == "Unknown":
|
||||
return None
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
year = int(v)
|
||||
return year if year >= 0 else None
|
||||
except ValueError:
|
||||
return None
|
||||
return v if isinstance(v, int) and v >= 0 else None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorSchema(BaseModel):
|
||||
"""
|
||||
Expert parser: Only extract investor information if clearly identifiable.
|
||||
Leave optional fields empty if uncertain. All numeric values must be 0 or greater.
|
||||
"""
|
||||
|
||||
id: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Investor name. Do not return any special characters, Just the name as a string.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Investor description. Leave empty if not clearly available or uncertain.",
|
||||
)
|
||||
aum: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Assets Under Management in USD, must be 0 or greater. Use 0 if not clearly identifiable or uncertain.",
|
||||
)
|
||||
check_size_lower: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Lower bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
|
||||
)
|
||||
check_size_upper: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Upper bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
|
||||
)
|
||||
geographic_focus: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
stage_focus: InvestmentStage = Field(
|
||||
default=InvestmentStage.SEED,
|
||||
description="Investment stage focus. Use SEED as default if uncertain.",
|
||||
)
|
||||
number_of_investments: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Total number of investments made, must be 0 or greater. Use 0 if not clearly identifiable.",
|
||||
)
|
||||
|
||||
@field_validator("name", "description", "geographic_focus", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, v):
|
||||
"""Convert empty strings to None"""
|
||||
if v == "" or (isinstance(v, str) and v.strip() == ""):
|
||||
return None
|
||||
return v
|
||||
|
||||
@field_validator(
|
||||
"id",
|
||||
"aum",
|
||||
"check_size_lower",
|
||||
"check_size_upper",
|
||||
"number_of_investments",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def zero_to_none(cls, v):
|
||||
"""Convert 0 to None for optional integer fields"""
|
||||
if v == 0:
|
||||
return None
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""
|
||||
Expert parser: Comprehensive investor data schema for LLM processing.
|
||||
Only populate fields with clearly identifiable information. Leave lists empty if uncertain.
|
||||
"""
|
||||
|
||||
investor: InvestorSchema = Field(
|
||||
description="Core investor information. Only populate with clearly identifiable data."
|
||||
)
|
||||
portfolio_companies: List[CompanySchema] = Field(
|
||||
default=[],
|
||||
description="List of portfolio companies. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
team_members: List[InvestorMemberSchema] = Field(
|
||||
default=[],
|
||||
description="List of team members. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
sectors: List[SectorSchema] = Field(
|
||||
default=[],
|
||||
description="List of investment sectors. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyData(BaseModel):
|
||||
"""
|
||||
Expert parser: Comprehensive company data schema for LLM processing.
|
||||
Only populate fields with clearly identifiable information. Leave lists empty if uncertain.
|
||||
"""
|
||||
|
||||
company: CompanySchema = Field(
|
||||
description="Core company information. Only populate with clearly identifiable data."
|
||||
)
|
||||
sectors: List[SectorSchema] = Field(
|
||||
default=[],
|
||||
description="List of company sectors. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
members: List[CompanyMemberSchema] = Field(
|
||||
default=[],
|
||||
description="List of company members. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
investors: List[InvestorSchema] = Field(
|
||||
default=[],
|
||||
description="List of investors. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
"""Expert parser: List of investors with clearly identifiable information only."""
|
||||
|
||||
investors: List[InvestorData] = Field(
|
||||
default=[],
|
||||
description="List of investors. Leave empty if no clearly identifiable investors.",
|
||||
)
|
||||
@@ -22,25 +22,37 @@ class SectorSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
role: str | None
|
||||
email: str | None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorTeamMemberSchema(BaseModel):
|
||||
class CompanyMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: Optional[str]
|
||||
linkedin: Optional[str]
|
||||
role: Optional[str]
|
||||
company_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str
|
||||
email: str
|
||||
industry: str | None
|
||||
location: str | None
|
||||
description: Optional[str]
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -50,14 +62,14 @@ class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
aum: int | None
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int
|
||||
created_at: Optional[datetime]
|
||||
updated_at: Optional[datetime]
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -67,9 +79,19 @@ class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema] = []
|
||||
team_members: List[InvestorTeamMemberSchema] = []
|
||||
sectors: List[SectorSchema] = []
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema]
|
||||
members: List[CompanyMemberSchema]
|
||||
investors: List[InvestorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+298
-329
@@ -1,368 +1,337 @@
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from dotenv import load_dotenv
|
||||
from openai import OpenAI
|
||||
|
||||
from db import get_session, init_database
|
||||
from py_schemas import CSVRow, Investor
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
from db.db import get_db_session
|
||||
from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from schemas.py_schemas import CompanyData, InvestorData
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class LLMInvestorParser:
|
||||
class InvestorProcessor:
|
||||
def __init__(self):
|
||||
# Initialize OpenAI client
|
||||
self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Initialize ChromaDB
|
||||
self.chroma_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.chroma_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Initialize database
|
||||
init_database()
|
||||
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
|
||||
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
|
||||
|
||||
def parse_json_field(self, json_str: str) -> Dict[str, Any]:
|
||||
"""Safely parse JSON string with LLM assistance if needed"""
|
||||
if not json_str or json_str.strip() == "":
|
||||
return {}
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
if not sector:
|
||||
sector = SectorTable(name=sector_name)
|
||||
db.add(sector)
|
||||
db.flush() # Get the ID without committing
|
||||
return sector
|
||||
|
||||
try:
|
||||
# Try direct JSON parsing first
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# If direct parsing fails, use LLM to clean and parse
|
||||
logger.info("Direct JSON parsing failed, using LLM to clean JSON")
|
||||
return self._llm_clean_json(json_str)
|
||||
def _save_investor_to_db(
|
||||
self, db: Session, investor_data: InvestorData
|
||||
) -> InvestorTable:
|
||||
"""Save investor data to database"""
|
||||
# Create investor record
|
||||
investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
db.add(investor)
|
||||
db.flush() # Get the ID
|
||||
|
||||
def _llm_clean_json(self, malformed_json: str) -> Dict[str, Any]:
|
||||
"""Use LLM to clean and parse malformed JSON"""
|
||||
try:
|
||||
prompt = f"""
|
||||
The following text appears to be malformed JSON. Please clean it up and return valid JSON.
|
||||
If it's not possible to create valid JSON, return an empty object {{}}.
|
||||
|
||||
Original text:
|
||||
{malformed_json[:2000]} # Limit length for API
|
||||
|
||||
Return only the cleaned JSON, no explanations:
|
||||
"""
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
# Add team members
|
||||
for member_data in investor_data.team_members:
|
||||
member = InvestorMember(
|
||||
name=member_data.name,
|
||||
role=member_data.role,
|
||||
email=member_data.email,
|
||||
investor_id=investor.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
cleaned_json = response.choices[0].message.content.strip()
|
||||
return json.loads(cleaned_json)
|
||||
# Add sectors
|
||||
for sector_data in investor_data.sectors:
|
||||
sector = self._get_or_create_sector(db, sector_data.name)
|
||||
investor.sectors.append(sector)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM JSON cleaning failed: {e}")
|
||||
return {}
|
||||
|
||||
def extract_structured_data(self, csv_row: CSVRow) -> Dict[str, Any]:
|
||||
"""Extract and structure data from CSV row using LLM"""
|
||||
# Parse the investment firm profile
|
||||
profile_data = {}
|
||||
if csv_row.investment_firm_profile:
|
||||
profile_data = self.parse_json_field(csv_row.investment_firm_profile)
|
||||
|
||||
# Create structured output
|
||||
structured_data = {
|
||||
"name": csv_row.name,
|
||||
"website": csv_row.website or profile_data.get("websiteURL"),
|
||||
"investor_description": profile_data.get("investorDescription", ""),
|
||||
"investment_thesis_focus": profile_data.get("investmentThesisFocus", []),
|
||||
"headquarters": profile_data.get("headquarters", ""),
|
||||
"aum_info": profile_data.get("overallAssetsUnderManagement", {}),
|
||||
"funds_info": profile_data.get("funds", []),
|
||||
"crunchbase_urls": csv_row.crunchbase_linkedin_urls or "",
|
||||
"crunchbase_extract": csv_row.crunchbase_firm_extract or "",
|
||||
"linkedin_profile": csv_row.linkedin_investment_profile or "",
|
||||
"source_truth_profile": csv_row.source_of_truth_profile or "",
|
||||
}
|
||||
|
||||
return structured_data
|
||||
|
||||
def enhance_with_llm(self, investor_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Use LLM to enhance and standardize investor data"""
|
||||
try:
|
||||
# Combine all available text for context
|
||||
context_text = " ".join(
|
||||
[
|
||||
investor_data.get("investor_description", ""),
|
||||
investor_data.get("crunchbase_extract", ""),
|
||||
investor_data.get("linkedin_profile", ""),
|
||||
investor_data.get("source_truth_profile", ""),
|
||||
]
|
||||
# Add portfolio companies
|
||||
for company_schema in investor_data.portfolio_companies:
|
||||
# Convert CompanySchema to CompanyData format
|
||||
company_data = CompanyData(
|
||||
company=company_schema,
|
||||
sectors=[], # Will be empty for portfolio companies
|
||||
members=[], # Will be empty for portfolio companies
|
||||
investors=[], # Will be empty for portfolio companies
|
||||
)
|
||||
company = self._save_company_to_db(db, company_data, skip_investors=True)
|
||||
investor.portfolio_companies.append(company)
|
||||
|
||||
if not context_text.strip():
|
||||
return investor_data
|
||||
return investor
|
||||
|
||||
prompt = f"""
|
||||
Based on the following information about an investor, please extract and standardize:
|
||||
1. A concise investor description (2-3 sentences)
|
||||
2. Investment thesis focus areas (list of specific focus areas)
|
||||
3. Headquarters location (city, country format)
|
||||
|
||||
Investor: {investor_data["name"]}
|
||||
Context: {context_text[:3000]} # Limit for API
|
||||
|
||||
Return in JSON format:
|
||||
{{
|
||||
"enhanced_description": "concise description here",
|
||||
"standardized_focus": ["focus area 1", "focus area 2", ...],
|
||||
"standardized_headquarters": "City, Country"
|
||||
}}
|
||||
"""
|
||||
def _save_company_to_db(
|
||||
self, db: Session, company_data: CompanyData, skip_investors: bool = False
|
||||
) -> CompanyTable:
|
||||
"""Save company data to database"""
|
||||
# Check if company already exists
|
||||
existing_company = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.company.name)
|
||||
.first()
|
||||
)
|
||||
if existing_company:
|
||||
return existing_company
|
||||
|
||||
response = self.openai_client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
)
|
||||
# Create company record
|
||||
company = CompanyTable(
|
||||
name=company_data.company.name,
|
||||
industry=company_data.company.industry,
|
||||
location=company_data.company.location,
|
||||
description=company_data.company.description,
|
||||
founded_year=company_data.company.founded_year,
|
||||
website=company_data.company.website,
|
||||
)
|
||||
db.add(company)
|
||||
db.flush() # Get the ID
|
||||
|
||||
enhanced_data = json.loads(response.choices[0].message.content)
|
||||
# Add company members
|
||||
for member_data in company_data.members:
|
||||
if member_data.name: # Only add members with names
|
||||
member = CompanyMember(
|
||||
name=member_data.name,
|
||||
linkedin=member_data.linkedin,
|
||||
role=member_data.role,
|
||||
company_id=company.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# Update investor data with enhanced information
|
||||
if enhanced_data.get("enhanced_description"):
|
||||
investor_data["enhanced_description"] = enhanced_data[
|
||||
"enhanced_description"
|
||||
]
|
||||
# Add sectors
|
||||
for sector_data in company_data.sectors:
|
||||
sector = self._get_or_create_sector(db, sector_data.name)
|
||||
company.sectors.append(sector)
|
||||
|
||||
if enhanced_data.get("standardized_focus"):
|
||||
investor_data["standardized_focus"] = enhanced_data[
|
||||
"standardized_focus"
|
||||
]
|
||||
|
||||
if enhanced_data.get("standardized_headquarters"):
|
||||
investor_data["standardized_headquarters"] = enhanced_data[
|
||||
"standardized_headquarters"
|
||||
]
|
||||
|
||||
return investor_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM enhancement failed for {investor_data['name']}: {e}")
|
||||
return investor_data
|
||||
|
||||
def save_to_sql(self, investor_data: Dict[str, Any]) -> int:
|
||||
"""Save investor data to SQL database"""
|
||||
try:
|
||||
with get_session() as session:
|
||||
# Check if investor already exists
|
||||
existing = (
|
||||
session.query(Investor)
|
||||
.filter_by(name=investor_data["name"])
|
||||
# Add investors (if not skipping to avoid circular references)
|
||||
if not skip_investors:
|
||||
for investor_data in company_data.investors:
|
||||
# Look for existing investor by name
|
||||
existing_investor = (
|
||||
db.query(InvestorTable)
|
||||
.filter(InvestorTable.name == investor_data.name)
|
||||
.first()
|
||||
)
|
||||
if existing_investor:
|
||||
company.investors.append(existing_investor)
|
||||
|
||||
if existing:
|
||||
logger.info(f"Updating existing investor: {investor_data['name']}")
|
||||
investor = existing
|
||||
else:
|
||||
logger.info(f"Creating new investor: {investor_data['name']}")
|
||||
investor = Investor()
|
||||
return company
|
||||
|
||||
# Map data to investor object
|
||||
investor.name = investor_data["name"]
|
||||
investor.website = investor_data.get("website")
|
||||
investor.investor_description = investor_data.get(
|
||||
"enhanced_description"
|
||||
) or investor_data.get("investor_description")
|
||||
investor.investment_thesis_focus = investor_data.get(
|
||||
"standardized_focus"
|
||||
) or investor_data.get("investment_thesis_focus")
|
||||
investor.headquarters = investor_data.get(
|
||||
"standardized_headquarters"
|
||||
) or investor_data.get("headquarters")
|
||||
|
||||
# AUM information
|
||||
aum_info = investor_data.get("aum_info", {})
|
||||
investor.aum_amount = aum_info.get("aumAmount")
|
||||
investor.aum_as_of_date = aum_info.get("asOfDate")
|
||||
investor.aum_source_url = aum_info.get("sourceUrl")
|
||||
|
||||
# Fund information
|
||||
investor.funds_info = investor_data.get("funds_info", [])
|
||||
|
||||
# Raw data
|
||||
investor.crunchbase_urls = investor_data.get("crunchbase_urls")
|
||||
investor.crunchbase_extract = investor_data.get("crunchbase_extract")
|
||||
investor.linkedin_profile = investor_data.get("linkedin_profile")
|
||||
investor.source_truth_profile = investor_data.get(
|
||||
"source_truth_profile"
|
||||
async def _process_row(
|
||||
self, row: pd.Series, row_idx: int, is_investor: bool = True
|
||||
) -> Optional[InvestorData | CompanyData]:
|
||||
"""Process a single row of data"""
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value).replace("\n", " ").replace("\r", " ").replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
if not existing:
|
||||
session.add(investor)
|
||||
|
||||
session.flush() # Get the ID
|
||||
return investor.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to SQL: {e}")
|
||||
raise
|
||||
|
||||
def save_to_vector_db(self, investor_id: int, investor_data: Dict[str, Any]):
|
||||
"""Save investor description and focus to ChromaDB"""
|
||||
row_str = ", ".join([f"{key}: {value}" for key, value in cleaned_row.items()])
|
||||
try:
|
||||
# Prepare text for embedding
|
||||
description_text = investor_data.get(
|
||||
"enhanced_description"
|
||||
) or investor_data.get("investor_description", "")
|
||||
focus_areas = investor_data.get("standardized_focus") or investor_data.get(
|
||||
"investment_thesis_focus", []
|
||||
)
|
||||
|
||||
if isinstance(focus_areas, list):
|
||||
focus_text = " ".join(focus_areas)
|
||||
print(f"Processing row {row_idx + 1}...")
|
||||
if is_investor:
|
||||
result = await self.investor_structured_llm.ainvoke(row_str)
|
||||
else:
|
||||
focus_text = str(focus_areas)
|
||||
|
||||
# Combine description and focus for embedding
|
||||
combined_text = f"{description_text} {focus_text}".strip()
|
||||
|
||||
if not combined_text:
|
||||
logger.warning(f"No text to embed for investor {investor_data['name']}")
|
||||
return
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"investor_id": investor_id,
|
||||
"name": investor_data["name"],
|
||||
"website": investor_data.get("website", ""),
|
||||
"headquarters": investor_data.get("standardized_headquarters")
|
||||
or investor_data.get("headquarters", ""),
|
||||
"focus_areas_count": len(focus_areas)
|
||||
if isinstance(focus_areas, list)
|
||||
else 0,
|
||||
}
|
||||
|
||||
# Add to ChromaDB
|
||||
self.collection.add(
|
||||
documents=[combined_text],
|
||||
metadatas=[metadata],
|
||||
ids=[f"investor_{investor_id}"],
|
||||
)
|
||||
|
||||
logger.info(f"Added investor {investor_data['name']} to vector database")
|
||||
|
||||
result = await self.company_structured_llm.ainvoke(row_str)
|
||||
if result:
|
||||
return result.model_dump()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to vector DB: {e}")
|
||||
|
||||
def process_csv_file(self, csv_file_path: str, limit: Optional[int] = None):
|
||||
"""Process the entire CSV file"""
|
||||
logger.info(f"Starting to process CSV file: {csv_file_path}")
|
||||
|
||||
# Read CSV
|
||||
df = pd.read_csv(csv_file_path)
|
||||
logger.info(f"Loaded {len(df)} rows from CSV")
|
||||
|
||||
if limit:
|
||||
df = df.head(limit)
|
||||
logger.info(f"Processing limited to {limit} rows")
|
||||
|
||||
processed_count = 0
|
||||
error_count = 0
|
||||
|
||||
for index, row in df.iterrows():
|
||||
try:
|
||||
logger.info(f"Processing row {index + 1}/{len(df)}: {row['Name']}")
|
||||
|
||||
# Create CSVRow object
|
||||
csv_row = CSVRow(
|
||||
name=row["Name"],
|
||||
website=row.get("Website"),
|
||||
investment_firm_profile=row.get("Investment Firm Profile"),
|
||||
crunchbase_linkedin_urls=row.get("Crunchbase & LinkedIn URLs"),
|
||||
crunchbase_firm_extract=row.get("Crunchbase Firm Extract"),
|
||||
linkedin_investment_profile=row.get("LinkedIn Investment Profile"),
|
||||
source_of_truth_profile=row.get("Source of Truth Profile"),
|
||||
)
|
||||
|
||||
# Extract structured data
|
||||
structured_data = self.extract_structured_data(csv_row)
|
||||
|
||||
# Enhance with LLM
|
||||
enhanced_data = self.enhance_with_llm(structured_data)
|
||||
|
||||
# Save to SQL database
|
||||
investor_id = self.save_to_sql(enhanced_data)
|
||||
|
||||
# Save to vector database
|
||||
self.save_to_vector_db(investor_id, enhanced_data)
|
||||
|
||||
processed_count += 1
|
||||
|
||||
# Progress update every 10 rows
|
||||
if (index + 1) % 10 == 0:
|
||||
logger.info(
|
||||
f"Processed {processed_count} rows successfully, {error_count} errors"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(
|
||||
f"Error processing row {index + 1} ({row.get('Name', 'Unknown')}): {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Processing complete! Processed: {processed_count}, Errors: {error_count}"
|
||||
)
|
||||
return processed_count, error_count
|
||||
|
||||
def search_investors(self, query: str, limit: int = 5):
|
||||
"""Search investors using vector similarity"""
|
||||
try:
|
||||
results = self.collection.query(query_texts=[query], n_results=limit)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(self, df, save_to_db: bool = True):
|
||||
"""Parse investors from DataFrame and optionally save to database"""
|
||||
investors = []
|
||||
df = df[20:]
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
def main():
|
||||
"""Main function to run the parser"""
|
||||
parser = LLMInvestorParser()
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
# Process the CSV file
|
||||
csv_file = "/home/oluwasanmi/Documents/Work/MKD/anton_wireframe/New Excerpt 5 investors - Sheet1 parse.csv"
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Start with a small sample for testing
|
||||
processed, errors = parser.process_csv_file(csv_file, limit=5)
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=True) for idx, row in batch
|
||||
]
|
||||
|
||||
print("\nProcessing complete!")
|
||||
print(f"Successfully processed: {processed} investors")
|
||||
print(f"Errors encountered: {errors}")
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Test search functionality
|
||||
print("\nTesting search functionality...")
|
||||
results = parser.search_investors("bioeconomy circular economy")
|
||||
if results:
|
||||
print(f"Found {len(results['documents'][0])} similar investors")
|
||||
for i, doc in enumerate(results["documents"][0]):
|
||||
print(f" {i + 1}. {results['metadatas'][0][i]['name']}")
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
if result:
|
||||
# Convert dict to InvestorData if needed
|
||||
if isinstance(result, dict):
|
||||
investor_data = InvestorData(**result)
|
||||
else:
|
||||
investor_data = result
|
||||
|
||||
investors.append(investor_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_investor = self._save_investor_to_db(
|
||||
db, investor_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved investor '{saved_investor.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save investor to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in batch processing: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return investors
|
||||
|
||||
async def parse_companies(self, df, save_to_db: bool = True):
|
||||
"""Parse companies from DataFrame and optionally save to database"""
|
||||
companies = []
|
||||
df = df[20:]
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=False) for idx, row in batch
|
||||
]
|
||||
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
if result:
|
||||
# Convert dict to CompanyData if needed
|
||||
if isinstance(result, dict):
|
||||
company_data = CompanyData(**result)
|
||||
else:
|
||||
company_data = result
|
||||
|
||||
companies.append(company_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_company = self._save_company_to_db(
|
||||
db, company_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved company '{saved_company.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save company to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing row {idx}: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return companies
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# async def main():
|
||||
# """Main execution function"""
|
||||
# # Initialize database tables
|
||||
# print("🔧 Initializing database...")
|
||||
# init_database()
|
||||
|
||||
# # Create processor
|
||||
# processor = InvestorProcessor()
|
||||
|
||||
# print("📊 Processing companies...")
|
||||
# companies = await processor.parse_companies(
|
||||
# "data/19 Companies data.csv", save_to_db=True
|
||||
# )
|
||||
# print(f"Processed {len(companies)} companies")
|
||||
|
||||
# print("\n💰 Processing investors...")
|
||||
# investors = await processor.parse_investors(
|
||||
# "data/19 Investors data.csv", save_to_db=True
|
||||
# )
|
||||
# print(f"Processed {len(investors)} investors")
|
||||
# print("\n✨ Processing complete!")
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_list: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
|
||||
|
||||
Given the following CSV data rows:
|
||||
{question}
|
||||
|
||||
For each row, extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a structured list of comprehensive investor data."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
)
|
||||
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
||||
self.sql_session = sql_session
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_batch(
|
||||
self, batch: pd.DataFrame, batch_idx: int
|
||||
) -> List[InvestorData]:
|
||||
"""Process a single batch of data"""
|
||||
# Convert batch to string representation - clean the data
|
||||
batch_str = ""
|
||||
for idx, row in batch.iterrows():
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value)
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
row_str = ", ".join(
|
||||
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
||||
)
|
||||
batch_str += f"Row {idx + 1}: {row_str}\n"
|
||||
|
||||
try:
|
||||
print(f"Processing batch {batch_idx + 1}...")
|
||||
batch_results = await self.structured_llm.ainvoke(batch_str)
|
||||
return batch_results.investor_list
|
||||
except Exception as e:
|
||||
print(f"Error processing batch {batch_idx + 1}: {e}")
|
||||
return []
|
||||
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data in parallel batches and save to databases"""
|
||||
results = []
|
||||
|
||||
# Create batches
|
||||
batches = []
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch = df.iloc[i : i + batch_size]
|
||||
batches.append((batch, i // batch_size))
|
||||
|
||||
# Process batches with concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(batch_data):
|
||||
batch, batch_idx = batch_data
|
||||
async with semaphore:
|
||||
return await self._process_batch(batch, batch_idx)
|
||||
|
||||
# Execute all batches concurrently
|
||||
batch_results = await asyncio.gather(
|
||||
*[process_with_semaphore(batch_data) for batch_data in batches],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Collect results, filtering out exceptions
|
||||
for batch_result in batch_results:
|
||||
if not isinstance(batch_result, Exception):
|
||||
results.extend(batch_result)
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
return results
|
||||
@@ -1,290 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
|
||||
class InvestorOutput(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_data: InvestorData
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a structured record.
|
||||
|
||||
Given the following CSV data row:
|
||||
{question}
|
||||
|
||||
Extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a single comprehensive investor data record."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
)
|
||||
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorOutput)
|
||||
self.sql_session = sql_session
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_row(
|
||||
self, row: pd.Series, row_idx: int
|
||||
) -> Optional[InvestorData]:
|
||||
"""Process a single row of data"""
|
||||
# Clean values to remove control characters
|
||||
cleaned_row = {}
|
||||
for key, value in row.items():
|
||||
if pd.notna(value):
|
||||
# Convert to string and clean control characters
|
||||
clean_value = (
|
||||
str(value)
|
||||
.replace("\n", " ")
|
||||
.replace("\r", " ")
|
||||
.replace("\t", " ")
|
||||
)
|
||||
# Remove other control characters
|
||||
clean_value = "".join(
|
||||
char
|
||||
for char in clean_value
|
||||
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||
)
|
||||
cleaned_row[key] = clean_value
|
||||
|
||||
row_str = ", ".join(
|
||||
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Processing row {row_idx + 1}...")
|
||||
result = await self.structured_llm.ainvoke(row_str)
|
||||
if result.investor_data:
|
||||
return result.investor_data
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
|
||||
documents = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, max_concurrent: int = 10
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data one row at a time and save to databases"""
|
||||
results = []
|
||||
|
||||
# Create semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_row_with_semaphore(row_data):
|
||||
row, row_idx = row_data
|
||||
async with semaphore:
|
||||
return await self._process_row(row, row_idx)
|
||||
|
||||
# Create row tasks
|
||||
row_tasks = []
|
||||
for idx, row in df.iterrows():
|
||||
row_tasks.append((row, idx))
|
||||
|
||||
# Execute all rows concurrently
|
||||
row_results = await asyncio.gather(
|
||||
*[process_row_with_semaphore(row_data) for row_data in row_tasks],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Collect results, filtering out exceptions and None values
|
||||
for row_result in row_results:
|
||||
if not isinstance(row_result, Exception) and row_result is not None:
|
||||
results.append(row_result)
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
return results
|
||||
+74
-236
@@ -1,88 +1,47 @@
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import chromadb
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from py_schemas import InvestorData, InvestorList
|
||||
from settings import settings
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||
system_message = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n Get answers from the Sql database and the vector database"
|
||||
)
|
||||
db = SQLDatabase.from_uri(DATABASE_URL)
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.sql_session = sql_session
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0.3,
|
||||
model="openai/gpt-4o-mini",
|
||||
temperature=0,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||
prompt=system_message,
|
||||
tools=self.toolkit.get_tools(),
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
self.collection = self.vector_db_client.get_or_create_collection(
|
||||
name="investor_descriptions",
|
||||
metadata={
|
||||
"description": "Investor descriptions and investment thesis focus"
|
||||
},
|
||||
)
|
||||
|
||||
def query_sql_database(self, query: str) -> Optional[InvestorList]:
|
||||
"""Query the SQL database for investor information."""
|
||||
if not self.sql_session:
|
||||
return None
|
||||
|
||||
# Implement SQL querying logic here
|
||||
result = self.sql_session.execute(query)
|
||||
investors = result.scalars().all()
|
||||
return InvestorList(investors=investors)
|
||||
|
||||
def query_vector_database(self, query: str) -> Optional[InvestorList]:
|
||||
"""Query the vector database for investor information."""
|
||||
if not self.vector_db_client:
|
||||
return None
|
||||
print("VECTOR STORE WAS CALLED")
|
||||
|
||||
# Query the collection directly, not passing collection as parameter
|
||||
results = self.collection.query(
|
||||
query_texts=[query], # ChromaDB expects a list of query texts
|
||||
n_results=3, # Specify how many results you want
|
||||
)
|
||||
print(results)
|
||||
|
||||
# ChromaDB returns results in a different structure
|
||||
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
||||
return results
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return structured investor data."""
|
||||
# Extract filters from the query first
|
||||
filters = self._extract_filters_from_query(question)
|
||||
|
||||
# Get AI response for additional context
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
@@ -92,189 +51,68 @@ class QueryProcessor:
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Try to extract investor IDs or names from the AI response
|
||||
investor_ids = self._extract_investor_info_from_response(ai_response)
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
|
||||
# Fetch filtered investor data with relationships from database
|
||||
return self._fetch_investors_with_relationships(investor_ids, filters)
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
|
||||
def _extract_investor_info_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response. This is a simple implementation."""
|
||||
# This is a basic implementation - you might want to make it more sophisticated
|
||||
# based on how your AI formats responses
|
||||
investor_ids = []
|
||||
|
||||
# If the AI can't provide structured data, fall back to getting all investors
|
||||
# that match basic criteria
|
||||
try:
|
||||
# Try to extract numbers that might be IDs
|
||||
import re
|
||||
|
||||
ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower())
|
||||
investor_ids = [int(id_str) for id_str in ids]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return investor_ids if investor_ids else []
|
||||
|
||||
def _extract_filters_from_query(self, question: str) -> dict:
|
||||
"""Extract filter criteria from natural language query."""
|
||||
question_lower = question.lower()
|
||||
filters = {}
|
||||
|
||||
# Extract stage filters
|
||||
if any(
|
||||
stage in question_lower
|
||||
for stage in [
|
||||
"seed",
|
||||
"series a",
|
||||
"series b",
|
||||
"series c",
|
||||
"growth",
|
||||
"late stage",
|
||||
]
|
||||
):
|
||||
if "seed" in question_lower:
|
||||
filters["stage"] = "SEED"
|
||||
elif "series a" in question_lower:
|
||||
filters["stage"] = "SERIES_A"
|
||||
elif "series b" in question_lower:
|
||||
filters["stage"] = "SERIES_B"
|
||||
elif "series c" in question_lower:
|
||||
filters["stage"] = "SERIES_C"
|
||||
elif "growth" in question_lower:
|
||||
filters["stage"] = "GROWTH"
|
||||
elif "late stage" in question_lower:
|
||||
filters["stage"] = "LATE_STAGE"
|
||||
|
||||
# Extract geographic filters
|
||||
if any(
|
||||
geo in question_lower
|
||||
for geo in [
|
||||
"us",
|
||||
"usa",
|
||||
"united states",
|
||||
"europe",
|
||||
"asia",
|
||||
"silicon valley",
|
||||
"bay area",
|
||||
]
|
||||
):
|
||||
if (
|
||||
"us" in question_lower
|
||||
or "usa" in question_lower
|
||||
or "united states" in question_lower
|
||||
):
|
||||
filters["geography"] = "US"
|
||||
elif "europe" in question_lower:
|
||||
filters["geography"] = "Europe"
|
||||
elif "asia" in question_lower:
|
||||
filters["geography"] = "Asia"
|
||||
elif "silicon valley" in question_lower or "bay area" in question_lower:
|
||||
filters["geography"] = "Silicon Valley"
|
||||
|
||||
# Extract sector filters
|
||||
sectors = [
|
||||
"fintech",
|
||||
"healthcare",
|
||||
"saas",
|
||||
"ai",
|
||||
"biotech",
|
||||
"consumer",
|
||||
"enterprise",
|
||||
"crypto",
|
||||
"blockchain",
|
||||
]
|
||||
for sector in sectors:
|
||||
if sector in question_lower:
|
||||
filters["sector"] = sector
|
||||
break
|
||||
|
||||
# Extract check size filters (simple patterns)
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
import re
|
||||
|
||||
amounts = re.findall(
|
||||
r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower
|
||||
)
|
||||
if amounts:
|
||||
amount = amounts[0].replace(",", "")
|
||||
if "million" in question_lower or "m" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000000)
|
||||
elif "thousand" in question_lower or "k" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000)
|
||||
investor_ids = []
|
||||
try:
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
|
||||
return filters
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
def _fetch_investors_with_relationships(
|
||||
self, investor_ids: List[int] = None, filters: dict = None
|
||||
) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database."""
|
||||
if not self.sql_session:
|
||||
except Exception as e:
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
return investor_ids
|
||||
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from db.models import SectorTable
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
# Build query with all relationships loaded
|
||||
query = self.sql_session.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
if "stage" in filters:
|
||||
from db.models import InvestmentStage
|
||||
|
||||
stage_enum = getattr(InvestmentStage, filters["stage"])
|
||||
query = query.filter(InvestorTable.stage_focus == stage_enum)
|
||||
|
||||
if "geography" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%")
|
||||
try:
|
||||
# Build query with all relationships loaded
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
if "min_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_lower >= filters["min_check_size"]
|
||||
)
|
||||
|
||||
if "max_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_upper <= filters["max_check_size"]
|
||||
)
|
||||
|
||||
if "min_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum >= filters["min_aum"])
|
||||
|
||||
if "max_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum <= filters["max_aum"])
|
||||
|
||||
if "sector" in filters:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{filters['sector']}%")
|
||||
)
|
||||
|
||||
# Filter by IDs if provided
|
||||
if investor_ids:
|
||||
query = query.filter(InvestorTable.id.in_(investor_ids))
|
||||
else:
|
||||
# If no specific IDs and no filters, limit to prevent overwhelming response
|
||||
if not filters:
|
||||
query = query.limit(10)
|
||||
|
||||
investors = query.all()
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
investors = query.all()
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
OPENROUTER_API_KEY: str
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user