feat: Implement company querying functionality with natural language processing and logging
This commit is contained in:
Binary file not shown.
+25
-1
@@ -1,4 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from db.db import Base, db_dependency, engine
|
from db.db import Base, db_dependency, engine
|
||||||
@@ -13,7 +15,8 @@ from routers import (
|
|||||||
projects,
|
projects,
|
||||||
report_route,
|
report_route,
|
||||||
)
|
)
|
||||||
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
|
from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
|
||||||
|
from services.company_querying import CompanyQueryProcessor
|
||||||
from services.llm_parser import InvestorProcessor
|
from services.llm_parser import InvestorProcessor
|
||||||
from services.querying import QueryProcessor
|
from services.querying import QueryProcessor
|
||||||
|
|
||||||
@@ -114,6 +117,27 @@ async def query_investors(request: QueryRequest):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
|
||||||
|
)
|
||||||
|
async def query_companies(request: QueryRequest):
|
||||||
|
"""
|
||||||
|
Query companies using natural language.
|
||||||
|
|
||||||
|
Returns company matches with their investor relationships, team members, and sectors.
|
||||||
|
|
||||||
|
Supports queries like:
|
||||||
|
- "Show me fintech companies founded in 2020"
|
||||||
|
- "Find healthcare companies in San Francisco"
|
||||||
|
- "Companies in the AI sector"
|
||||||
|
- "Companies that received funding from Sequoia"
|
||||||
|
- "European startups founded after 2019"
|
||||||
|
"""
|
||||||
|
processor = CompanyQueryProcessor()
|
||||||
|
results = processor.process_query(request.question)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
app.include_router(investors.router)
|
app.include_router(investors.router)
|
||||||
app.include_router(companies.router)
|
app.include_router(companies.router)
|
||||||
app.include_router(projects.router)
|
app.include_router(projects.router)
|
||||||
|
|||||||
@@ -1,15 +1,21 @@
|
|||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from db.db import get_db
|
from db.db import get_db
|
||||||
from db.models import InvestorTable
|
from db.models import InvestorTable
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from services.crm import folk
|
from services.crm import FolkAPI
|
||||||
from sqlalchemy.orm import Session, selectinload
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
|
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_folk_client():
|
||||||
|
"""Get Folk API client with loaded environment variables"""
|
||||||
|
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
|
||||||
|
|
||||||
|
|
||||||
class GroupResponse(BaseModel):
|
class GroupResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -44,6 +50,7 @@ def get_folk_groups():
|
|||||||
to sync investors to Folk.
|
to sync investors to Folk.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
folk = get_folk_client()
|
||||||
groups_data = folk.get_groups()
|
groups_data = folk.get_groups()
|
||||||
items = groups_data.get("data", {}).get("items", [])
|
items = groups_data.get("data", {}).get("items", [])
|
||||||
|
|
||||||
@@ -71,6 +78,7 @@ def sync_investors_to_folk(
|
|||||||
Returns:
|
Returns:
|
||||||
Summary of sync operation including successes and errors
|
Summary of sync operation including successes and errors
|
||||||
"""
|
"""
|
||||||
|
folk = get_folk_client()
|
||||||
# Fetch investors with their team members
|
# Fetch investors with their team members
|
||||||
investors = (
|
investors = (
|
||||||
db.query(InvestorTable)
|
db.query(InvestorTable)
|
||||||
|
|||||||
Binary file not shown.
@@ -168,6 +168,7 @@ class InvestorFundData(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class InvestorMinimal(BaseModel):
|
class InvestorMinimal(BaseModel):
|
||||||
"""Minimal investor info with just id and name"""
|
"""Minimal investor info with just id and name"""
|
||||||
|
|
||||||
@@ -177,6 +178,7 @@ class InvestorMinimal(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class CompanySchemaMinimal(BaseModel):
|
class CompanySchemaMinimal(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
@@ -188,9 +190,12 @@ class CompanySchemaMinimal(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||||
company: CompanySchemaMinimal
|
company: CompanySchemaMinimal
|
||||||
investors: List[InvestorMinimal]
|
investors: List[InvestorMinimal]
|
||||||
|
# members: List[CompanyMemberSchema] = []
|
||||||
|
sectors: List[SectorSchema] = []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|||||||
@@ -0,0 +1,176 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from db.db import DATABASE_URL, get_db
|
||||||
|
from db.models import CompanyTable
|
||||||
|
from langchain import hub
|
||||||
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||||
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
# Connect to SQLite
|
||||||
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||||
|
db = SQLDatabase.from_uri(DATABASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyQueryProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = ChatOpenAI(
|
||||||
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model="openai/gpt-4o-mini",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||||
|
# Update system message to specifically request only company IDs
|
||||||
|
system_message_updated = (
|
||||||
|
prompt_template.format(dialect="SQLite", top_k=5)
|
||||||
|
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
|
||||||
|
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
|
||||||
|
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
|
||||||
|
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
|
||||||
|
+ "\n- Do NOT add any explanations, just list the IDs"
|
||||||
|
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
|
||||||
|
+ "\n\n=== QUERY GUIDELINES ==="
|
||||||
|
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
|
||||||
|
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
|
||||||
|
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
|
||||||
|
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
|
||||||
|
+ "\n5. For investor-related: JOIN investor_companies table"
|
||||||
|
)
|
||||||
|
self.agent = create_react_agent(
|
||||||
|
model=self.llm,
|
||||||
|
tools=self.toolkit.get_tools(),
|
||||||
|
prompt=system_message_updated,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
||||||
|
"""Process a query using the LLM and return company response data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question: The natural language query to process
|
||||||
|
"""
|
||||||
|
# Let the LLM handle all database interactions and filtering to get company IDs
|
||||||
|
response = self.agent.invoke(
|
||||||
|
{"messages": [("user", question)]},
|
||||||
|
config={"recursion_limit": 50},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the actual message content
|
||||||
|
logger.info(f"{response}")
|
||||||
|
|
||||||
|
# Look through all messages to find the SQL query results (ToolMessage with actual data)
|
||||||
|
company_ids = []
|
||||||
|
for message in response["messages"]:
|
||||||
|
if hasattr(message, "content") and message.content:
|
||||||
|
# Check if this looks like SQL results (contains tuples with numbers)
|
||||||
|
if "(" in str(message.content) and "," in str(message.content):
|
||||||
|
company_ids = self._extract_company_ids_from_response(
|
||||||
|
str(message.content)
|
||||||
|
)
|
||||||
|
if company_ids:
|
||||||
|
logger.info(
|
||||||
|
f"Extracted {len(company_ids)} company IDs from results"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no IDs found from ToolMessage, check the final AI message
|
||||||
|
if not company_ids:
|
||||||
|
final_message_content = response["messages"][-1].content
|
||||||
|
logger.info(f"AI Response: \n{final_message_content}")
|
||||||
|
company_ids = self._extract_company_ids_from_response(final_message_content)
|
||||||
|
|
||||||
|
# Fetch full company data with relationships using the IDs
|
||||||
|
return self._fetch_companies_by_ids(company_ids)
|
||||||
|
|
||||||
|
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
|
||||||
|
"""Extract company IDs from AI response."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
company_ids = []
|
||||||
|
|
||||||
|
# Check if response is NO_RESULTS
|
||||||
|
if "NO_RESULTS" in ai_response.upper():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The response contains tuples like (1,), (5,), etc.
|
||||||
|
# Extract numbers between parentheses
|
||||||
|
pattern = r"\((\d+),?\)"
|
||||||
|
matches = re.findall(pattern, ai_response)
|
||||||
|
if matches:
|
||||||
|
company_ids = [int(match) for match in matches]
|
||||||
|
else:
|
||||||
|
# Fallback: extract all numbers
|
||||||
|
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||||
|
# Filter out very large numbers that might be tokens or timestamps
|
||||||
|
company_ids = [int(num) for num in numbers if int(num) < 100000]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting IDs from response: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
return company_ids
|
||||||
|
|
||||||
|
def _fetch_companies_by_ids(
|
||||||
|
self, company_ids: List[int]
|
||||||
|
) -> PaginatedResponse[CompanyData]:
|
||||||
|
"""Fetch companies with all their relationships from the database using company IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
company_ids: List of company IDs to fetch
|
||||||
|
"""
|
||||||
|
if not company_ids:
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=[],
|
||||||
|
total=0,
|
||||||
|
page=1,
|
||||||
|
page_size=len(company_ids) if company_ids else 10,
|
||||||
|
total_pages=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get database session
|
||||||
|
db_session = next(get_db())
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Query companies with all necessary relationships loaded
|
||||||
|
companies = (
|
||||||
|
db_session.query(CompanyTable)
|
||||||
|
.options(
|
||||||
|
selectinload(CompanyTable.investors),
|
||||||
|
selectinload(CompanyTable.members),
|
||||||
|
selectinload(CompanyTable.sectors),
|
||||||
|
)
|
||||||
|
.filter(CompanyTable.id.in_(company_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform to CompanyData format
|
||||||
|
company_data_list = []
|
||||||
|
for company in companies:
|
||||||
|
company_data = CompanyData(
|
||||||
|
company=company,
|
||||||
|
investors=company.investors,
|
||||||
|
members=company.members,
|
||||||
|
sectors=company.sectors,
|
||||||
|
)
|
||||||
|
company_data_list.append(company_data)
|
||||||
|
|
||||||
|
total_count = len(company_data_list)
|
||||||
|
total_pages = 1 if total_count > 0 else 0
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=company_data_list,
|
||||||
|
total=total_count,
|
||||||
|
page=1,
|
||||||
|
page_size=total_count,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
+10
-68
@@ -1,14 +1,24 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler()],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FolkAPI:
|
class FolkAPI:
|
||||||
BASE_URL = "https://api.folk.app/v1"
|
BASE_URL = "https://api.folk.app/v1"
|
||||||
|
|
||||||
def __init__(self, api_key: str):
|
def __init__(self, api_key: str):
|
||||||
|
api_key = os.environ.get("FOLK_API_KEY", api_key)
|
||||||
self.headers = {"Authorization": f"Bearer {api_key}"}
|
self.headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
|
||||||
|
|
||||||
def get_groups(self):
|
def get_groups(self):
|
||||||
"""Fetch all groups from Folk."""
|
"""Fetch all groups from Folk."""
|
||||||
@@ -190,71 +200,3 @@ class FolkAPI:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
# Prefer getting the API key from the environment. If not set, fall back to the
|
|
||||||
# existing (hard-coded) key so behavior is unchanged for now.
|
|
||||||
DEFAULT_API_KEY = "FOLKfIGXuv74ML9EAajxyiUR39ePaNrZ"
|
|
||||||
api_key = os.environ.get("FOLK_API_KEY", DEFAULT_API_KEY)
|
|
||||||
|
|
||||||
folk = FolkAPI(api_key=api_key)
|
|
||||||
|
|
||||||
|
|
||||||
def example_flow():
|
|
||||||
# Step 1: Get groups
|
|
||||||
groups = folk.get_groups()
|
|
||||||
print(groups)
|
|
||||||
|
|
||||||
# Safely dig into the returned structure. The API returns groups under
|
|
||||||
# groups['data']['items'] (not groups['data'][0]). Handle missing/empty.
|
|
||||||
items = groups.get("data", {}).get("items", [])
|
|
||||||
if not items:
|
|
||||||
print("No groups returned by Folk API.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Choose the first group as an example
|
|
||||||
group_id = items[0].get("id")
|
|
||||||
if not group_id:
|
|
||||||
print("No id found for the first group item.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Step 2: Choose a group_id and create a company
|
|
||||||
company = folk.create_company(
|
|
||||||
name="2050 Investment Partners",
|
|
||||||
group_id=group_id,
|
|
||||||
website="https://2050.com",
|
|
||||||
linkedin_url="https://linkedin.com/company/2050-investments",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Add a person to the same group or company
|
|
||||||
person = folk.create_person(
|
|
||||||
first_name="John",
|
|
||||||
last_name="Doe",
|
|
||||||
email="john@2050.com",
|
|
||||||
company_id=company.get("data", {}).get("id"),
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Created company:", company)
|
|
||||||
print("Created person:", person)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
|
||||||
example_flow()
|
|
||||||
except requests.HTTPError as e:
|
|
||||||
# Try to include response body for easier debugging if available
|
|
||||||
resp = getattr(e, "response", None)
|
|
||||||
if resp is not None:
|
|
||||||
try:
|
|
||||||
body = resp.text
|
|
||||||
except Exception:
|
|
||||||
body = "<unreadable response body>"
|
|
||||||
print("HTTP error while talking to Folk API:", e)
|
|
||||||
print("Response status:", resp.status_code)
|
|
||||||
print("Response body:", body)
|
|
||||||
else:
|
|
||||||
print("HTTP error while talking to Folk API:", e)
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e: # pragma: no cover - top-level safety
|
|
||||||
print("Unexpected error:", e)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user