Classification: finetune

In this notebook we illustrate how to re-train the models on user’s data. Specifically, we remap the last layer of the model to the desired classes, without modifying the model’s internal weights; this operation is called finetuning and is not as computationally intensive as re-training the full model. Regardless, this module greatly benefits from GPU compute, as long as the GPU(s) support CUDA and nvidia-smi is configured correctly.

This module uses two scripts: classification/main_prepare_learning_sets.py for preparing the data for training, and classification/main_classification_finetune.py, that need to be executed in that order.

The first step is to import the necessary libraries for main_prepare_learning_sets.py:

[1]:
import argparse
import shutil
import sys
import os
from pathlib import Path

import numpy as np
import pandas as pd
import yaml

from mzbsuite.utils import cfg_to_arguments

We need to declare the running parameters for the script,

[2]:
ROOT_DIR = Path("/Users/mivolpi/Projects/BioDetect/mzb-workflow")
MODEL="convnext-small-vtest-1"
LSET_FOLD=Path(f"{ROOT_DIR}/data/mzb_example_data")

arguments = {
    "input_dir": LSET_FOLD / "curated_learning_sets",
    "taxonomy_file": ROOT_DIR.absolute() / "data/mzb_example_data/MZB_taxonomy.csv",
    "output_dir": ROOT_DIR.absolute() / "data/mzb_example_data/agg_lsets",
    "save_model": ROOT_DIR.absolute() / f"models/mzb-classification-models/{MODEL}",
    "config_file": ROOT_DIR.absolute() / "configs/mzb_example_config.yaml"
}

with open(str(arguments["config_file"]), "r") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

cfg["trcl_gpu_ids"] = None # this sets the number of available GPUs to zero, since this part of the module doesn't benefit from GPU compute.
cfg
[2]:
{'glob_random_seed': 222,
 'glob_root_folder': '/Users/mivolpi/Projects/BioDetect/mzb-workflow',
 'glob_blobs_folder': '/Users/mivolpi/Projects/BioDetect/mzb-workflow/data/derived/blobs/',
 'glob_local_format': 'pdf',
 'model_logger': 'wandb',
 'impa_image_format': 'jpg',
 'impa_clip_areas': [2700, 4700, -1, -1],
 'impa_area_threshold': 5000,
 'impa_gaussian_blur': [21, 21],
 'impa_gaussian_blur_passes': 3,
 'impa_adaptive_threshold_block_size': 351,
 'impa_mask_postprocess_kernel': [11, 11],
 'impa_mask_postprocess_passes': 5,
 'impa_bounding_box_buffer': 200,
 'impa_save_clips_plus_features': True,
 'lset_class_cut': 'order',
 'lset_val_size': 0.1,
 'trcl_learning_rate': 0.0001,
 'trcl_batch_size': 8,
 'trcl_weight_decay': 0,
 'trcl_step_size_decay': 5,
 'trcl_number_epochs': 75,
 'trcl_save_topk': 1,
 'trcl_num_classes': 8,
 'trcl_model_pretrarch': 'convnext-small',
 'trcl_num_workers': 16,
 'trcl_wandb_project_name': 'mzb-classifiers',
 'trcl_logger': 'wandb',
 'trsk_learning_rate': 0.001,
 'trsk_batch_size': 32,
 'trsk_weight_decay': 0,
 'trsk_step_size_decay': 25,
 'trsk_number_epochs': 400,
 'trsk_save_topk': 1,
 'trsk_num_classes': 2,
 'trsk_model_pretrarch': 'mit_b2',
 'trsk_num_workers': 16,
 'trsk_wandb_project_name': 'mzb-skeletons',
 'trsk_logger': 'wandb',
 'infe_model_ckpt': 'last',
 'infe_num_classes': 8,
 'infe_image_glob': '*_rgb.jpg',
 'skel_class_exclude': 'errors',
 'skel_conv_rate': 131.6625,
 'skel_label_thickness': 3,
 'skel_label_buffer_on_preds': 25,
 'skel_label_clip_with_mask': False,
 'trcl_gpu_ids': None}

Convert these parameters to a dictionary:

