First Commit
This commit is contained in:
@@ -14,22 +14,75 @@ import os
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from scripts.rag import get_answer, response_agent
|
from scripts.rag import get_calculations, response_agent,setup_rag_system
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import FastAPI, HTTPException, Security, Depends
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
# Load environment variables from a .env file
|
# Load environment variables from a .env file
|
||||||
|
|
||||||
|
# Initialize OpenAI LLM and Embeddings
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model
|
||||||
|
embedding_function = OpenAIEmbeddings()
|
||||||
|
|
||||||
|
# Set up the RAG system
|
||||||
|
file_path = r'docs/UpdateAI_training-company-data.md'
|
||||||
|
rag_chain = setup_rag_system(file_path, llm, embedding_function)
|
||||||
|
|
||||||
# Initialize FastAPI app
|
# Initialize FastAPI app
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
load_dotenv()
|
||||||
|
API_KEY = os.getenv("API_KEY")
|
||||||
|
|
||||||
|
# Initialize FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="Update stack AI API",
|
||||||
|
description="API For fire Update stack",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup API key authentication
|
||||||
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_api_key(api_key_header: str = Security(api_key_header)) -> str:
|
||||||
|
"""Validate API key from header"""
|
||||||
|
if not api_key_header or not api_key_header.startswith('Bearer '):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"error": "Unauthorized", "message": "API key is missing or invalid."}
|
||||||
|
)
|
||||||
|
|
||||||
|
token = api_key_header.split(' ')[1]
|
||||||
|
if token != API_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"error": "Unauthorized", "message": "API key does not match."}
|
||||||
|
)
|
||||||
|
|
||||||
|
return token
|
||||||
# Define a request model
|
# Define a request model
|
||||||
class QuestionRequest(BaseModel):
|
class QuestionRequest(BaseModel):
|
||||||
question: str
|
question: str
|
||||||
|
|
||||||
@app.post("/ask")
|
@app.post("/update-stack/v1/ask")
|
||||||
def ask_question(request: QuestionRequest):
|
def ask_question(request: QuestionRequest,
|
||||||
|
api_key: str = Depends(get_api_key)):
|
||||||
try:
|
try:
|
||||||
# Use the RAG system to get the answer
|
# Use the RAG system to get the answer
|
||||||
calculation_data = get_answer(rag_chain, request.question)
|
calculation_data = get_calculations(rag_chain, request.question)
|
||||||
final_response = response_agent(request.question, calculation_data)
|
final_response = response_agent(request.question, calculation_data)
|
||||||
return {"answer": final_response}
|
return {"answer": final_response}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+1
-1
@@ -70,7 +70,7 @@ def setup_rag_system(file_path, llm, embedding_function):
|
|||||||
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
||||||
return create_retrieval_chain(retriever, question_answer_chain)
|
return create_retrieval_chain(retriever, question_answer_chain)
|
||||||
|
|
||||||
def get_answer(rag_chain, user_input):
|
def get_calculations(rag_chain, user_input):
|
||||||
answer = rag_chain.invoke({"input": user_input})
|
answer = rag_chain.invoke({"input": user_input})
|
||||||
return answer['answer']
|
return answer['answer']
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user