105 lines
3.2 KiB
Python
Executable File
105 lines
3.2 KiB
Python
Executable File
"""
|
|
Test script for the Ollama API with custom prompts.
|
|
"""
|
|
|
|
import requests
|
|
import json
|
|
import argparse
|
|
import time
|
|
|
|
class OllamaRequest:
|
|
"""Request model for Ollama API."""
|
|
def __init__(self, model, prompt, system_prompt=None):
|
|
self.model = model
|
|
self.prompt = prompt
|
|
self.system_prompt = system_prompt or "You are a helpful assistant."
|
|
|
|
def to_json(self):
|
|
"""Convert to JSON for the API request."""
|
|
messages = []
|
|
|
|
# Add system message
|
|
messages.append({"role": "system", "content": self.system_prompt})
|
|
|
|
# Add user message
|
|
messages.append({"role": "user", "content": self.prompt})
|
|
|
|
return {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False
|
|
}
|
|
|
|
def test_ollama_api(api_url, model, prompt, system_prompt=None):
|
|
"""Test the Ollama API directly."""
|
|
print(f"Testing Ollama API at: {api_url}")
|
|
print(f"Model: {model}")
|
|
print(f"Prompt: {prompt}")
|
|
if system_prompt:
|
|
print(f"System prompt: {system_prompt}")
|
|
|
|
# Create the request
|
|
request = OllamaRequest(model, prompt, system_prompt)
|
|
request_json = request.to_json()
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Make the API call
|
|
response = requests.post(
|
|
api_url,
|
|
headers={"Content-Type": "application/json"},
|
|
json=request_json,
|
|
timeout=300 # 5 minutes timeout
|
|
)
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
|
|
response.raise_for_status()
|
|
|
|
# Parse the response
|
|
result = response.json()
|
|
|
|
print(f"\nResponse (took {elapsed_time:.2f} seconds):")
|
|
print(json.dumps(result, indent=2))
|
|
|
|
# Extract the content
|
|
if 'message' in result:
|
|
content = result['message'].get('content', 'No content in response')
|
|
print("\nContent:")
|
|
print(content)
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Ollama API test failed: {str(e)}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main function."""
|
|
parser = argparse.ArgumentParser(description='Test the Ollama API with custom prompts')
|
|
parser.add_argument('--api-url', type=str, default='http://localhost:11434/api/chat', help='Ollama API URL')
|
|
parser.add_argument('--model', type=str, default='llama3.1', help='Model to use for testing')
|
|
parser.add_argument('--prompt', type=str, default='What is the capital of France?', help='Prompt to use for testing')
|
|
parser.add_argument('--system-prompt', type=str, default=None, help='System prompt to use for testing')
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("=== Ollama API Test ===")
|
|
print(f"API URL: {args.api_url}")
|
|
print(f"Model: {args.model}")
|
|
print(f"Prompt: {args.prompt}")
|
|
if args.system_prompt:
|
|
print(f"System prompt: {args.system_prompt}")
|
|
print()
|
|
|
|
# Test Ollama API
|
|
success = test_ollama_api(args.api_url, args.model, args.prompt, args.system_prompt)
|
|
|
|
# Print summary
|
|
print("\n=== Test Summary ===")
|
|
print(f"Ollama API Test: {'SUCCESS' if success else 'FAILED'}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|