[3]:
# Transforms configurations dicts to argparse arguments
args = cfg_to_arguments(arguments)
cfg = cfg_to_arguments(cfg)
print(str(cfg))
{'glob_random_seed': 222, 'glob_root_folder': '/Users/mivolpi/Projects/BioDetect/mzb-workflow', 'glob_blobs_folder': '/Users/mivolpi/Projects/BioDetect/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2700, 4700, -1, -1], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'trcl_learning_rate': 0.0001, 'trcl_batch_size': 8, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 75, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'convnext-small', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False, 'trcl_gpu_ids': None}

We next check whether the target directories already exist, and if not create them:

[4]:
np.random.seed(cfg.glob_random_seed)

# root of raw clip data
root_data = Path(args.input_dir)
outdir = Path(args.output_dir)
outdir.mkdir(parents=True, exist_ok=True)

# target folders definition
target_trn = outdir / "trn_set/"
target_val = outdir / "val_set/"

# check if trn_set and val_set subfolders exist. If so, then interrupt the script.
# This is to make sure that no overwriting happens; prompt the user that they need to specify a different output directory.
if target_trn.exists() or target_val.exists():
    raise ValueError(
        # print in red and back to normal
        f"\033[91m Output directory {outdir} already exists. Please specify a different output directory.\033[0m"
    )

We now use the specified taxonomic rank in the lset_class_cut parameter in the configuration file to cut the provided phylogenetic tree, and reorganize the images in directories corresponding to the this rank. See the documentation for further details.

[5]:

# make dictionary to recode: key is current classification, value is target reclassification. # forward fill to get last valid entry and subset to desired column mzb_taxonomy = pd.read_csv(Path(args.taxonomy_file)) if "Unnamed: 0" in mzb_taxonomy.columns: mzb_taxonomy = mzb_taxonomy.drop(columns=["Unnamed: 0"]) mzb_taxonomy = mzb_taxonomy.ffill(axis=1) recode_order = dict( zip(mzb_taxonomy["query"], mzb_taxonomy[cfg.lset_class_cut].str.lower()) ) print(f"Cutting phylogenetic tree at: {cfg.lset_class_cut}")
Cutting phylogenetic tree at: order

Now we copy the images over into the new folder structure according to the taxonomy:

[6]:
# Move files to target folders for all files in the curated learning set
for s_fo in recode_order:
    target_folder = target_trn / recode_order[s_fo]
    target_folder.mkdir(exist_ok=True, parents=True)

    for file in list((root_data / s_fo).glob("*")):
        shutil.copy(file, target_folder)

Next, we split the validation set based on the proportion of total images specified by the lset_val_size parameter in the configuration file. We recommend at least 10% of the total images for each class.

[7]:
# move out the validation set
# make a small val set, 10% or 1 file, what is possible...
size = cfg.lset_val_size
trn_folds = [a.name for a in sorted(list(target_trn.glob("*")))]

for s_fo in trn_folds:
    target_folder = target_val / s_fo
    target_folder.mkdir(exist_ok=True, parents=True)

    list_class = list((target_trn / s_fo).glob("*"))
    n_val_sam = np.max((1, np.ceil(0.1 * len(list_class))))

    val_files = np.random.choice(list_class, int(n_val_sam))

    for file in val_files:
        try:
            shutil.move(str(file), target_folder)
        except:
            print(f"{str(file)} into {target_folder}")
/Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/trn_set/errors/32_hf2_protonemura_01_clip_11_rgb.png into /Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/val_set/errors
/Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/trn_set/errors/31_b1_isoperla_01_clip_11_rgb.png into /Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/val_set/errors
/Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/trn_set/errors/32_hf2_plecoptera_01_clip_5_rgb.png into /Users/mivolpi/Projects/BioDetect/mzb-workflow/data/mzb_example_data/agg_lsets/val_set/errors

Now we have the training dataset ready for model training, with a training set and a validation set containing the same classes.

We move on to the model finetuning, using the script classification/main_classification_finetune.py. First we import some additional libraries from PyTorch;

[8]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.strategies.ddp import DDPStrategy

from mzbsuite.classification.mzb_classification_pilmodel import MZBModel
from mzbsuite.utils import cfg_to_arguments, SaveLogCallback

# Set the thread layer used by MKL
os.environ["MKL_THREADING_LAYER"] = "GNU" # this time we set the GPU computing layer to active

Before we can launch the training, we need to define a few special parameters, relating to finding the specified monitoring the model training progress over time:

[9]:
# args.output_dir
args.input_dir_tr = args.output_dir
[10]:
# Define checkpoints callbacks
# best model on validation
best_val_cb = pl.callbacks.ModelCheckpoint(
    dirpath=args.save_model,
    filename="best-val-{epoch}-{step}-{val_loss:.1f}",
    monitor="val_loss",
    mode="min",
    save_top_k=cfg.trcl_save_topk,
)

# latest model in training
last_mod_cb = pl.callbacks.ModelCheckpoint(
    dirpath=args.save_model,
    filename="last-{step}",
    every_n_train_steps=50,
    save_top_k=cfg.trcl_save_topk,
)

# Define progress bar callback
pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)

# Define logger callback to log training date
trdatelog = SaveLogCallback(model_folder=args.save_model)

# Define model from config
model = MZBModel(
    data_dir=args.input_dir_tr,
    pretrained_network=cfg.trcl_model_pretrarch,
    learning_rate=cfg.trcl_learning_rate,
    batch_size=cfg.trcl_batch_size,
    weight_decay=cfg.trcl_weight_decay,
    num_workers_loader=cfg.trcl_num_workers,
    step_size_decay=cfg.trcl_step_size_decay,
    num_classes=cfg.trcl_num_classes,
)

We now check wether a pre-trained model is available, and if there is load the weights from that model. Note that logging model progress requires either a Weights & Biases or Tensorflow account. See the documentation for more details.

