Add CompanyTable model and refactor query handling; update requirements for new dependencies
This commit is contained in:
+34
-12
@@ -1,13 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from settings import settings
|
||||
|
||||
# Add these imports for your databases
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from your_vector_db import VectorDBClient
|
||||
# Connect to SQLite
|
||||
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||
system_message = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n Get answers from the Sql database and the vector database"
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
@@ -19,12 +28,16 @@ class QueryProcessor:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-oss-120b:free",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
||||
self.sql_session = sql_session
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||
prompt=system_message,
|
||||
response_format=InvestorList,
|
||||
)
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||
@@ -49,13 +62,22 @@ class QueryProcessor:
|
||||
"""Query the vector database for investor information."""
|
||||
if not self.vector_db_client:
|
||||
return None
|
||||
print("VECTOR STORE WAS CALLED")
|
||||
|
||||
# Implement vector database querying logic here
|
||||
results = self.vector_db_client.query(collection=self.collection, query=query)
|
||||
investors = [Investor(**doc.metadata) for doc in results.documents]
|
||||
return InvestorList(investors=investors)
|
||||
# Query the collection directly, not passing collection as parameter
|
||||
results = self.collection.query(
|
||||
query_texts=[query], # ChromaDB expects a list of query texts
|
||||
n_results=3, # Specify how many results you want
|
||||
)
|
||||
print(results)
|
||||
|
||||
# ChromaDB returns results in a different structure
|
||||
# results will have 'documents', 'metadatas', 'ids', 'distances'
|
||||
return results
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return structured investor data."""
|
||||
response = self.structured_llm.predict(question=question)
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user