859c17aad8
- Update config.py with Pinecone settings and model configurations - Implement VectorStore class with Pinecone backend - Add comprehensive vector operations (add, search, delete) - Set up proper error handling and metadata management - Add .gitignore for Python project
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
import requests
|
|
import json
|
|
import argparse
|
|
|
|
def test_direct_model(prompt, max_length=200, num_return_sequences=1, temperature=0.7, top_p=0.9):
|
|
"""Test the direct model inference endpoint."""
|
|
url = "http://localhost:8000/direct-model"
|
|
|
|
payload = {
|
|
"prompt": prompt,
|
|
"max_length": max_length,
|
|
"num_return_sequences": num_return_sequences,
|
|
"temperature": temperature,
|
|
"top_p": top_p
|
|
}
|
|
|
|
print(f"Sending request to {url}")
|
|
print(f"Payload: {json.dumps(payload, indent=2)}")
|
|
|
|
response = requests.post(url, json=payload)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
print("\nResponse:")
|
|
for i, text in enumerate(result["generated_texts"]):
|
|
print(f"\n--- Generation {i+1} ---")
|
|
print(text)
|
|
else:
|
|
print(f"Error: {response.status_code}")
|
|
print(response.text)
|
|
|
|
def test_generate_copy(prompt, content_type, tone=None, target_audience=None):
|
|
"""Test the generate-copy endpoint."""
|
|
url = "http://localhost:8000/generate-copy"
|
|
|
|
payload = {
|
|
"prompt": prompt,
|
|
"content_type": content_type
|
|
}
|
|
|
|
if tone:
|
|
payload["tone"] = tone
|
|
if target_audience:
|
|
payload["target_audience"] = target_audience
|
|
|
|
print(f"Sending request to {url}")
|
|
print(f"Payload: {json.dumps(payload, indent=2)}")
|
|
|
|
response = requests.post(url, json=payload)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
print("\nResponse:")
|
|
print(f"Content: {result['content']}")
|
|
print(f"Confidence Score: {result['confidence_score']}")
|
|
print(f"Brand Alignment Score: {result['brand_alignment_score']}")
|
|
else:
|
|
print(f"Error: {response.status_code}")
|
|
print(response.text)
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test the backend with the finetuned model")
|
|
parser.add_argument("--endpoint", type=str, choices=["direct-model", "generate-copy"], default="direct-model",
|
|
help="Endpoint to test")
|
|
parser.add_argument("--prompt", type=str, default="Create a welcome message for new clients",
|
|
help="Prompt to generate text for")
|
|
parser.add_argument("--content-type", type=str, default="email",
|
|
help="Content type (for generate-copy endpoint)")
|
|
parser.add_argument("--tone", type=str, default=None,
|
|
help="Tone (for generate-copy endpoint)")
|
|
parser.add_argument("--target-audience", type=str, default=None,
|
|
help="Target audience (for generate-copy endpoint)")
|
|
parser.add_argument("--max-length", type=int, default=200,
|
|
help="Maximum length of the generated text (for direct-model endpoint)")
|
|
parser.add_argument("--num-return-sequences", type=int, default=1,
|
|
help="Number of sequences to generate (for direct-model endpoint)")
|
|
parser.add_argument("--temperature", type=float, default=0.7,
|
|
help="Temperature for sampling (for direct-model endpoint)")
|
|
parser.add_argument("--top-p", type=float, default=0.9,
|
|
help="Top-p for sampling (for direct-model endpoint)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.endpoint == "direct-model":
|
|
test_direct_model(
|
|
prompt=args.prompt,
|
|
max_length=args.max_length,
|
|
num_return_sequences=args.num_return_sequences,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p
|
|
)
|
|
else:
|
|
test_generate_copy(
|
|
prompt=args.prompt,
|
|
content_type=args.content_type,
|
|
tone=args.tone,
|
|
target_audience=args.target_audience
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
main() |