[11]:
# Check if there is a model to load, if there is, load it and train from there
if args.save_model.is_dir():
    if args.verbose:
        print(f"Loading model from {args.save_model}")
    try:
        fmodel = list(args.save_model.glob("last-*.ckpt"))[0]
    except:
        print("No last-* model in folder, loading best model")
        fmodel = list(
            args.save_model.glob("best-val-epoch=*-step=*-val_loss=*.*.ckpt")
        )[-1]

    model = model.load_from_checkpoint(fmodel)

# Define logger and name of run
name_run = f"classifier-{cfg.trcl_model_pretrarch}"  # f"{model.pretrained_network}"
cbacks = [pbar_cb, best_val_cb, last_mod_cb, trdatelog]

# Define logger, and use either wandb or tensorboard
if cfg.trcl_logger == "wandb":
    logger = WandbLogger(
        project=cfg.trcl_wandb_project_name, name=name_run if name_run else None
    )
    logger.watch(model, log="all")

elif cfg.trcl_logger == "tensorboard":
    logger = TensorBoardLogger(
        save_dir=args.save_model,
        name=name_run if name_run else None,
        log_graph=True,
    )
wandb: Currently logged in as: mivolpi. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
Tracking run with wandb version 0.19.4
Run data is saved locally in ./wandb/run-20250122_162849-ed2na7k7
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`

We are now finally ready to train our model!

[12]:

# instantiate trainer and train trainer = pl.Trainer( accelerator="auto", # cfg.trcl_num_gpus outdated max_epochs=cfg.trcl_number_epochs, strategy="ddp_notebook", precision=16, callbacks=cbacks, logger=logger, log_every_n_steps=1 # profiler="simple", ) trainer.fit(model)
/Users/mivolpi/micromamba/envs/str-mzb/lib/python3.10/site-packages/lightning_fabric/connector.py:572: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/Users/mivolpi/micromamba/envs/str-mzb/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

[W122 16:28:50.686888000 ProcessGroupGloo.cpp:745] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())

  | Name     | Type              | Params | Mode
-------------------------------------------------------
0 | model    | ConvNeXt          | 49.5 M | train
1 | accuracy | MulticlassF1Score | 0      | train
-------------------------------------------------------
6.2 K     Trainable params
49.5 M    Non-trainable params
49.5 M    Total params
197.843   Total estimated model params size (MB)
384       Modules in train mode
0         Modules in eval mode
Validation set size: 45
/Users/mivolpi/micromamba/envs/str-mzb/lib/python3.10/site-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8 (`cpuset` is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
[rank0]:[W122 16:28:52.059777000 NNPACK.cpp:61] Could not initialize NNPACK! Reason: Unsupported hardware.
Training set size: 385
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.

Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.

Detected KeyboardInterrupt, attempting graceful shutdown ...
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
File ~/micromamba/envs/str-mzb/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:46, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45 if trainer.strategy.launcher is not None:
---> 46     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     47 return trainer_fn(*args, **kwargs)

File ~/micromamba/envs/str-mzb/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:144, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
    143 self.procs = process_context.processes
--> 144 while not process_context.join():
    145     pass

File ~/micromamba/envs/str-mzb/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:132, in ProcessContext.join(self, timeout)
    131 # Wait for any process to fail or all of them to succeed.
--> 132 ready = multiprocessing.connection.wait(
    133     self.sentinels.keys(),
    134     timeout=timeout,
    135 )
    137 error_index = None

File ~/micromamba/envs/str-mzb/lib/python3.10/multiprocessing/connection.py:931, in wait(object_list, timeout)
    930 while True:
--> 931     ready = selector.select(timeout)
    932     if ready:

File ~/micromamba/envs/str-mzb/lib/python3.10/selectors.py:416, in _PollLikeSelector.select(self, timeout)
    415 try:
--> 416     fd_event_list = self._selector.poll(timeout)
    417 except InterruptedError:

KeyboardInterrupt:

During handling of the above exception, another exception occurred:

NameError                                 Traceback (most recent call last)
Cell In[12], line 15
      3 # instantiate trainer and train
      4 trainer = pl.Trainer(
      5     accelerator="auto",  # cfg.trcl_num_gpus outdated
      6     max_epochs=cfg.trcl_number_epochs,
   (...)
     12     # profiler="simple",
     13 )
---> 15 trainer.fit(model)

File ~/micromamba/envs/str-mzb/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:539, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    537 self.state.status = TrainerStatus.RUNNING
    538 self.training = True
--> 539 call._call_and_handle_interrupt(
    540     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    541 )

File ~/micromamba/envs/str-mzb/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:64, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     62     if isinstance(launcher, _SubprocessScriptLauncher):
     63         launcher.kill(_get_sigkill_signal())
---> 64     exit(1)
     66 except BaseException as exception:
     67     _interrupt(trainer, exception)

NameError: name 'exit' is not defined
[ ]:

[ ]: