diff --git a/app.py b/app.py index 89e6383..469d05e 100644 --- a/app.py +++ b/app.py @@ -14,22 +14,75 @@ import os from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from scripts.rag import get_answer, response_agent +from scripts.rag import get_calculations, response_agent,setup_rag_system +import os +from typing import Optional +from fastapi import FastAPI, HTTPException, Security, Depends +from fastapi.security import APIKeyHeader +from fastapi.middleware.cors import CORSMiddleware +from dotenv import load_dotenv +load_dotenv() # Load environment variables from a .env file + # Initialize OpenAI LLM and Embeddings +llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model +embedding_function = OpenAIEmbeddings() + + # Set up the RAG system +file_path = r'docs/UpdateAI_training-company-data.md' +rag_chain = setup_rag_system(file_path, llm, embedding_function) # Initialize FastAPI app app = FastAPI() +load_dotenv() +API_KEY = os.getenv("API_KEY") +# Initialize FastAPI app +app = FastAPI( + title="Update stack AI API", + description="API For fire Update stack", + version="1.0.0" +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Setup API key authentication +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + + +async def get_api_key(api_key_header: str = Security(api_key_header)) -> str: + """Validate API key from header""" + if not api_key_header or not api_key_header.startswith('Bearer '): + raise HTTPException( + status_code=401, + detail={"error": "Unauthorized", "message": "API key is missing or invalid."} + ) + + token = api_key_header.split(' ')[1] + if token != API_KEY: + raise HTTPException( + status_code=401, + detail={"error": "Unauthorized", "message": "API key does not match."} + ) + + return token # Define a request model class QuestionRequest(BaseModel): question: str -@app.post("/ask") -def ask_question(request: QuestionRequest): +@app.post("/update-stack/v1/ask") +def ask_question(request: QuestionRequest, + api_key: str = Depends(get_api_key)): try: # Use the RAG system to get the answer - calculation_data = get_answer(rag_chain, request.question) + calculation_data = get_calculations(rag_chain, request.question) final_response = response_agent(request.question, calculation_data) return {"answer": final_response} except Exception as e: diff --git a/scripts/rag.py b/scripts/rag.py index ade2b3f..fc8faa5 100644 --- a/scripts/rag.py +++ b/scripts/rag.py @@ -70,7 +70,7 @@ def setup_rag_system(file_path, llm, embedding_function): question_answer_chain = create_stuff_documents_chain(llm, prompt) return create_retrieval_chain(retriever, question_answer_chain) -def get_answer(rag_chain, user_input): +def get_calculations(rag_chain, user_input): answer = rag_chain.invoke({"input": user_input}) return answer['answer']