First Commit
This commit is contained in:
@@ -14,22 +14,75 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from scripts.rag import get_answer, response_agent
|
||||
from scripts.rag import get_calculations, response_agent,setup_rag_system
|
||||
import os
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, HTTPException, Security, Depends
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
# Load environment variables from a .env file
|
||||
|
||||
# Initialize OpenAI LLM and Embeddings
|
||||
llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model
|
||||
embedding_function = OpenAIEmbeddings()
|
||||
|
||||
# Set up the RAG system
|
||||
file_path = r'docs/UpdateAI_training-company-data.md'
|
||||
rag_chain = setup_rag_system(file_path, llm, embedding_function)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
API_KEY = os.getenv("API_KEY")
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="Update stack AI API",
|
||||
description="API For fire Update stack",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Setup API key authentication
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(api_key_header: str = Security(api_key_header)) -> str:
|
||||
"""Validate API key from header"""
|
||||
if not api_key_header or not api_key_header.startswith('Bearer '):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": "Unauthorized", "message": "API key is missing or invalid."}
|
||||
)
|
||||
|
||||
token = api_key_header.split(' ')[1]
|
||||
if token != API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={"error": "Unauthorized", "message": "API key does not match."}
|
||||
)
|
||||
|
||||
return token
|
||||
# Define a request model
|
||||
class QuestionRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
@app.post("/ask")
|
||||
def ask_question(request: QuestionRequest):
|
||||
@app.post("/update-stack/v1/ask")
|
||||
def ask_question(request: QuestionRequest,
|
||||
api_key: str = Depends(get_api_key)):
|
||||
try:
|
||||
# Use the RAG system to get the answer
|
||||
calculation_data = get_answer(rag_chain, request.question)
|
||||
calculation_data = get_calculations(rag_chain, request.question)
|
||||
final_response = response_agent(request.question, calculation_data)
|
||||
return {"answer": final_response}
|
||||
except Exception as e:
|
||||
|
||||
+1
-1
@@ -70,7 +70,7 @@ def setup_rag_system(file_path, llm, embedding_function):
|
||||
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
||||
return create_retrieval_chain(retriever, question_answer_chain)
|
||||
|
||||
def get_answer(rag_chain, user_input):
|
||||
def get_calculations(rag_chain, user_input):
|
||||
answer = rag_chain.invoke({"input": user_input})
|
||||
return answer['answer']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user