Files
ds_task_marketing_assistant…/test_model.py
T

63 lines
2.0 KiB
Python
Raw Normal View History

import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_text(model_path, prompt, max_length=100, num_return_sequences=1, temperature=0.7):
"""Generate text using the finetuned model."""
# Load the finetuned model and tokenizer
print(f"Loading model from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
# Format the prompt
formatted_prompt = f"Prompt: {prompt}\nCompletion:"
# Generate text
print(f"Generating text for prompt: {prompt}")
inputs = tokenizer(formatted_prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=num_return_sequences,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode and return the generated text
generated_texts = []
for output in outputs:
generated_text = tokenizer.decode(output, skip_special_tokens=True)
# Extract just the completion part
completion = generated_text.split("Completion:")[-1].strip()
generated_texts.append(completion)
return generated_texts
def main():
class Args:
def __init__(self):
self.model_path = "finetuned_model" # Default path to the finetuned model
self.prompt = "Create a welcome message for new clients" # Default prompt
self.max_length = 100
self.num_return_sequences = 1
self.temperature = 0.7
args = Args()
# Generate text
generated_texts = generate_text(
args.model_path,
args.prompt,
args.max_length,
args.num_return_sequences,
args.temperature
)
# Print the generated text
print("\nGenerated text:")
for i, text in enumerate(generated_texts):
print(f"\n--- Generation {i+1} ---")
print(text)
if __name__ == "__main__":
main()