Abstract: Multi-modal three-dimensional (3D) medical imaging data, derived from ultrasound, magnetic resonance imaging (MRI), and computed tomography (CT), provide a widely adopted approach for non-invasive anatomical visualization. However, accurate modeling depends on surface reconstruction and frame-to-frame interpolation, where traditional methods often struggle with image noise and incomplete information between sparse frames. To address these challenges, we present MedGS, a novel framework based on Gaussian Splatting (GS) designed for high-fidelity 3D anatomical reconstruction. Uniquely, MedGS employs a multi-task architecture that simultaneously performs frame interpolation and segmentation using a unified geometric representation. By coupling these tasks, the model leverages dense signals from image synthesis to regularize the geometry, enabling high-quality surface extraction even from a limited number of input frames. Specifically, medical data are modeled as Folded-Gaussians with dual color attributes, supported by an In-Between Frame Regularization (IBFR) mechanism. Experimental results demonstrate that MedGS offers more efficient training than implicit neural representations and enhances robustness to noise.
Follow the steps below to set up the project environment.
- CUDA-ready GPU with Compute Capability 7.0+
- CUDA toolkit 12 for PyTorch extensions (we used 12.4)
Create and activate a Python virtual environment using Python 3.8.
python3.8 -m venv env
source env/bin/activate
Install the PyTorch framework and torchvision for deep learning tasks.
pip3 install torch torchvision
Install the necessary submodules for Gaussian rasterization and k-nearest neighbors.
pip3 install submodules/diff-gaussian-rasterization
pip3 install submodules/simple-knn
Install all other dependencies listed in the requirements.txt file.
pip3 install -r requirements.txt
The training script supports three pipelines:
img— photometric/image reconstruction training (default)seg— binary segmentation trainingjoint— joint image + segmentation training with shared geometry and dual heads
python3 train.py -s <img_dataset_dir> -m <output_dir>
python3 train.py -s <seg_dataset_dir> -m <output_dir> --pipeline seg
python3 train.py \
-s <img_dataset_dir> \
-m <output_dir> \
--pipeline joint \
--seg_source_path <seg_dataset_dir>
Use this mode to refine only the segmentation head from a checkpoint. Useful if you previously trained single image head.
python3 train.py \
-s <seg_dataset_dir> \
-m <output_dir> \
--pipeline seg \
--seg_head_only \
--start_checkpoint <output_dir>/chkpntXXXXX.pth
Before training, convert your data into individual frames (0000.png, 0001.png, ...).
Each dataset root should look like:
<data_root>
├── original/
│ ├── 0000.png
│ ├── 0001.png
│ └── ...
└── mirror/
original/contains input frames.mirror/is used by the camera pipeline.
For --pipeline joint, you must provide:
- one dataset root for images via
-s - one dataset root for masks via
--seg_source_path
Both datasets must have:
- the same number of frames
- identical ordering / indexing (
0000.png,0001.png, ...)
-
--pipeline {img,seg,joint}
Select training mode. -
--seg_head_only
Valid with--pipeline seg. Freezes geometry and image head, trains only the segmentation head. -
--seg_source_path <path>
Required for--pipeline joint. Path to the segmentation dataset root. -
--lambda_img <float>
Weight of the image loss in joint training (default:1.0). -
--lambda_seg <float>
Weight of the segmentation loss in joint training (default:1.0, you can get good results using2or3). -
--start_checkpoint <path>
Resume from checkpoint. In joint mode, image-only checkpoints are also supported (the segmentation head is initialized and optimizer is rebuilt). -
--save_xyz
Save Gaussian xyz positions periodically to<output_dir>/xyz/. -
--random_background
Randomize background during training (useful for transparent-background data). -
--poly_degree <int>
Polynomial degree of folded Gaussians. -
--batch_size <int>
Batch size (default:3). -
--test_iterations,--save_iterations,--checkpoint_iterations
Control evaluation, full saves, and checkpoint saves.
Render test views from a trained model.
The renderer supports:
img— render image head outputseg— render segmentation outputboth— render both image and segmentation outputs
python3 render.py --model_path <model_dir> --interp <interp> --pipeline <img|seg|both>
Render image output:
python3 render.py --model_path <model_dir> --pipeline img
Render segmentation output:
python3 render.py --model_path <model_dir> --pipeline seg
Render both outputs:
python3 render.py --model_path <model_dir> --pipeline both
Render a specific checkpoint:
python3 render.py --model_path <model_dir> --iteration 30000
Render the latest checkpoint automatically (--iteration -1, default):
python3 render.py --model_path <model_dir> --iteration -1
Reduce memory usage by rendering in chunks:
python3 render.py --model_path <model_dir> --pipeline both --chunks 4
-
--model_path <path>
Path to the training output directory. -
--iteration <int>
Checkpoint iteration to load.-1selects the latestchkpnt*.pth. -
--interp <int>
Interpolation multiplier (default:1). -
--pipeline {img,seg,both}
Output type(s) to render. -
--chunks <int>
Split rendering across chunks to reduce memory usage. -
--extension <str>
Output file extension (default:.png). -
--mask_path <path>/--generate_points_path <path>
Optional advanced rendering controls.
Rendered images are saved to:
<model_dir>/render_img/for image renders<model_dir>/render_mask/for segmentation renders
Files are named like:
00000_0.png
00001_0.png
...
If --interp > 1, each frame can produce multiple outputs:
00000_0.png00000_1.png- ...
If you render with --pipeline seg and the checkpoint does not contain a dedicated segmentation head, rendering falls back to the image head.
Build a 3D mesh (.ply) from rendered segmentation frames.
The mesh script uses marching cubes. If a NIfTI file is present in a case folder, its voxel spacing can be used.
--input should point to a directory where each subfolder is one case/model.
If your mesh pipeline expects rendered masks, point it to the segmentation render folder (render_mask/) instead of the old render/ path.
python3 slices_to_ply.py \
--input <input_root> \
--output <out_dir> \
--thresh 150
--input— parent directory with case subfolders--output— destination directory for meshes (<case>.ply)--thresh— iso-level for marching cubes (PNG intensity scale)--inter— interpolation scale (if supported by yourslices_to_ply.py)





