feat: Implement Pinecone vector store integration
- Update config.py with Pinecone settings and model configurations - Implement VectorStore class with Pinecone backend - Add comprehensive vector operations (add, search, delete) - Set up proper error handling and metadata management - Add .gitignore for Python project
This commit is contained in:
+101
@@ -0,0 +1,101 @@
|
||||
import requests
|
||||
import json
|
||||
import argparse
|
||||
|
||||
def test_direct_model(prompt, max_length=200, num_return_sequences=1, temperature=0.7, top_p=0.9):
|
||||
"""Test the direct model inference endpoint."""
|
||||
url = "http://localhost:8000/direct-model"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"max_length": max_length,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
|
||||
print(f"Sending request to {url}")
|
||||
print(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
response = requests.post(url, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("\nResponse:")
|
||||
for i, text in enumerate(result["generated_texts"]):
|
||||
print(f"\n--- Generation {i+1} ---")
|
||||
print(text)
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
def test_generate_copy(prompt, content_type, tone=None, target_audience=None):
|
||||
"""Test the generate-copy endpoint."""
|
||||
url = "http://localhost:8000/generate-copy"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"content_type": content_type
|
||||
}
|
||||
|
||||
if tone:
|
||||
payload["tone"] = tone
|
||||
if target_audience:
|
||||
payload["target_audience"] = target_audience
|
||||
|
||||
print(f"Sending request to {url}")
|
||||
print(f"Payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
response = requests.post(url, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("\nResponse:")
|
||||
print(f"Content: {result['content']}")
|
||||
print(f"Confidence Score: {result['confidence_score']}")
|
||||
print(f"Brand Alignment Score: {result['brand_alignment_score']}")
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test the backend with the finetuned model")
|
||||
parser.add_argument("--endpoint", type=str, choices=["direct-model", "generate-copy"], default="direct-model",
|
||||
help="Endpoint to test")
|
||||
parser.add_argument("--prompt", type=str, default="Create a welcome message for new clients",
|
||||
help="Prompt to generate text for")
|
||||
parser.add_argument("--content-type", type=str, default="email",
|
||||
help="Content type (for generate-copy endpoint)")
|
||||
parser.add_argument("--tone", type=str, default=None,
|
||||
help="Tone (for generate-copy endpoint)")
|
||||
parser.add_argument("--target-audience", type=str, default=None,
|
||||
help="Target audience (for generate-copy endpoint)")
|
||||
parser.add_argument("--max-length", type=int, default=200,
|
||||
help="Maximum length of the generated text (for direct-model endpoint)")
|
||||
parser.add_argument("--num-return-sequences", type=int, default=1,
|
||||
help="Number of sequences to generate (for direct-model endpoint)")
|
||||
parser.add_argument("--temperature", type=float, default=0.7,
|
||||
help="Temperature for sampling (for direct-model endpoint)")
|
||||
parser.add_argument("--top-p", type=float, default=0.9,
|
||||
help="Top-p for sampling (for direct-model endpoint)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.endpoint == "direct-model":
|
||||
test_direct_model(
|
||||
prompt=args.prompt,
|
||||
max_length=args.max_length,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p
|
||||
)
|
||||
else:
|
||||
test_generate_copy(
|
||||
prompt=args.prompt,
|
||||
content_type=args.content_type,
|
||||
tone=args.tone,
|
||||
target_audience=args.target_audience
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user