Gemma 3 Reasoning Fine-Tuning Crash Course #1: Medical Reasoning
Step-by-Step Guide on Gemma 3 Reasoning Fine-Tuning
Google’s new open-source model family, Gemma 3, is quickly gaining attention for its remarkable performance, rivaling that of some of the latest proprietary models.
With advanced multimodal features, improved reasoning abilities, and support for over 140 languages, Gemma 3 stands out as a versatile tool for various AI applications.
In this tutorial, we will dive into the capabilities of Gemma 3 and demonstrate how to fine-tune it using a medical reasoning question-answering dataset.
This fine-tuning process will enhance the model’s ability to accurately understand, reason, and respond to complex medical questions, providing contextually relevant and precise answers.
Table of Contents:
Introduction to Gemma 3 LLM
Setting Up the Working Environment
Loading the Model and Tokenizer
Test the Model with Zero-Shot Inference
Loading and Processing the Dataset
Setting up the Model Training Pipeline
Model Fine-Tuning with LoRA
Saving the Model and Tokenizer to Hugging Face
Model Inference After Fine-Tuning
My New E-Book: LLM Roadmap from Beginner to Advanced Level
I am pleased to announce that I have published my new ebook LLM Roadmap from Beginner to Advanced Level. This ebook will provide all the resources you need to start your journey towards mastering LLMs.
1. Introduction to Gemma 3 LLM
Gemma 3 is the latest in a series of open models designed to make powerful AI more accessible and easier to deploy. Built on the same research foundation as Google’s Gemini 2.0 models, it delivers strong performance while staying lightweight and resource-efficient.
Available in multiple sizes — from 1B to 27B parameters — Gemma 3 fits a wide range of hardware setups, making it practical for real-world use. Despite its compact design, it consistently ranks at the top of human preference benchmarks like LMArena, outperforming models such as Llama3–405B, DeepSeek-V3, and o3-mini, proving that efficiency doesn’t have to come at the cost of capability.
Some of the key features of Gemma 3 models:
Multilingual support: Built-in support for 35 languages, with pretraining coverage for over 140 languages.
Long context window: 128k-token context window, ideal for handling long documents and complex reasoning tasks.
Multimodal capabilities: Supports reasoning across text, images, and short videos.
Structured output: Native support for function calling and structured outputs.
Quantized models: Official quantized versions available to reduce model size and computational requirements.
Hardware integration: Compatible with CPUs, NVIDIA and AMD GPUs, Google Cloud TPUs.
Integration with popular tools: Works with Hugging Face Transformers, PyTorch, Keras, JAX, Google AI Edge, and vLLM.
ShieldGemma 2: Includes built-in image safety checks for harmful or explicit content to ensure responsible AI usage.
2. Setting Up the Working Environment
We will start with installing and updating all the essential libraries needed to fine-tune and run inference with Google’s Gemma-3 model using Hugging Face tools. This ensures compatibility by updating transformers, datasets, accelerate, peft, trl, and bitsandbytes to their latest versions, which are required for efficient training and memory-optimized model loading.
Finally, we will install a specific development branch of transformers that includes early support for Gemma-3, making sure the environment is ready for both fine-tuning with LoRA and generating responses from the model.
%%capture
!pip install -U datasets
!pip install -U transformers
!pip install -U accelerate
!pip install -U peft
!pip install -U trl
!pip install -U bitsandbytes
!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
Next, we authenticate with the Hugging Face Hub to access pre-trained models and datasets. We retrieve the authentication token from kaggle secrets and log in using the huggingface_hub package:
from huggingface_hub import login
from google.colab import userdata
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(hf_token)
This step allows seamless interaction with Hugging Face resources, ensuring smooth model loading and saving. Now we are ready to download the model and tokenizer and try it with zero-shot inference.
3. Loading the Model and Tokenizer
To begin working with the Gemma 3 model for instruction tuning and reasoning tasks, we will first load the base model and tokenizer. We’re using the google/gemma–3–4b–it variant, which is the instruction-tuned 4B parameter version of Gemma, suitable for tasks like question answering or multi-step reasoning fine-tuning.
The model is instantiated using Gemma3ForConditionalGeneration.from_pretrained, specifying torch_dtype=torch.bfloat16 to optimize memory usage and inference efficiency without a significant loss in numerical precision.
Setting device_map=”auto” ensures the model is automatically distributed across available devices, typically GPUs, which is essential for handling larger models like this one.
The attn_implementation=’eager’ argument is required for Gemma, as it uses a specific attention mechanism that relies on eager execution rather than FlashAttention.
Finally, the .eval() call puts the model in inference mode, which disables dropout and other training-specific behavior. The corresponding tokenizer is loaded using AutoTokenizer.from_pretrained, ensuring that the input text is properly tokenized to match the model’s vocabulary and expected format.
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
import torch
GEMMA_PATH = "google/gemma-3-4b-it"
# Modified model loading
model = Gemma3ForConditionalGeneration.from_pretrained(
GEMMA_PATH,
torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
device_map="auto",
attn_implementation='eager'
).eval()
tokenizer = AutoTokenizer.from_pretrained(GEMMA_PATH)
Now that we have loaded the model, we will first test it without any fine-tuning.
4. Test the Model with Zero-Shot Inference
With the model and tokenizer ready, we can now run zero-shot inference on a clinical reasoning question. The input is a medically detailed prompt describing a patient with symptoms consistent with stress urinary incontinence. This kind of prompt tests the model’s ability to reason about likely diagnostic outcomes — in this case, what cystometry would reveal about post-void residual volume and detrusor activity.
The question is tokenized using the tokenizer, which converts the raw string into model-readable token IDs and attention masks. We use the return_tensors=”pt” argument to ensure the output is in PyTorch tensor format, and .to(model.device) moves the input to the same device (GPU or CPU) as the model for inference.
To generate an answer, we invoke the model’s generate method inside a torch.no_grad() block, which disables gradient tracking and reduces memory usage during inference.
The generate function takes in the tokenized input and produces output token IDs, using max_new_tokens=200 to limit the length of the response. Setting do_sample=False ensures deterministic output, meaning the model will always produce the same answer for the same input. The temperature=0.7 slightly controls the randomness of token selection, but since sampling is off, this parameter has minimal effect here.
Finally, we decode the model’s output IDs back into human-readable text using tokenizer.decode, skipping any special tokens that the model might generate (like <pad> or <eos>).
The resulting answer is printed and should ideally reflect the model’s learned reasoning patterns from either its base instruction tuning or your fine-tuned dataset, if applicable.
question = "A 55-year-old woman presents with a long history of urine leakage triggered by physical activities such as laughing or lifting heavy objects. She denies experiencing nocturnal enuresis or urgency. A gynecological exam and Q-tip test are performed. Based on these findings, what would cystometry most likely reveal regarding her post-void residual volume and detrusor muscle activity?"
# model zero shot infernce
# Tokenize the question
inputs = tokenizer(question, return_tensors="pt").to(model.device)
# Generate a response
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False, # deterministic output
temperature=0.7, # optional: for more creative answers
)
# Decode and print the response
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(response)
A 55-year-old woman presents with a long history of urine leakage triggered by physical activities such as laughing or lifting heavy objects. She denies experiencing nocturnal enuresis or urgency. A gynecological exam and Q-tip test are performed. Based on these findings, what would cystometry most likely reveal regarding her post-void residual volume and detrusor muscle activity?
A. Elevated post-void residual volume and decreased detrusor muscle contractility.
B. Normal post-void residual volume and normal detrusor muscle contractility.
C. Decreased post-void residual volume and increased detrusor muscle contractility.
D. Elevated post-void residual volume and increased detrusor muscle contractility.The correct answer is **A. Elevated post-void residual volume and decreased detrusor muscle contractility.**
Here’s why:
* **Patient Presentation:** The patient’s symptoms (stress urinary incontinence) strongly suggest a problem with the support structures of the urethra, rather than a primary bladder issue. The negative gynecological exam and Q-tip test further support this.
* **Stress Urinary Incontinence (SUI):** SUI is characterized by leakage during activities that increase intra-abdominal pressure. This pressure can lead to bladder overdistension and incomplete bladder emptying.
We can see that the model response, although it has some explanation but it is not reasoning before answering the question. Moreover, the final answer appears in a bullet-point format, which does not align with the structure and style of the target dataset for fine-tuning.
5. Loading and Processing the Dataset
When fine-tuning a language model for reasoning, structuring the training dataset into a reasoning response is an important step before fine-tuning the model.
This section covers how we define a structured prompt format, preprocess dataset entries, and apply transformations before feeding data into the model.
To ensure the model generates structured medical responses, we define a prompt template that includes an instruction, a medical question, and a structured reasoning process. The template follows this format:
train_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
### Instruction:
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.
Please answer the following medical question.
### Question:
{}
### Response:
<think>{}"""
This template structures the model’s response by incorporating a chain of thought (CoT) inside <think></think> tags. The {} placeholders are dynamically replaced with a medical question, reasoning process, and final response.