2025-08-29 18:42:55 +01:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import chromadb
|
2025-09-02 12:22:50 +01:00
|
|
|
from langchain import hub
|
|
|
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
|
|
|
from langchain_community.utilities import SQLDatabase
|
2025-08-29 18:42:55 +01:00
|
|
|
from langchain_openai import ChatOpenAI
|
2025-09-02 12:22:50 +01:00
|
|
|
from langgraph.prebuilt import create_react_agent
|
2025-08-29 18:42:55 +01:00
|
|
|
from pydantic_schemas import Investor, InvestorList
|
|
|
|
|
from settings import settings
|
|
|
|
|
|
2025-09-02 12:22:50 +01:00
|
|
|
# Connect to SQLite
|
|
|
|
|
|
|
|
|
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
|
|
|
|
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
|
|
|
|
system_message = (
|
|
|
|
|
prompt_template.format(dialect="SQLite", top_k=5)
|
|
|
|
|
+ "\n Get answers from the Sql database and the vector database"
|
|
|
|
|
)
|
2025-08-29 18:42:55 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class QueryProcessor:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
sql_session: Optional[object] = None,
|
|
|
|
|
vector_db_client: Optional[object] = None,
|
|
|
|
|
):
|
|
|
|
|
self.llm = ChatOpenAI(
|
|
|
|
|
api_key=settings.OPENROUTER_API_KEY,
|
|
|
|
|
base_url="https://openrouter.ai/api/v1",
|
2025-09-02 12:22:50 +01:00
|
|
|
model="google/gemini-2.5-flash-lite",
|
2025-08-29 18:42:55 +01:00
|
|
|
temperature=0,
|
|
|
|
|
)
|
2025-09-02 12:22:50 +01:00
|
|
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
|
|
|
self.agent = create_react_agent(
|
|
|
|
|
model=self.llm,
|
|
|
|
|
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
|
|
|
|
prompt=system_message,
|
|
|
|
|
response_format=InvestorList,
|
|
|
|
|
)
|
2025-08-29 18:42:55 +01:00
|
|
|
self.vector_db_client = vector_db_client
|
|
|
|
|
|
|
|
|
|
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
|
|
|
|
self.collection = self.vector_db_client.get_or_create_collection(
|
|
|
|
|
name="investor_descriptions",
|
|
|
|
|
metadata={
|
|
|
|
|
"description": "Investor descriptions and investment thesis focus"
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def query_sql_database(self, query: str) -> Optional[InvestorList]:
|
|
|
|
|
"""Query the SQL database for investor information."""
|
|
|
|
|
if not self.sql_session:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Implement SQL querying logic here
|
|
|
|
|
result = self.sql_session.execute(query)
|
|
|
|
|
investors = result.scalars().all()
|
|
|
|
|
return InvestorList(investors=investors)
|
|
|
|
|
|
|
|
|
|
def query_vector_database(self, query: str) -> Optional[InvestorList]:
|
|
|
|
|
"""Query the vector database for investor information."""
|
|
|
|
|
if not self.vector_db_client:
|
|
|
|
|
return None
|
2025-09-02 12:22:50 +01:00
|
|
|
print("VECTOR STORE WAS CALLED")
|
2025-08-29 18:42:55 +01:00
|
|
|
|
2025-09-02 12:22:50 +01:00
|
|
|
# Query the collection directly, not passing collection as parameter
|
|
|
|
|
results = self.collection.query(
|
|
|
|
|
query_texts=[query], # ChromaDB expects a list of query texts
|
|
|
|
|
n_results=3, # Specify how many results you want
|
|
|
|
|
)
|
|
|
|
|
print(results)
|
|
|
|
|
|
|
|
|
|
# ChromaDB returns results in a different structure
|
|
|
|
|
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
|
|
|
|
return results
|
2025-08-29 18:42:55 +01:00
|
|
|
|
|
|
|
|
def process_query(self, question: str) -> InvestorList:
|
|
|
|
|
"""Process a query using the LLM and return structured investor data."""
|
2025-09-02 12:22:50 +01:00
|
|
|
response = self.agent.invoke(
|
|
|
|
|
{"messages": [("user", question)]},
|
|
|
|
|
)
|
2025-08-29 18:42:55 +01:00
|
|
|
return response
|