发布于 2026-01-06 9 阅读
0

逐步指南,助您微调 MedGemma 以进行乳腺肿瘤分类 DEV 的全球展示挑战赛,由 Mux 呈现:展示您的项目!

逐步指南:如何微调 MedGemma 以进行乳腺肿瘤分类

由 Mux 主办的 DEV 全球展示挑战赛:展示你的项目!


免责声明:本指南仅供参考和教育用途,不能替代专业的医疗建议、诊断或治疗。MedGemma
不应在未经适当验证、调整和/或开发者针对其特定用例进行有意义的修改的情况下使用。MedGemma 生成的输出结果并非旨在直接用于临床诊断、患者管理决策、治疗建议或任何其他直接的临床实践应用。性能基准测试突出了相关基准测试中的基线能力,但即使对于构成大量训练数据的图像和文本领域,模型输出仍可能不准确。MedGemma 的所有输出结果均应视为初步结果,需要通过既定的研发方法进行独立验证、临床关联分析和进一步研究。

人工智能(AI)正在革新医疗保健行业,但如何才能让功能强大的通用AI模型掌握病理学家的专业技能呢?从原型到生产的这段旅程通常始于笔记本,而这正是我们将要开始的地方。

在本指南中,我们将迈出至关重要的第一步。我们将逐步完成对Gemma 3 变体MedGemma 的微调过程。MedGemma 是谷歌面向医学界推出的一系列开放模型,用于对乳腺癌组织病理图像进行分类。我们使用全精度 MedGemma 模型,因为这是在许多临床任务中获得最佳性能所必需的。如果您担心计算成本,可以使用MedGemma 预配置的微调 notebook进行量化和微调

为了完成第一步,我们将使用Finetune Notebook。该 Notebook 提供了所有代码以及分步操作说明,是进行实验的理想环境。我还会分享我在此过程中学到的关键见解,包括一个至关重要的数据类型选择,它最终产生了决定性的影响。

在原型设计阶段完善模型后,我们就可以进入下一步了。在接下来的文章中,我们将向您展示如何使用Cloud Run 作业将这一工作流程迁移到可扩展的、可用于生产环境的环境中。

铺垫:我们的目标、模型和数据

在深入代码之前,我们先来了解一下背景。我们的目标是将乳腺组织的显微镜图像分类为八类:四类良性(非癌性)和四类恶性(癌性)。这种分类是病理学家为做出准确诊断而执行的众多关键任务之一,而我们拥有一套强大的工具来完成这项工作。

我们将使用MedGemma,这是谷歌推出的一系列强大的开放模型,它基于与 Gemini 模型相同的研究和技术构建而成。MedGemma 的独特之处在于,它并非通用模型,而是专门针对医疗领域进行了优化。

MedGemma 的视觉组件MedSigLIP已使用大量去标识化的医学图像进行预训练,其中包括我们正在使用的组织病理切片类型。如果您不需要 MedGemma 的预测能力,可以单独使用 MedSigLIP,它是一种更经济高效的预测任务选择,例如图像分类。您可以使用多个MedSigLIP 教程笔记本进行微调。

MedGemma 语言组件也经过了各种医学文本的训练,因此我们使用的google/medgemma-4b-it版本非常适合我们基于文本的提示。Google 为 MedGemma 提供了一个强大的基础,但它需要针对特定​​用例进行微调——而这正是我们即将要做的。

为了训练我们的模型,我们将使用乳腺癌组织病理图像分类(BreakHis)数据集。BreakHis 数据集是一个公开的数据集,包含数千张乳腺肿瘤组织的显微镜图像,这些图像来自 82 位患者,并使用了不同的放大倍数(40 倍、100 倍、200 倍和 400 倍)。该数据集可公开用于非商业研究,详情请参阅论文:FA Spanhol、LS Oliveira、C. Petitjean 和 L. Heudel,《乳腺癌组织病理图像分类数据集》。 <sup> 1</sup>

处理一个包含 40 亿个参数的模型需要强大的 GPU,因此我Vertex AI Workbench中使用了配备40 GB显存的NVIDIA A100。这款 GPU 拥有足够的性能,并且还配备了NVIDIA Tensor Core,能够出色地处理现代数据格式,我们将利用这一点来加快训练速度。在后续文章中,我们将解释如何计算微调所需的显存。

我的 float16 灾难:稳定性方面的重要教训

我第一次尝试加载模型时为了节省内存使用了常见的float16数据类型,结果惨败。模型的输出完全是乱码,快速调试后发现所有内部值都变成了NaN(非数字)

