"""SAMv2 model for 2D/3D tomogram segmentation with a prompt predictor for automated segmentation.
Code is based on the original SAM2 training code from https://github.com/facebookresearch/sam2/blob/main/training/model/sam2.py.
"""
import logging
from pathlib import Path
from typing import Any, Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
from sam2.modeling.sam2_base import SAM2Base
from torch import Tensor
from torch.optim import Optimizer
from cryovit.config import BaseModel as BaseModelConfig
from cryovit.models.base_model import BaseModel
from cryovit.models.sam2_blocks import LoRAMaskDecoderFactory, PromptPredictor
from cryovit.types import BatchedTomogramData
# Clear SAM2 hydra initialization if it exists
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
## Pre-trained Model Weights ##
sam2_model = (
"facebook/sam2.1-hiera-tiny",
{"config": "sam2.1_hiera_t.yaml", "weights": "sam2.1_hiera_tiny.pt"},
) # the tiny variant of SAMv2.1
medical_sam2_model = (
"wanglab/MedSAM2",
{"config": "sam2.1_hiera_t.yaml", "weights": "MedSAM2_latest.pt"},
) # fine-tuned on medical data SAMv2
MAX_SAM_DEPTH = 255 # Temporary maximum depth (number of slices) for SAMv2 (due to implementation error in CUDA - https://github.com/pytorch/pytorch/issues/142228)
# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0
[docs]
class SAM2(BaseModel):
"""Lightning wrapper over the SAM2 model."""
[docs]
def __init__(
self, sam_model: "SAM2Train", custom_kwargs, **kwargs
) -> None:
"""Initializes the SAM2 model with specific convolutional and synthesis blocks."""
super().__init__(**kwargs)
self.prompt_lr = custom_kwargs.get("prompt_lr", 3e-05)
if "prompt_lr" in custom_kwargs:
del custom_kwargs["prompt_lr"]
self.model = sam_model(**custom_kwargs)
self.prompt_predictor = PromptPredictor()
self.freeze_parameters()
self.log_masks = False # whether to log predicted masks during training for debugging
[docs]
def freeze_parameters(self):
"""Freezes all model parameters except for the prompt predictor and mask decoder."""
for p in self.model.image_encoder.parameters():
p.requires_grad = False
for p in self.model.sam_prompt_encoder.parameters():
p.requires_grad = False
for p in self.model.memory_encoder.parameters():
p.requires_grad = False
for p in self.model.memory_attention.parameters():
p.requires_grad = False
def _masked_predict(
self,
batch: BatchedTomogramData, # type: ignore
use_mito_mask: bool = False,
) -> dict[str, Tensor]:
"""Override trainer _masked_predict to handle masking for the prompt predictor."""
out = self(batch)
y_true = batch.labels # (B, D, H, W)
y_pred_full, mask_pred_full = out["preds"], out["prompts"]
mask = (y_true > -1.0).detach()
if use_mito_mask:
assert (
batch.aux_data is not None and "labels/mito" in batch.aux_data
), "Batch aux_data must contain 'labels/mito' key for mito masking."
# assumes eval code with batch size of 1
mito_mask = torch.from_numpy(batch.aux_data["labels/mito"][0]) > 0
mito_mask = mito_mask.to(dtype=mask.dtype, device=mask.device)
mask = mask & mito_mask # Combine masks
y_pred = torch.masked_select(y_pred_full, mask).view(-1, 1)
y_true = torch.masked_select(y_true, mask).view(-1, 1)
mask_pred = torch.masked_select(mask_pred_full, mask).view(-1, 1)
return {
"preds": y_pred,
"masks": mask_pred,
"labels": y_true,
"preds_full": y_pred_full,
"masks_full": mask_pred_full,
}
def _do_step(
self,
batch: BatchedTomogramData, # type: ignore
batch_idx: int,
prefix: Literal["train", "val", "test"],
) -> Tensor:
"""Override trainer do_step to handle losses for the prompt predictor."""
out_dict = self._masked_predict(batch)
y_pred, y_true = out_dict["preds"], out_dict["labels"]
losses = self.compute_losses(y_pred, y_true)
mask_loss = self.compute_losses(out_dict["masks"], y_true)["dice_loss"]
losses["mask_loss"] = mask_loss
losses["total"] = losses["total"] + mask_loss
for _, m_fn in self.metric_fns[prefix.upper()].items(): # type: ignore
m_fn(y_pred, y_true)
self.log_stats(losses, prefix, batch.num_tomos)
if self.training and self.log_masks:
import wandb
# debug logging predicted masks
raw_image = (
batch.tomo_batch[0, batch.num_slices // 2].detach().cpu()[[0]]
)
pred_image = (
out_dict["preds_full"][0, batch.num_slices // 2]
.detach()
.cpu()
.unsqueeze(0)
)
mask_image = (
out_dict["masks_full"][0, batch.num_slices // 2]
.detach()
.cpu()
.unsqueeze(0)
)
prompt_image = (mask_image > 0.9).float()
combined_image = torch.cat(
[raw_image, pred_image, mask_image, prompt_image], dim=-1
) # concat along width
wandb.log(
{
"raw-pred-mask-prompt": [
wandb.Image(
combined_image,
caption="Raw | Pred | Mask | Prompt",
)
]
},
commit=False,
)
return losses["total"]
[docs]
def forward(self, data: BatchedTomogramData) -> dict[str, Tensor]: # type: ignore
C, H, W = data.tomo_batch.shape[-3:] # [H, W]
truncate_size = 0
# Expand channels for expected RGB input
if C == 1:
data.tomo_batch = data.tomo_batch.expand(-1, -1, 3, -1, -1)
C = 3
# Truncate if too many slices
do_truncate = data.num_slices > MAX_SAM_DEPTH
do_resize = self.model.image_size != H or self.model.image_size != W
if do_truncate:
logging.warning(
"Truncating input tomogram from %d to %d slices for SAM2 model.",
data.num_slices,
MAX_SAM_DEPTH,
)
truncate_size = data.num_slices - MAX_SAM_DEPTH
data.tomo_batch = data.tomo_batch[:, :MAX_SAM_DEPTH]
data.tomo_sizes = torch.clamp(data.tomo_sizes, max=MAX_SAM_DEPTH)
data.min_slices = min(data.min_slices, MAX_SAM_DEPTH)
if do_resize:
# Resize the input tomogram batch to the target size
data.tomo_batch = F.interpolate(
data.tomo_batch,
size=(C, self.model.image_size, self.model.image_size),
mode="trilinear",
align_corners=False,
)
flat_tensor = data.flat_tomo_batch # [BxDxCxHxW] -> [[BxD]xCxHxW]
backbone_out = self.model.image_encoder(flat_tensor)
flat_box_prompts, flat_mask_prompts = self.prompt_predictor(
backbone_out["backbone_fpn"][0], num_batches=data.num_tomos
) # flat tensor form
binary_flat_mask_prompts = (
flat_mask_prompts > 0.9
).bool() # binarize the mask prompts for SAM2 input conservatively
preds = self.model(
data, flat_box_prompts, binary_flat_mask_prompts
) # forward pass through SAM2
masks = flat_mask_prompts.view(
data.num_tomos, data.num_slices, *flat_mask_prompts.shape[-2:]
) # reshape to [B, D, H, W]
out = {"preds": preds, "prompts": masks}
# Pad outputs if truncated
if do_truncate:
pad_size = (0, 0, 0, 0, 0, truncate_size)
for k in out:
out[k] = F.pad(out[k], pad_size, mode="constant", value=0)
if do_resize:
# Upsample the output to the original size
for k in out:
out[k] = F.interpolate(
out[k], size=(H, W), mode="bilinear", align_corners=False
)
# Resize the input tomogram batch to the original size
data.tomo_batch = F.interpolate(
data.tomo_batch,
size=(C, H, W),
mode="trilinear",
align_corners=False,
)
return out
[docs]
def load_sam_state_dict(
self,
state_dict: dict[str, Tensor],
strict: bool = False,
assign: bool = True,
) -> tuple:
"""Override load_state_dict to handle loading of SAM2 weights."""
return self.model.load_state_dict(
state_dict, strict=strict, assign=assign
)
[docs]
def compile(self) -> None:
"""Compiles the model image encoder for training."""
self.model.image_encoder.forward = torch.compile(
self.model.image_encoder.forward
)
self.model.memory_encoder.forward = torch.compile(
self.model.memory_encoder.forward
)
self.model.memory_attention.forward = torch.compile(
self.model.memory_attention.forward
)
class SAM2Train(SAM2Base):
"""SAMv2 model implementation."""
def __init__(
self,
image_encoder: nn.Module,
memory_attention: nn.Module,
memory_encoder: nn.Module,
num_init_cond_slices: tuple[int, int] = (1, 1),
rand_init_cond_slices: tuple[bool, bool] = (True, False),
**kwargs,
) -> None:
"""Initializes the SAM2 model with pre-trained blocks."""
super().__init__(
image_encoder, memory_attention, memory_encoder, **kwargs
)
self.num_init_cond_slices = num_init_cond_slices
self.rand_init_cond_slices = rand_init_cond_slices
def _apply_lora_to_mask_decoder(self):
"""Delay applying LoRA to the mask decoder until after loading weights."""
decoder_factory = LoRAMaskDecoderFactory(
lora_r=128, lora_alpha=128
) # Using alpha=r
self.sam_mask_decoder = decoder_factory.apply(self.sam_mask_decoder)
def forward(
self, data: BatchedTomogramData, box_prompts, mask_prompts # type: ignore
) -> dict[str, Any] | Tensor:
"""Forward pass for the SAMv2 model."""
backbone_out = self.forward_image(data.flat_tomo_batch)
mid_slice_idx = data.num_slices // 2
backbone_out = self.prepare_prompt_inputs(
backbone_out,
box_prompts,
mask_prompts,
data,
start_slice_idx=mid_slice_idx,
)
out = self.forward_tracking(backbone_out, data)
if not isinstance(out, dict):
out = torch.sigmoid(out)
return out
def prepare_prompt_inputs(
self,
backbone_out: dict[str, Any],
box_prompts: dict[str, Any],
mask_prompts: dict[str, Any],
data: BatchedTomogramData, # type: ignore
start_slice_idx: int = 0,
) -> dict[str, Any]:
"""Prepare predicted masks."""
backbone_out["num_slices"] = data.num_slices
# Setup prompt parameters
if self.training:
num_init_cond_slices = self.num_init_cond_slices[0]
rand_init_cond_slices = self.rand_init_cond_slices[0]
else:
num_init_cond_slices = self.num_init_cond_slices[1]
rand_init_cond_slices = self.rand_init_cond_slices[1]
assert (
num_init_cond_slices >= 1
), "Number of initial conditioning slices must be at least 1."
if rand_init_cond_slices and num_init_cond_slices > 1:
# Randomly select number of initial conditioning slices
num_init_cond_slices = np.random.randint(
1, num_init_cond_slices + 1
)
# Select initial conditioning slices
if num_init_cond_slices == 1:
init_cond_slices = [start_slice_idx]
else:
init_cond_slices = [start_slice_idx] + np.random.choice(
a=range(start_slice_idx + 1, data.min_slices),
size=num_init_cond_slices - 1,
replace=False,
).tolist()
backbone_out["init_cond_slices"] = init_cond_slices
backbone_out["slices_not_in_init_cond"] = [
n for n in range(data.num_slices) if n not in init_cond_slices
]
# Prepare mask and box inputs for slices
backbone_out["box_inputs_per_slice"] = {}
backbone_out["mask_inputs_per_slice"] = {}
for n in init_cond_slices:
idxs = data.index_to_flat_batch(n)
backbone_out["box_inputs_per_slice"][n] = (
box_prompts[idxs] * self.sam_prompt_encoder.input_image_size[0]
)
backbone_out["mask_inputs_per_slice"][n] = mask_prompts[idxs]
return backbone_out
def forward_tracking(
self,
backbone_out: dict[str, Any],
data: BatchedTomogramData, # type: ignore
return_dict: bool = False,
) -> dict[str, Any] | Tensor:
"""Forward tracking on each slice."""
# Prepare backbone features
# backbone_out is [[BxD]xCxHxW]
# vision_feats and vision_pos_embeds are [(HW), (BD), C]
_, vision_feats, vision_pos_embeds, feat_sizes = (
self._prepare_backbone_features(backbone_out)
)
# Start loop over slices
num_slices = backbone_out["num_slices"]
init_cond_slices = backbone_out["init_cond_slices"]
# First process initial conditioning slices, then condition on them for memory
processing_order = (
init_cond_slices + backbone_out["slices_not_in_init_cond"]
)
# Use "frame" instead of "slice" to match with SAM2 implementation
output_dict = {
"cond_frame_outputs": {},
"non_cond_frame_outputs": {},
}
for slice_id in processing_order:
flat_idxs = data.index_to_flat_batch(slice_id)
# Get image features for the current slice
current_vision_feats = [x[:, flat_idxs] for x in vision_feats]
current_vision_pos_embeds = [
x[:, flat_idxs] for x in vision_pos_embeds
]
current_out = self.track_step(
frame_idx=slice_id,
is_init_cond_frame=slice_id in init_cond_slices,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=backbone_out["box_inputs_per_slice"].get(
slice_id, None
),
mask_inputs=backbone_out["mask_inputs_per_slice"].get(
slice_id, None
),
output_dict=output_dict,
num_frames=num_slices,
)
add_output_as_cond_slice = slice_id in init_cond_slices
if add_output_as_cond_slice:
output_dict["cond_frame_outputs"][slice_id] = current_out
else:
output_dict["non_cond_frame_outputs"][slice_id] = current_out
if return_dict:
return output_dict
# turn 'output_dict' into a batched tensor for loss function (expects [B, D, H, W] output)
all_slice_outputs = {}
all_slice_outputs.update(output_dict["cond_frame_outputs"])
all_slice_outputs.update(output_dict["non_cond_frame_outputs"])
pred_output = []
for _, output_dict in all_slice_outputs.items():
# Upsample to original size (from low-res masks)
preds = F.interpolate(
output_dict["pred_masks"],
scale_factor=4,
mode="bilinear",
align_corners=False,
)
pred_output.append(preds)
total_output = torch.cat(pred_output, dim=1)
return total_output
def track_step(
self,
frame_idx: int,
is_init_cond_frame: bool,
current_vision_feats: Tensor | list[Tensor],
current_vision_pos_embeds: Tensor | list[Tensor],
feat_sizes: Tensor | list[tuple],
point_inputs: Tensor | None,
mask_inputs: Tensor | None,
output_dict: dict[str, Any],
num_frames: int,
track_in_reverse: bool = False,
run_mem_encoder: bool = True,
prev_sam_mask_logits: Any | None = None,
) -> dict[str, Any]:
"""Process a single slice in the tomogram."""
# Run the tracking step for the current slice
current_out, sam_outputs, _, _ = self._track_step(
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
False,
None,
)
# Only save essential outputs to reduce memory usage
(
low_res_multimasks,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = sam_outputs
# Combine multimask outputs into one mask by taking the max
if low_res_multimasks is not None:
low_res_masks = torch.max(
low_res_multimasks, dim=1, keepdim=True
).values
current_out["pred_masks"] = low_res_masks
current_out["obj_ptr"] = obj_ptr
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future slices)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
None,
True, # run_mem_encoder
high_res_masks,
object_score_logits,
current_out,
)
return current_out
def _track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
track_in_reverse,
prev_sam_mask_logits,
):
"""Overrided _track_step to handle box and mask prompts."""
current_out = {
"point_inputs": None,
"box_inputs": point_inputs,
"mask_inputs": mask_inputs,
}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(
current_vision_feats[:-1], feat_sizes[:-1], strict=False
)
]
else:
high_res_features = None
if (
mask_inputs is not None
and self.use_mask_input_as_output_without_sam
):
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(
pix_feat, high_res_features, mask_inputs
)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, None)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return current_out, sam_outputs, high_res_features, pix_feat
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""Forward SAM prompt encoders and mask heads, overrided to use box and mask prompts."""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts by padding with an empty point with label -1
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -1 * torch.ones(B, 1, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (
B,
1,
)
if (
mask_inputs.shape[-2:]
!= self.sam_prompt_encoder.mask_input_size
):
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=point_inputs,
masks=sam_mask_prompt,
)
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[
batch_inds, best_iou_inds
].unsqueeze(1)
high_res_masks = high_res_multimasks[
batch_inds, best_iou_inds
].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = (
low_res_multimasks,
high_res_multimasks,
)
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float() # type: ignore
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
#### Model Creation and Loading ####
[docs]
def create_sam_model_from_weights(cfg: BaseModelConfig, sam_dir: Path) -> SAM2:
"""Creates a SAM2 model from pre-trained weights specified in the config."""
configs = _download_model_weights(sam_dir)
assert (
cfg.name in configs
), f"Model {cfg.name} was not found in available SAMv2 models. Available models are {configs.keys()}."
file_paths = configs[cfg.name]
# Merge configs together
model_cfg_path = file_paths["config"]
model_cfg = OmegaConf.load(model_cfg_path)["model"] # type: ignore
model_cfg._target_ = (
"cryovit.models.sam2.SAM2Train" # Use cryovit SAM2 as target
)
model_cfg.image_size = (
512 # Set image size to 512 (crop size for training)
)
model_cfg.use_mask_input_as_output_without_sam = (
False # use sam memory and mask decoder
)
model_cfg._partial_ = True
model = instantiate(
cfg, sam_model=model_cfg, custom_kwargs=cfg.custom_kwargs
)
sd = torch.load(
file_paths["weights"], map_location="cpu", weights_only=True
)["model"]
missing_keys, unexpected_keys = model.load_sam_state_dict(sd)
if missing_keys:
logging.error(missing_keys)
raise RuntimeError()
if unexpected_keys:
logging.error(unexpected_keys)
raise RuntimeError()
model.model._apply_lora_to_mask_decoder() # Apply LoRA after loading weights
model.configure_optimizers() # Configure optimizers after setting requires_grad
return model
def _download_model_weights(sam_dir: Path) -> dict[str, dict[str, Path]]:
"""Downloads the SAMv2 and Medical-SAMv2 model weights if they do not exist using huggingface_hub."""
# Download base SAMv2 model
sam2_repo, sam2_config = sam2_model
if not (
(sam_dir / sam2_config["weights"]).exists()
and (sam_dir / sam2_config["config"]).exists()
):
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=sam2_repo, repo_type="model", local_dir=sam_dir
)
sam2_config = {k: sam_dir / v for k, v in sam2_config.items()}
# Download Medical-SAMv2
medsam_repo, medsam_config = medical_sam2_model
if not (
(sam_dir / medsam_config["weights"]).exists()
and (sam_dir / medsam_config["config"]).exists()
):
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=medsam_repo, repo_type="model", local_dir=sam_dir
)
medsam_config = {k: sam_dir / v for k, v in medsam_config.items()}
return {"SAM2": sam2_config, "MedSAM": medsam_config}