Taming a Llama with Rotten Tomatoes: A Vertex AI Fine-Tuning Saga

Satish Iyer
5 min readDec 20, 2024

--

I’m on a mission to learn about LLMs. It’s been three years since I touched code, but I’m jumping back in with hands-on experimentation. This series will track my progress, starting with tuning.

Movie reviews. Love ’em or hate ’em, they’re packed with emotional juice. So, I decided to teach a Llama how to squeeze it. Not just any Llama, mind you, but Llama 2, Meta’s open-source large language model that’s been making waves in the AI world. My mission? Fine-tune the 7B-chat-hf variant on the Rotten Tomatoes dataset for sentiment analysis. I wanted to see if I could build a powerful sentiment analyzer without breaking the bank (or my GPU’s back). The secret sauce? A sprinkle of 8-bit magic, a dash of LoRA, and a whole lot of help from Google Cloud’s Vertex AI.

Why Llama 2 and Rotten Tomatoes?

Llama 2, Meta’s open-source beast, has been turning heads. Its performance on various benchmarks is impressive, and the 7B-chat-hf variant seemed perfect for a fine-tuning experiment. It’s powerful enough to capture complex language patterns but not so massive that it becomes unwieldy.

And for the training ground? Rotten Tomatoes, with its binary positive/negative labels, is a classic for sentiment analysis. Plus, who doesn’t have an opinion on movies? But could I make this large language model truly understand the nuances of a scathing review versus a glowing one? This was the challenge I was itching to tackle.

The Tech Stack: Vertex AI, Dynamic Workload Scheduler, 8-bit, and LoRA — My Fine-Tuning Arsenal

For this adventure, I enlisted the help of Google Cloud’s Vertex AI. Think of it as your all-in-one ML command center. Custom job submission made it a breeze to package my code (thanks to a well-crafted Dockerfile) and launch it on powerful A100 GPUs. I created a custom docker image, and pushed it to Artifact Registry. A quick gcloud command:

Bash

gcloud ai custom-jobs create \
--region=$REGION \
--display-name=$JOB_NAME \
--worker-pool-spec=machine-type=a2-highgpu-8g,replica-count=1,accelerator-type=NVIDIA_TESLA_A100,accelerator-count=8,container-image-uri=$IMAGE_URI \
--args="--project_id=${PROJECT_ID},--location=${REGION},--hf_token=${HF_TOKEN},--output_dir=${OUTPUT_DIR}"

and Vertex AI handled the rest. worker-pool-spec defined my hardware (juicy A100 GPUs!), while --args passed crucial parameters to my training script. No more managing infrastructure headaches; Vertex AI takes care of the heavy lifting.

Bonus tip: To reduce GPU costs on Vertex AI, use Dynamic Workload Scheduler (DWS) with Flex Start mode for customer training jobs. DWS Flex Start is ideal for jobs with a timeout of 7 days or less. It queues your GPU request, provisions resources when available, and ensures uninterrupted execution for the job’s duration. This translates to discounted GPU pricing. Sweet, right? Here’s a revised gcloud command:

gcloud ai custom-jobs create \
--region=$REGION \
--display-name=$JOB_NAME \
--worker-pool-spec=machine-type=a2-highgpu-8g,replica-count=1,accelerator-type=NVIDIA_TESLA_A100,accelerator-count=8,container-image-uri=$IMAGE_URI \
--args="--project_id=${PROJECT_ID},--location=${REGION},--hf_token=${HF_TOKEN},--output_dir=${OUTPUT_DIR}"
--scheduling=strategy=FLEX_START,max-wait-duration=1800s

But fine-tuning large models can be resource-intensive. Enter 8-bit quantization with bitsandbytes. Instead of using the standard 32-bit or 16-bit representation for weights and activations, this technique shrinks them down to 8-bit, dramatically reducing memory usage. The nf8 type and double quantization offered by bitsandbytes ensured I didn't sacrifice accuracy for efficiency. Here's how I configured it:

Python

bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_quant_type="nf8",
bnb_8bit_compute_dtype=torch.bfloat16,
bnb_8bit_use_double_quant=True
)