罪魁祸首是经典的数值溢出

要理解其中的原因,你需要了解这些 16 位格式之间的关键区别:

  • float16 (FP16) 的数值范围非常小,无法表示大于 65,504 的任何数字。在转换器执行数百万次计算的过程中,中间值很容易超过这个限制,导致溢出并产生 NaN 值。一旦出现 NaN 值,就会污染后续的所有计算。
  • bfloat16 (BF16):这种格式由谷歌大脑开发,它做出了一个关键的权衡。它牺牲了一些精度,以保持与完整的 32 位 float32 格式相同的巨大数值范围。

bfloat16 的超大取值范围可以防止溢出,从而保证训练过程的稳定性。修复方法虽然只是简单地修改了一行代码,但却是基于这一关键概念。

成功代码:

# The simple, stable solution
model_kwargs = dict(
    torch_dtype=torch.bfloat16,  # Use bfloat16 for its wide numerical range
    device_map="auto",
    attn_implementation="sdpa",
)


model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
Enter fullscreen mode Exit fullscreen mode

经验教训:在微调大型模型时,始终优先选择bfloat16,因为它更稳定。这个小小的改动可以让你避免很多与 NaN 值相关的麻烦。

代码详解:分步指南

现在,让我们来看代码。我会将我的Finetune Notebook分解成清晰、合乎逻辑的步骤。

第一步:设置和安装

首先,您需要从 Hugging Face 生态系统安装必要的库,然后登录您的帐户以下载模型。

# Install required packages
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn

import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
Enter fullscreen mode Exit fullscreen mode

Hugging Face 身份验证以及处理密钥的推荐方法

⚠️ 重要安全提示:切勿将 API 密钥或令牌等敏感信息直接硬编码到代码或笔记本中,尤其是在生产环境中。这种做法不安全,会造成严重的安全风险。

在 Vertex AI Workbench 中,处理密钥(例如您的 Hugging Face 令牌)最安全、企业级的方法是使用 Google Cloud 的Secret Manager

如果您只是在进行实验,暂时不想设置 Secret Manager,可以使用交互式登录小部件。该小部件会将令牌临时保存在实例的文件系统中。

# Hugging Face authentication using interactive login widget:
from huggingface_hub import notebook_login
notebook_login()
Enter fullscreen mode Exit fullscreen mode

在即将发布的文章中,我们将把这个过程迁移到 Cloud Run Jobs,并向您展示使用 Secret Manager 来处理此令牌的正确和安全方法。

步骤 2:加载并准备数据集

接下来,我们使用库从 Kaggle下载BreakHis 数据kagglehub集。该数据集包含一个Folds.csv文件,其中概述了实验的数据划分方式。原始研究使用了 5 折交叉验证,但为了控制本次演示的训练时间,我们将重点关注第一折,并且仅使用 100 倍放大倍率的图像。您可以尝试使用其他折数和放大倍率进行更广泛的实验。

! pip install -q kagglehub
import kagglehub
import os
import pandas as pd
from PIL import Image
from datasets import Dataset, Image as HFImage, Features, ClassLabel

# Download the dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv('{}/Folds.csv'.format(path))

# Filter for 100X magnification from the first fold
folds_100x = folds[folds['mag']==100]
folds_100x = folds_100x[folds_100x['fold']==1]

# Get the train/test splits
folds_100x_test = folds_100x[folds_100x.grp=='test']
folds_100x_train = folds_100x[folds_100x.grp=='train']

# Define the base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4/BreaKHis_v1"
Enter fullscreen mode Exit fullscreen mode

步骤 2.1:平衡数据集

初始的100倍放大倍率训练集和测试集划分显示良性和恶性样本数量不平衡。为了解决这个问题,我们将对训练集和测试集中数量较多的样本进行欠采样,以创建良恶性样本数量比例为50/50的平衡数据集。

