import requests import json import argparse def test_direct_model(prompt, max_length=200, num_return_sequences=1, temperature=0.7, top_p=0.9): """Test the direct model inference endpoint.""" url = "http://localhost:8000/direct-model" payload = { "prompt": prompt, "max_length": max_length, "num_return_sequences": num_return_sequences, "temperature": temperature, "top_p": top_p } print(f"Sending request to {url}") print(f"Payload: {json.dumps(payload, indent=2)}") response = requests.post(url, json=payload) if response.status_code == 200: result = response.json() print("\nResponse:") for i, text in enumerate(result["generated_texts"]): print(f"\n--- Generation {i+1} ---") print(text) else: print(f"Error: {response.status_code}") print(response.text) def test_generate_copy(prompt, content_type, tone=None, target_audience=None): """Test the generate-copy endpoint.""" url = "http://localhost:8000/generate-copy" payload = { "prompt": prompt, "content_type": content_type } if tone: payload["tone"] = tone if target_audience: payload["target_audience"] = target_audience print(f"Sending request to {url}") print(f"Payload: {json.dumps(payload, indent=2)}") response = requests.post(url, json=payload) if response.status_code == 200: result = response.json() print("\nResponse:") print(f"Content: {result['content']}") print(f"Confidence Score: {result['confidence_score']}") print(f"Brand Alignment Score: {result['brand_alignment_score']}") else: print(f"Error: {response.status_code}") print(response.text) def main(): parser = argparse.ArgumentParser(description="Test the backend with the finetuned model") parser.add_argument("--endpoint", type=str, choices=["direct-model", "generate-copy"], default="direct-model", help="Endpoint to test") parser.add_argument("--prompt", type=str, default="Create a welcome message for new clients", help="Prompt to generate text for") parser.add_argument("--content-type", type=str, default="email", help="Content type (for generate-copy endpoint)") parser.add_argument("--tone", type=str, default=None, help="Tone (for generate-copy endpoint)") parser.add_argument("--target-audience", type=str, default=None, help="Target audience (for generate-copy endpoint)") parser.add_argument("--max-length", type=int, default=200, help="Maximum length of the generated text (for direct-model endpoint)") parser.add_argument("--num-return-sequences", type=int, default=1, help="Number of sequences to generate (for direct-model endpoint)") parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling (for direct-model endpoint)") parser.add_argument("--top-p", type=float, default=0.9, help="Top-p for sampling (for direct-model endpoint)") args = parser.parse_args() if args.endpoint == "direct-model": test_direct_model( prompt=args.prompt, max_length=args.max_length, num_return_sequences=args.num_return_sequences, temperature=args.temperature, top_p=args.top_p ) else: test_generate_copy( prompt=args.prompt, content_type=args.content_type, tone=args.tone, target_audience=args.target_audience ) if __name__ == "__main__": main()