First Commit

This commit is contained in:
2025-03-06 17:32:42 +00:00
parent f97d731c82
commit 0c1f372e17
2 changed files with 58 additions and 5 deletions
+57 -4
View File
@@ -14,22 +14,75 @@ import os
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from scripts.rag import get_answer, response_agent from scripts.rag import get_calculations, response_agent,setup_rag_system
import os
from typing import Optional
from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
load_dotenv()
# Load environment variables from a .env file # Load environment variables from a .env file
# Initialize OpenAI LLM and Embeddings
llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model
embedding_function = OpenAIEmbeddings()
# Set up the RAG system
file_path = r'docs/UpdateAI_training-company-data.md'
rag_chain = setup_rag_system(file_path, llm, embedding_function)
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI() app = FastAPI()
load_dotenv()
API_KEY = os.getenv("API_KEY")
# Initialize FastAPI app
app = FastAPI(
title="Update stack AI API",
description="API For fire Update stack",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup API key authentication
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def get_api_key(api_key_header: str = Security(api_key_header)) -> str:
"""Validate API key from header"""
if not api_key_header or not api_key_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail={"error": "Unauthorized", "message": "API key is missing or invalid."}
)
token = api_key_header.split(' ')[1]
if token != API_KEY:
raise HTTPException(
status_code=401,
detail={"error": "Unauthorized", "message": "API key does not match."}
)
return token
# Define a request model # Define a request model
class QuestionRequest(BaseModel): class QuestionRequest(BaseModel):
question: str question: str
@app.post("/ask") @app.post("/update-stack/v1/ask")
def ask_question(request: QuestionRequest): def ask_question(request: QuestionRequest,
api_key: str = Depends(get_api_key)):
try: try:
# Use the RAG system to get the answer # Use the RAG system to get the answer
calculation_data = get_answer(rag_chain, request.question) calculation_data = get_calculations(rag_chain, request.question)
final_response = response_agent(request.question, calculation_data) final_response = response_agent(request.question, calculation_data)
return {"answer": final_response} return {"answer": final_response}
except Exception as e: except Exception as e:
+1 -1
View File
@@ -70,7 +70,7 @@ def setup_rag_system(file_path, llm, embedding_function):
question_answer_chain = create_stuff_documents_chain(llm, prompt) question_answer_chain = create_stuff_documents_chain(llm, prompt)
return create_retrieval_chain(retriever, question_answer_chain) return create_retrieval_chain(retriever, question_answer_chain)
def get_answer(rag_chain, user_input): def get_calculations(rag_chain, user_input):
answer = rag_chain.invoke({"input": user_input}) answer = rag_chain.invoke({"input": user_input})
return answer['answer'] return answer['answer']