And to make things even leaner, I used LoRA (Low-Rank Adaptation). Instead of fine-tuning every single parameter in Llama 2 (which would be, let’s face it, a lot), LoRA adds small, trainable “adapter” layers to the model, keeping the original weights frozen. It’s like giving the model a specialized add-on without a complete overhaul. This made the training process significantly faster and less memory-hungry. Here’s a peek at my LoRA configuration:

Python

peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
)

Notice the target_modules. I focused the LoRA adapters on key parts of the model architecture, like the query, value, key projections and others, allowing me to concentrate the learning where it mattered most.

The Code: A Glimpse Behind the Curtain

Let’s dive into a few crucial snippets. First, I had to get the Rotten Tomatoes dataset ready for Llama chow. Here’s how I preprocessed the data:

Python

def preprocess_data(tokenizer, examples):
inputs = [f"Classify the sentiment of this text: {text}" for text in examples["text"]]
model_inputs = tokenizer(inputs, max_length=256, truncation=True, padding="max_length")
if "label" in examples:
model_inputs["labels"] = torch.tensor(examples["label"])
else:
raise ValueError("The dataset does not contain a 'label' column.")
return model_inputs

The prompt engineering is key: Classify the sentiment of this text: {text}. I chose this format to explicitly instruct the model on its task. Tokenization with padding and truncation ensured all inputs were a uniform length for the model.

Next, loading the pre-trained Llama 2 model using the transformers library, leveraging the 8-bit configuration:

Python

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=args.hf_token,
quantization_config=bnb_config,
device_map="auto"
)

I’m using the “auto” device_map to automatically utilize the available hardware. Slick, right?

Now, the training arguments:

Python

training_args = TrainingArguments(
output_dir=args.output_dir,
evaluation_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
num_train_epochs=args.num_train_epochs,
weight_decay=args.weight_decay,
logging_dir=f'{args.output_dir}/logs',
logging_strategy="steps",
logging_steps=10,
save_strategy="epoch",
fp16=True,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=True,
remove_unused_columns=False,
)

I went with a learning_rate of 2e-5, a per_device_train_batch_size of 4, and trained for 3 num_train_epochs. To further optimize memory usage, I used gradient_accumulation_steps=4, effectively increasing the batch size without the memory overhead. gradient_checkpointing=True also played a crucial role, trading off some compute time for significant memory savings.

Finally, unleashing the Trainer:

Python

trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
)

Using DataCollatorWithPadding ensures that all sequences in a batch are padded to the same length, a requirement for efficient training.

Results and Insights: Did the Llama Learn to Love Movies?

The training process was surprisingly smooth. The loss decreased steadily, and I kept a close eye on the evaluation metrics. The model started to truly get the sentiment behind the reviews.

Here are a couple of examples:

Review: “This movie was an absolute masterpiece! The acting, the cinematography, the story — everything was perfect.”

  • Llama’s Prediction: Positive (Correct!)

Review: “What a waste of time. The plot was predictable, the characters were unlikeable, and the ending was utterly disappointing.”

  • Llama’s Prediction: Negative (Nailed it!)

Of course, it wasn’t perfect. The model occasionally stumbled on more nuanced or sarcastic reviews, which is a common challenge in sentiment analysis. But overall, the results were very promising.

This wasn’t just about getting a high accuracy score. It was about understanding how to efficiently fine-tune a powerful LLM for a specific task, making the most of available resources. I learned a ton about quantization, LoRA, and the intricacies of Vertex AI.

Conclusion: The Adventure Continues

So, there you have it. We tamed a Llama, fed it Rotten Tomatoes, and taught it to understand the fickle world of movie reviews — all while being mindful of our computational budget. The beauty of open-source models like Llama 2 is that anyone can jump in and experiment. I encourage you to try fine-tuning it on your own datasets!

Perhaps next, I’ll explore how to serve an LLM on Vertex AI. Got suggestions? Drop them in the comments!

Get the code from my repo.

--

--

No responses yet