# --- 1. Create Balanced TRAIN Set ---
train_benign_df = folds_100x_train[folds_100x_train['filename'].str.contains('benign')]
train_malignant_df = folds_100x_train[folds_100x_train['filename'].str.contains('malignant')]
min_train_count = min(len(train_benign_df), len(train_malignant_df))
balanced_train_benign = train_benign_df.sample(n=min_train_count, random_state=42)
balanced_train_malignant = train_malignant_df.sample(n=min_train_count, random_state=42)
balanced_train_df = pd.concat([balanced_train_benign, balanced_train_malignant])
# --- 2. Create Balanced TEST Set ---
test_benign_df = folds_100x_test[folds_100x_test['filename'].str.contains('benign')]
test_malignant_df = folds_100x_test[folds_100x_test['filename'].str.contains('malignant')]
min_test_count = min(len(test_benign_df), len(test_malignant_df))
balanced_test_benign = test_benign_df.sample(n=min_test_count, random_state=42)
balanced_test_malignant = test_malignant_df.sample(n=min_test_count, random_state=42)
balanced_test_df = pd.concat([balanced_test_benign, balanced_test_malignant])
# --- 3. Get the Final Filename Lists ---
train_filenames = balanced_train_df['filename'].values
test_filenames = balanced_test_df['filename'].values
print(f"Balanced Train: {len(train_filenames)} files")
print(f"Balanced Test: {len(test_filenames)} files")
Enter fullscreen mode Exit fullscreen mode

步骤 2.2:创建拥抱脸数据集

我们将数据转换为 Hugging Facedatasets格式,因为这是使用其 Transformers 库的最简便方法SFTTrainer。这种格式针对处理大型数据集(尤其是图像)进行了优化,因为它可以在需要时高效加载数据。此外,它还为我们提供了便捷的预处理工具,例如将我们的格式化函数应用于所有示例。

CLASS_NAMES = [
    'benign_adenosis', 'benign_fibroadenoma', 'benign_phyllodes_tumor',
    'benign_tubular_adenoma', 'malignant_ductal_carcinoma',
    'malignant_lobular_carcinoma', 'malignant_mucinous_carcinoma',
    'malignant_papillary_carcinoma'
]
def get_label_from_filename(filename):
     filename = filename.replace('\\', '/').lower()
     if '/adenosis/' in filename: return 0
     if '/fibroadenoma/' in filename: return 1
     if '/phyllodes_tumor/' in filename: return 2
     if '/tubular_adenoma/' in filename: return 3
     if '/ductal_carcinoma/' in filename: return 4
     if '/lobular_carcinoma/' in filename: return 5
     if '/mucinous_carcinoma/' in filename: return 6
     if '/papillary_carcinoma/' in filename: return 7
     return -1
train_data_dict = {
    'image': [os.path.join(BASE_PATH, f) for f in train_filenames],
    'label': [get_label_from_filename(f) for f in train_filenames]
}
test_data_dict = {
    'image': [os.path.join(BASE_PATH, f) for f in test_filenames],
    'label': [get_label_from_filename(f) for f in test_filenames]
}
features = Features({
    'image': HFImage(),
    'label': ClassLabel(names=CLASS_NAMES)
})
train_dataset = Dataset.from_dict(train_data_dict, features=features).cast_column("image", HFImage())
eval_dataset = Dataset.from_dict(test_data_dict, features=features).cast_column("image", HFImage())
print(train_dataset)
print(eval_dataset)
Enter fullscreen mode Exit fullscreen mode

步骤 3:快速工程

在这一步,我们需要告诉模型我们想要它做什么。我们创建一个清晰、结构化的提示,指示模型分析图像并仅返回与类别对应的数字。此提示使输出简洁易懂。然后,我们将此格式映射到整个数据集。

# Define the instruction prompt
PROMPT = """Analyze this breast tissue histopathology image and classify it.
​
Classes (0-7):
0: benign_adenosis
1: benign_fibroadenoma
2: benign_phyllodes_tumor
3: benign_tubular_adenoma
4: malignant_ductal_carcinoma
5: malignant_lobular_carcinoma
6: malignant_mucinous_carcinoma
7: malignant_papillary_carcinoma
​
Answer with only the number (0-7):"""

def format_data(example):
    """Format examples into the chat-style messages MedGemma expects."""
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": PROMPT},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": str(example["label"])},
            ],
        },
    ]
    return example

# Apply formatting
formatted_train = train_dataset.map(format_data, batched=False)
formatted_eval = eval_dataset.map(format_data, batched=False)

print("✓ Data formatted with instruction prompts")
Enter fullscreen mode Exit fullscreen mode

步骤 4:加载模型和处理器

在这里,我们加载 MedGemma 模型及其关联的处理器。该处理器是一个便捷的工具,用于准备模型所需的图像和文本。我们还将选择两个关键参数以提高效率:

  • torch_dtype=torch.bfloat16正如我们之前提到的,这种格式可以确保数值稳定性。
  • attn_implementation="sdpa"缩放点积注意力机制是 PyTorch 2.0 中提供的一种高度优化的注意力机制。你可以把它理解为告诉模型使用一个超快的内置引擎来执行最重要的计算。它能加速训练和推理,如果你的硬件支持,它甚至可以自动使用更高级的后端,例如 FlashAttention。
