63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
|
|
import argparse
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
def generate_text(model_path, prompt, max_length=100, num_return_sequences=1, temperature=0.7):
|
||
|
|
"""Generate text using the finetuned model."""
|
||
|
|
# Load the finetuned model and tokenizer
|
||
|
|
print(f"Loading model from {model_path}")
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||
|
|
|
||
|
|
# Format the prompt
|
||
|
|
formatted_prompt = f"Prompt: {prompt}\nCompletion:"
|
||
|
|
|
||
|
|
# Generate text
|
||
|
|
print(f"Generating text for prompt: {prompt}")
|
||
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt")
|
||
|
|
outputs = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_length=max_length,
|
||
|
|
num_return_sequences=num_return_sequences,
|
||
|
|
temperature=temperature,
|
||
|
|
do_sample=True,
|
||
|
|
pad_token_id=tokenizer.eos_token_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode and return the generated text
|
||
|
|
generated_texts = []
|
||
|
|
for output in outputs:
|
||
|
|
generated_text = tokenizer.decode(output, skip_special_tokens=True)
|
||
|
|
# Extract just the completion part
|
||
|
|
completion = generated_text.split("Completion:")[-1].strip()
|
||
|
|
generated_texts.append(completion)
|
||
|
|
|
||
|
|
return generated_texts
|
||
|
|
|
||
|
|
def main():
|
||
|
|
class Args:
|
||
|
|
def __init__(self):
|
||
|
|
self.model_path = "finetuned_model" # Default path to the finetuned model
|
||
|
|
self.prompt = "Create a welcome message for new clients" # Default prompt
|
||
|
|
self.max_length = 100
|
||
|
|
self.num_return_sequences = 1
|
||
|
|
self.temperature = 0.7
|
||
|
|
|
||
|
|
args = Args()
|
||
|
|
|
||
|
|
# Generate text
|
||
|
|
generated_texts = generate_text(
|
||
|
|
args.model_path,
|
||
|
|
args.prompt,
|
||
|
|
args.max_length,
|
||
|
|
args.num_return_sequences,
|
||
|
|
args.temperature
|
||
|
|
)
|
||
|
|
|
||
|
|
# Print the generated text
|
||
|
|
print("\nGenerated text:")
|
||
|
|
for i, text in enumerate(generated_texts):
|
||
|
|
print(f"\n--- Generation {i+1} ---")
|
||
|
|
print(text)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|