Files
boladeE 859c17aad8 feat: Implement Pinecone vector store integration
- Update config.py with Pinecone settings and model configurations
- Implement VectorStore class with Pinecone backend
- Add comprehensive vector operations (add, search, delete)
- Set up proper error handling and metadata management
- Add .gitignore for Python project
2025-04-16 23:09:52 +01:00

101 lines
3.8 KiB
Python

import requests
import json
import argparse
def test_direct_model(prompt, max_length=200, num_return_sequences=1, temperature=0.7, top_p=0.9):
"""Test the direct model inference endpoint."""
url = "http://localhost:8000/direct-model"
payload = {
"prompt": prompt,
"max_length": max_length,
"num_return_sequences": num_return_sequences,
"temperature": temperature,
"top_p": top_p
}
print(f"Sending request to {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")
response = requests.post(url, json=payload)
if response.status_code == 200:
result = response.json()
print("\nResponse:")
for i, text in enumerate(result["generated_texts"]):
print(f"\n--- Generation {i+1} ---")
print(text)
else:
print(f"Error: {response.status_code}")
print(response.text)
def test_generate_copy(prompt, content_type, tone=None, target_audience=None):
"""Test the generate-copy endpoint."""
url = "http://localhost:8000/generate-copy"
payload = {
"prompt": prompt,
"content_type": content_type
}
if tone:
payload["tone"] = tone
if target_audience:
payload["target_audience"] = target_audience
print(f"Sending request to {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")
response = requests.post(url, json=payload)
if response.status_code == 200:
result = response.json()
print("\nResponse:")
print(f"Content: {result['content']}")
print(f"Confidence Score: {result['confidence_score']}")
print(f"Brand Alignment Score: {result['brand_alignment_score']}")
else:
print(f"Error: {response.status_code}")
print(response.text)
def main():
parser = argparse.ArgumentParser(description="Test the backend with the finetuned model")
parser.add_argument("--endpoint", type=str, choices=["direct-model", "generate-copy"], default="direct-model",
help="Endpoint to test")
parser.add_argument("--prompt", type=str, default="Create a welcome message for new clients",
help="Prompt to generate text for")
parser.add_argument("--content-type", type=str, default="email",
help="Content type (for generate-copy endpoint)")
parser.add_argument("--tone", type=str, default=None,
help="Tone (for generate-copy endpoint)")
parser.add_argument("--target-audience", type=str, default=None,
help="Target audience (for generate-copy endpoint)")
parser.add_argument("--max-length", type=int, default=200,
help="Maximum length of the generated text (for direct-model endpoint)")
parser.add_argument("--num-return-sequences", type=int, default=1,
help="Number of sequences to generate (for direct-model endpoint)")
parser.add_argument("--temperature", type=float, default=0.7,
help="Temperature for sampling (for direct-model endpoint)")
parser.add_argument("--top-p", type=float, default=0.9,
help="Top-p for sampling (for direct-model endpoint)")
args = parser.parse_args()
if args.endpoint == "direct-model":
test_direct_model(
prompt=args.prompt,
max_length=args.max_length,
num_return_sequences=args.num_return_sequences,
temperature=args.temperature,
top_p=args.top_p
)
else:
test_generate_copy(
prompt=args.prompt,
content_type=args.content_type,
tone=args.tone,
target_audience=args.target_audience
)
if __name__ == "__main__":
main()