MODEL_ID = "google/medgemma-4b-it"

# Model configuration
model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Ensure right padding for training
processor.tokenizer.padding_side = "right"
Enter fullscreen mode Exit fullscreen mode

步骤 5:评估基线模型

在投入时间和计算资源进行微调之前,我们先来看看预训练模型自身的性能如何。这一步可以为我们提供一个基准,以便衡量我们改进的效果。

# Helper functions to run evaluation
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(predictions, references):
    return {
        **accuracy_metric.compute(predictions=predictions, references=references),
        **f1_metric.compute(predictions=predictions, references=references, average="weighted")
    }

def postprocess_prediction(text):
    """Extract just the number from the model's text output."""
    digit_match = re.search(r'\b([0-7])\b', text.strip())
    return int(digit_match.group(1)) if digit_match else -1

def batch_predict(model, processor, prompts, images, batch_size=8, max_new_tokens=40):
    """A function to run inference in batches."""
    predictions = []
    for i in range(0, len(prompts), batch_size):
        batch_texts = prompts[i:i + batch_size]
        batch_images = [[img] for img in images[i:i + batch_size]]

        inputs = processor(text=batch_texts, images=images, padding=True, return_tensors="pt").to("cuda", torch.bfloat16)
        prompt_lengths = inputs["attention_mask"].sum(dim=1)

        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=processor.tokenizer.pad_token_id)

        for seq, length in zip(outputs, prompt_lengths):
            generated = processor.decode(seq[length:], skip_special_tokens=True)
            predictions.append(postprocess_prediction(generated))

    return predictions

# Prepare data for evaluation
eval_prompts = [processor.apply_chat_template([msg[0]], add_generation_prompt=True, tokenize=False) for msg in formatted_eval["messages"]]
eval_images = formatted_eval["image"]
eval_labels = formatted_eval["label"]

# Run baseline evaluation
print("Running baseline evaluation...")
baseline_preds = batch_predict(model, processor, eval_prompts, eval_images)
baseline_metrics = compute_metrics(baseline_preds, eval_labels)

print(f"\n{'BASELINE RESULTS':-^80}")
print(f"Accuracy: {baseline_metrics['accuracy']:.1%}")
print(f"F1 Score: {baseline_metrics['f1']:.3f}")
print("-" * 80)
Enter fullscreen mode Exit fullscreen mode

我们对基线模型在 8 类分类和二元(良性/恶性)分类上的性能进行了评估:

  • 8类分类准确率:32.6%
  • 8 类 F1 得分(加权):0.241
  • 二元准确率:59.6%
  • 二元F1评分(恶性):0.639

该输出结果表明,该模型的性能优于随机猜测(12.5%),但仍有很大的改进空间,尤其是在细粒度的 8 类分类方面。

顺便提一下:少样本学习与微调

在开始训练之前,值得问一下:微调是唯一的方法吗?另一种流行的技术是少样本学习

小样本学习就像在考试前给聪明的学生几个新数学题的例子。你不是在重新教他们代数,而是在题目中直接提供例子,向他们展示你希望他们遵循的特定模式。这是一种强大的技巧,尤其是在使用通过 API 实现的封闭模型,而无法访问内部权重时。

那么,我们为什么选择微调呢?

  1. 我们可以托管该模型:由于 MedGemma 是一个开放模型,我们可以直接访问其架构。这种访问权限使我们能够进行微调,从而创建一个新的、永久更新的模型版本。
  2. 我们拥有一个很好的数据集:微调可以让模型学习数百张训练图像中深层、潜在的模式,比仅仅在提示中向它展示几个例子要有效得多。

简而言之,微调为我们的任务创建了一个真正的专业模型,这正是我们所想要的。

步骤 6:配置并运行 LoRa 微调

这才是重头戏!我们将使用低秩自适应(LoRA),它比传统的微调方法速度更快、内存效率更高。LoRA 的工作原理是冻结原始模型权重,仅训练一小部分新的适配器权重。以下是我们的参数选择明细:

  • r=8LoRa 等级。等级越低,可训练参数越少,速度越快,但表达能力越弱。等级越高,容量越大,但在小数据集上更容易过拟合。等级 8 是一个很好的起点,兼顾了性能和效率。
  • lora_alpha=16:LoRA权重的缩放因子。一个常用的经验法则是将其设置为秩的两倍(2 × r)。
  • lora_dropout=0.1:一种正则化技术。它在训练过程中随机停用一些LoRA神经元,以防止模型过度专业化而丧失泛化能力。
