94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
import faiss
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain.chains import create_retrieval_chain
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
import openai
|
|
from openai import OpenAI
|
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
|
from pathlib import Path
|
|
from langchain_openai import ChatOpenAI,OpenAIEmbeddings
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from scripts.rag import get_calculations, response_agent,setup_rag_system
|
|
import os
|
|
from typing import Optional
|
|
from fastapi import FastAPI, HTTPException, Security, Depends
|
|
from fastapi.security import APIKeyHeader
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
# Load environment variables from a .env file
|
|
|
|
# Initialize OpenAI LLM and Embeddings
|
|
llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model
|
|
embedding_function = OpenAIEmbeddings()
|
|
|
|
# Set up the RAG system
|
|
file_path = r'docs/UpdateAI_training-company-data.md'
|
|
rag_chain = setup_rag_system(file_path, llm, embedding_function)
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI()
|
|
load_dotenv()
|
|
API_KEY = os.getenv("API_KEY")
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(
|
|
title="Update stack AI API",
|
|
description="API For fire Update stack",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Setup API key authentication
|
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|
|
|
|
|
async def get_api_key(api_key_header: str = Security(api_key_header)) -> str:
|
|
"""Validate API key from header"""
|
|
if not api_key_header or not api_key_header.startswith('Bearer '):
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={"error": "Unauthorized", "message": "API key is missing or invalid."}
|
|
)
|
|
|
|
token = api_key_header.split(' ')[1]
|
|
if token != API_KEY:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={"error": "Unauthorized", "message": "API key does not match."}
|
|
)
|
|
|
|
return token
|
|
# Define a request model
|
|
class QuestionRequest(BaseModel):
|
|
question: str
|
|
|
|
@app.post("/update-stack/v1/ask")
|
|
def ask_question(request: QuestionRequest,
|
|
api_key: str = Depends(get_api_key)):
|
|
try:
|
|
# Use the RAG system to get the answer
|
|
calculation_data = get_calculations(rag_chain, request.question)
|
|
final_response = response_agent(request.question, calculation_data)
|
|
return {"answer": final_response}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=5079) |