# LoRA Configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

# Custom data collator to handle images and text
def collate_fn(examples):
    texts, images = [], []
    for example in examples:
        images.append([example["image"]])
        texts.append(processor.apply_chat_template(example["messages"], add_generation_prompt=False, tokenize=False).strip())
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
    batch["labels"] = labels
    return batch

# Training arguments
training_args = SFTConfig(
    output_dir="medgemma-breastcancer-finetuned",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    learning_rate=5e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,  # Warm up LR for first 3% of training
    max_grad_norm=0.3,  # Clip gradients to prevent instability
    bf16=True,  # Use bfloat16 precision
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    eval_strategy="epoch",
    push_to_hub=False,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"], 
)

# Initialize and run the trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_train,
    eval_dataset=formatted_eval,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("Starting training...")
trainer.train()
trainer.save_model()
Enter fullscreen mode Exit fullscreen mode

在配备 40 GB 显存的 A100 GPU 上进行训练大约耗时80 分钟。结果看起来很有希望,验证损失稳步下降。

重要提示(节省时间!):如果您的训练因任何原因(例如连接问题或超出资源限制)中断,您可以使用参数从已保存的检查点恢复训练过程resume_from_checkpointtrainer.train()检查点可以节省您宝贵的时间,因为它们会save_steps按照定义中定义的时间间隔进行保存TrainingArguments

步骤 7:最终结论——评估我们微调后的模型

训练完成后,就到了检验结果的时刻。我们将加载新的 LoRa 适配器权重,将其与基础模型合并,然后运行与基线模型相同的评估。

# Clear memory and load the final model
del model
torch.cuda.empty_cache()
gc.collect()

# Load base model again
base_model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

# Load LoRA adapters and merge them into a single model
finetuned_model = PeftModel.from_pretrained(base_model, training_args.output_dir)
finetuned_model = finetuned_model.merge_and_unload()

# Configure for generation
finetuned_model.generation_config.max_new_tokens = 50
finetuned_model.generation_config.pad_token_id = processor_finetuned.tokenizer.pad_token_id
finetuned_model.config.pad_token_id = processor_finetuned.tokenizer.pad_token_id

# Load the processor and run evaluation
processor_finetuned = AutoProcessor.from_pretrained(training_args.output_dir)
finetuned_preds = batch_predict(finetuned_model, processor_finetuned, eval_prompts, eval_images, batch_size=4)
finetuned_metrics = compute_metrics(finetuned_preds, eval_labels)
Enter fullscreen mode Exit fullscreen mode

最终结果

那么,微调对性能有何影响呢?让我们来看看8级精度和宏F1的数据。

--- 8-Class Classification (0-7) ---
Model                Accuracy     F1 (Weighted)
-----------------------------------------------
Baseline                  32.6%         0.241
Fine-tuned                87.2%         0.865
-----------------------------------------------
--- Binary (Benign/Malignant) Classification ---
Model                Accuracy     F1 (Malignant)
-----------------------------------------------
Baseline                  59.6%         0.639
Fine-tuned                99.0%         0.991
-----------------------------------------------
Enter fullscreen mode Exit fullscreen mode

结果非常棒!经过微调后,我们看到了显著的提升:

  • 8 类:准确率从 32.6% 跃升至 87.2%(+54.6%),F1 值从 0.241 跃升至 0.865。
  • 二进制:准确率从 59.6% 提高到 99.0%(+39.4%),F1 值从 0.639 提高到 0.991。

这个项目展现了微调现代基础模型的强大能力。我们选取了一个已经基于相关医学数据预训练的通用人工智能模型,为其提供一个小型、专门的数据集,并以惊人的效率教会它一项新技能。从通用模型到专门分类器的转变比以往任何时候都更加便捷,这为人工智能在医学及其他领域的应用开辟了令人兴奋的可能性。

所有信息都可以在Finetune Notebook中找到。您可以使用Vertex AI Workbench上的 GPU 实例运行它

想把它投入生产环境吗?别忘了关注即将发布的文章,文章将向您展示如何将微调和评估功能引入Cloud Run 作业

希望这篇指南对您有所帮助。祝您编程愉快!

特别感谢 MedGemma 团队的 Fereshteh Mahvar 和 Dave Steiner 对本文提出的宝贵意见和反馈。


1 IEEE生物医学工程学报,第63卷,第7期,第1455-1462页,2016年

文章来源:https://dev.to/googleai/a-step-by-step-guide-to-fine-tuning-medgemma-for-breast-tumor-classification-35af