Chapter 7: Computer Vision for Medical Imaging
Learning Objectives
By the end of this chapter, readers will be able to:
-
Implement production-grade medical image preprocessing and augmentation pipelines that account for systematic differences in acquisition protocols, equipment quality, and imaging parameters across healthcare settings serving diverse populations, with comprehensive documentation of all transformations applied.
-
Develop semantic and instance segmentation models for medical images that achieve equitable performance across patient demographics and care settings, including approaches for handling systematic variation in anatomy, pathology presentation, and image quality that correlate with patient race, age, and socioeconomic status.
-
Design object detection and classification systems for diagnostic radiology that explicitly account for prevalence differences across populations, implement fairness-aware confidence thresholds, and provide calibrated uncertainty estimates stratified by demographic factors and acquisition characteristics.
-
Apply self-supervised and semi-supervised learning techniques to leverage large unlabeled medical imaging datasets while ensuring learned representations do not encode spurious correlations between imaging artifacts and patient demographics that could lead to biased downstream predictions.
-
Build domain adaptation and transfer learning frameworks that enable models trained on data from well-resourced settings to generalize to images from safety-net hospitals, community health centers, and resource-limited environments with different equipment and protocols.
-
Evaluate medical imaging models using comprehensive fairness frameworks that stratify performance by patient demographics, care setting characteristics, equipment manufacturers, and acquisition parameters, with statistical tests for significant performance disparities and effect size quantification.
7.1 Introduction: The Promise and Peril of Medical Imaging AI
Medical imaging has emerged as perhaps the most successful application domain for deep learning in healthcare. Convolutional neural networks now match or exceed human expert performance on tasks ranging from chest radiograph interpretation to diabetic retinopathy detection to histopathology slide analysis. These systems promise to democratize access to specialized imaging expertise, enabling accurate diagnoses in settings that lack subspecialty radiologists, ophthalmologists, or pathologists. The potential public health impact is enormous, particularly for underserved communities where specialist shortages are most acute.
Yet this promise remains largely unrealized, and in some cases, deployment of medical imaging AI has exacerbated rather than ameliorated healthcare disparities. A dermatology AI system trained predominantly on images of light skin performs poorly at detecting melanoma in darker skin tones, where delayed diagnosis is already more common and mortality rates are higher. A chest X-ray interpretation model developed at an academic medical center using digital radiography systems fails when applied to portable X-ray images common in rural hospitals and nursing homes serving low-income elderly patients. A mammography screening algorithm exhibits different false positive rates across racial groups, potentially leading to differential callback rates and patient anxiety.
These failures are not inevitable consequences of the technology but rather stem from choices made throughout the development lifecycle. Training datasets that undersample certain populations teach models patterns that do not generalize. Preprocessing pipelines optimized for high-quality images from modern equipment introduce artifacts when applied to images from older scanners common in under-resourced facilities. Evaluation frameworks that report only aggregate performance metrics hide systematic disparities across patient subgroups. Deployment processes that ignore differences in clinical workflows and decision thresholds across care settings lead to differential impacts on patient care quality.
This chapter develops computer vision methods for medical imaging that explicitly center equity throughout the development process. We begin with fundamental image preprocessing and augmentation techniques, examining how seemingly technical choices affect model fairness. We then develop segmentation architectures that account for anatomical variation across populations, detection systems that handle prevalence differences without sacrificing performance for minority groups, and classification approaches that maintain calibration across diverse patient demographics and care settings. Throughout, we emphasize not just achieving high average performance but ensuring equitable outcomes across all populations who will be affected by these systems.
The medical imaging modalities we address include radiography (conventional X-rays and digital radiography), computed tomography, magnetic resonance imaging, ultrasound, nuclear medicine, and optical imaging including fundus photography and dermoscopy. Each modality presents unique technical challenges and equity considerations. The implementations provided are production-ready, with comprehensive error handling, logging, and evaluation frameworks that surface fairness issues during development rather than after deployment. The goal is to enable practitioners to build medical imaging AI systems that truly serve all patient populations equitably, fulfilling rather than undermining the technology’s promise to improve healthcare access and outcomes.
7.1.1 Sources of Bias in Medical Imaging AI
Before developing technical solutions, we must understand the mechanisms through which bias enters medical imaging systems. These mechanisms operate at multiple levels, from data collection through model deployment.
Dataset composition bias arises when training data systematically underrepresents certain patient populations. Academic medical centers that provide most publicly available imaging datasets serve different demographic groups than community hospitals and safety-net facilities. The NIH Chest X-ray dataset, while valuable for research, contains primarily images from patients at a tertiary care center whose demographics differ substantially from the broader U.S. population. When models trained on such datasets are deployed in community settings, they encounter distribution shift in both patient characteristics and image properties.
Acquisition protocol bias reflects systematic differences in how images are captured across healthcare settings. Modern digital radiography systems with automated exposure control produce consistent image quality, while older film-screen systems and portable X-ray machines common in under-resourced settings yield more variable images. Magnetic resonance imaging protocols vary substantially across institutions in terms of field strength, sequence parameters, and contrast agent use. These technical differences correlate with patient socioeconomic status through the unequal distribution of healthcare resources.
Prevalence and presentation bias emerges from genuine biological and social epidemiological patterns. Disease prevalence varies across populations due to differential exposure to risk factors, access to preventive care, and genetic background. Disease presentation can differ, as with dermatological conditions manifesting differently on darker versus lighter skin tones. While these differences are real rather than artifactual, naively training models on datasets that do not reflect these patterns leads to poor generalization.
Annotation bias occurs when human labelers apply different standards or make systematic errors that correlate with patient characteristics. Radiologists may have differential confidence in their interpretations depending on image quality, patient history completeness, or implicit biases about which demographic groups are more likely to have certain conditions. These biases in training labels then become encoded in model predictions.
Calibration bias manifests as systematic miscalibration of prediction confidence across groups. A model might output well-calibrated probabilities for the majority population but overconfident or underconfident predictions for underrepresented groups. This matters clinically because downstream decision thresholds assume proper calibration.
Understanding these mechanisms enables us to address them through appropriate technical interventions at each stage of the development pipeline.
7.2 Medical Image Preprocessing with Equity Considerations
Medical image preprocessing transforms raw images from acquisition devices into formats suitable for model input. These transformations profoundly affect model fairness because systematic differences in image properties across patient populations can be either preserved, amplified, or mitigated depending on preprocessing choices.
7.2.1 Intensity Normalization Across Acquisition Protocols
Medical images exhibit wide variation in intensity distributions depending on acquisition parameters, equipment manufacturer, and institutional protocols. Simple min-max scaling to a fixed range can produce very different results when applied to images with different dynamic ranges. Consider two chest X-rays, one from a modern digital radiography system with 12-bit dynamic range and one from an older portable system with 8-bit range. Min-max scaling treats them identically, but the information content and noise characteristics differ substantially.
We implement adaptive intensity normalization that accounts for acquisition characteristics while preserving clinically relevant information:
"""
Adaptive Medical Image Normalization
Implements intensity normalization strategies that account for systematic
differences in acquisition protocols and equipment while preserving
diagnostic information content and maintaining fairness across settings.
"""
import numpy as np
from typing import Optional, Tuple, Dict, Union
import logging
from scipy import ndimage
from skimage import exposure
import pydicom
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AdaptiveNormalizer:
"""
Adaptive intensity normalization for medical images.
Provides multiple normalization strategies appropriate for different
imaging modalities and acquisition conditions. Maintains metadata
about normalization applied to enable inverse transforms and
fairness auditing.
"""
def __init__(
self,
method: str = 'adaptive_histogram',
clip_limit: float = 0.01,
preserve_range: bool = True,
log_transform: bool = False
):
"""
Initialize adaptive normalizer.
Args:
method: Normalization method ('adaptive_histogram', 'percentile',
'zscore', 'robust_zscore')
clip_limit: Clipping limit for adaptive histogram equalization
preserve_range: Whether to preserve original intensity range
log_transform: Apply log transform before normalization for
wide dynamic range images
"""
self.method = method
self.clip_limit = clip_limit
self.preserve_range = preserve_range
self.log_transform = log_transform
self.normalization_stats = {}
logger.info(
f"Initialized {method} normalizer "
f"(clip={clip_limit}, preserve_range={preserve_range})"
)
def normalize(
self,
image: np.ndarray,
mask: Optional[np.ndarray] = None,
metadata: Optional[Dict] = None
) -> Tuple[np.ndarray, Dict]:
"""
Normalize medical image with metadata tracking.
Args:
image: Input image array
mask: Optional binary mask indicating valid regions
metadata: Optional dict with acquisition metadata
Returns:
Tuple of (normalized image, normalization metadata)
"""
if image.ndim not in [2, 3]:
raise ValueError(
f"Expected 2D or 3D image, got shape {image.shape}"
)
# Extract region of interest if mask provided
if mask is not None:
roi_pixels = image[mask > 0]
else:
roi_pixels = image.flatten()
# Remove invalid values
roi_pixels = roi_pixels[np.isfinite(roi_pixels)]
if len(roi_pixels) == 0:
logger.warning("No valid pixels found in image")
return image, {'method': 'identity', 'error': 'no_valid_pixels'}
# Apply log transform for wide dynamic range
if self.log_transform:
image = self._safe_log_transform(image)
roi_pixels = self._safe_log_transform(roi_pixels)
# Compute normalization based on method
if self.method == 'adaptive_histogram':
normalized, stats = self._adaptive_histogram_eq(
image, mask, roi_pixels
)
elif self.method == 'percentile':
normalized, stats = self._percentile_normalize(
image, roi_pixels
)
elif self.method == 'zscore':
normalized, stats = self._zscore_normalize(
image, roi_pixels
)
elif self.method == 'robust_zscore':
normalized, stats = self._robust_zscore_normalize(
image, roi_pixels
)
else:
raise ValueError(f"Unknown normalization method: {self.method}")
# Add acquisition metadata if provided
if metadata is not None:
stats.update({
'acquisition_metadata': metadata
})
return normalized, stats
def _safe_log_transform(self, x: np.ndarray) -> np.ndarray:
"""Apply log transform with handling for non-positive values."""
x_shifted = x - x.min() + 1.0
return np.log(x_shifted)
def _adaptive_histogram_eq(
self,
image: np.ndarray,
mask: Optional[np.ndarray],
roi_pixels: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""
Adaptive histogram equalization with local contrast enhancement.
This approach enhances local contrast while adapting to the
global intensity distribution, making it robust to varying
acquisition parameters.
"""
# Compute optimal number of bins based on dynamic range
n_bins = min(256, int(np.sqrt(len(np.unique(roi_pixels)))))
if image.ndim == 2:
# 2D adaptive histogram equalization
from skimage import exposure
normalized = exposure.equalize_adapthist(
image,
clip_limit=self.clip_limit,
nbins=n_bins
)
else:
# Apply per-slice for 3D images
normalized = np.zeros_like(image)
for i in range(image.shape[0]):
normalized[i] = exposure.equalize_adapthist(
image[i],
clip_limit=self.clip_limit,
nbins=n_bins
)
stats = {
'method': 'adaptive_histogram',
'clip_limit': self.clip_limit,
'n_bins': n_bins,
'original_range': (float(roi_pixels.min()), float(roi_pixels.max())),
'normalized_range': (float(normalized.min()), float(normalized.max()))
}
return normalized, stats
def _percentile_normalize(
self,
image: np.ndarray,
roi_pixels: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""
Percentile-based normalization robust to outliers.
Uses 1st and 99th percentiles for clipping followed by
normalization to [0, 1] range.
"""
p1, p99 = np.percentile(roi_pixels, [1, 99])
normalized = np.clip(image, p1, p99)
normalized = (normalized - p1) / (p99 - p1 + 1e-8)
if self.preserve_range:
normalized = normalized * (p99 - p1) + p1
stats = {
'method': 'percentile',
'p1': float(p1),
'p99': float(p99),
'preserve_range': self.preserve_range
}
return normalized, stats
def _zscore_normalize(
self,
image: np.ndarray,
roi_pixels: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""Standard z-score normalization."""
mean = roi_pixels.mean()
std = roi_pixels.std()
normalized = (image - mean) / (std + 1e-8)
stats = {
'method': 'zscore',
'mean': float(mean),
'std': float(std)
}
return normalized, stats
def _robust_zscore_normalize(
self,
image: np.ndarray,
roi_pixels: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""
Robust z-score using median and MAD.
More robust to outliers than standard z-score, important
for images with artifacts or extreme values.
"""
median = np.median(roi_pixels)
mad = np.median(np.abs(roi_pixels - median))
# Convert MAD to standard deviation equivalent
mad_std = 1.4826 * mad
normalized = (image - median) / (mad_std + 1e-8)
stats = {
'method': 'robust_zscore',
'median': float(median),
'mad': float(mad),
'mad_std': float(mad_std)
}
return normalized, stats
def denormalize(
self,
normalized_image: np.ndarray,
normalization_stats: Dict
) -> np.ndarray:
"""
Reverse normalization using stored statistics.
Important for interpreting model outputs in original
intensity space and for clinical validation.
"""
method = normalization_stats['method']
if method == 'percentile':
if self.preserve_range:
p1 = normalization_stats['p1']
p99 = normalization_stats['p99']
return (normalized_image - p1) / (p99 - p1)
else:
p1 = normalization_stats['p1']
p99 = normalization_stats['p99']
return normalized_image * (p99 - p1) + p1
elif method == 'zscore':
mean = normalization_stats['mean']
std = normalization_stats['std']
return normalized_image * std + mean
elif method == 'robust_zscore':
median = normalization_stats['median']
mad_std = normalization_stats['mad_std']
return normalized_image * mad_std + median
else:
logger.warning(
f"Denormalization not implemented for {method}"
)
return normalized_image
class MultiSiteNormalizer:
"""
Normalization that accounts for systematic site-level differences.
Learns site-specific normalization parameters during training and
applies appropriate transform based on site identifier. Enables
training on multi-site data while maintaining fairness.
"""
def __init__(self, base_method: str = 'robust_zscore'):
"""
Initialize multi-site normalizer.
Args:
base_method: Base normalization method to use
"""
self.base_method = base_method
self.site_normalizers = {}
self.global_normalizer = AdaptiveNormalizer(method=base_method)
def fit_site(
self,
site_id: str,
images: np.ndarray,
masks: Optional[np.ndarray] = None
) -> None:
"""
Learn normalization parameters for a specific site.
Args:
site_id: Unique identifier for acquisition site
images: Array of images from this site
masks: Optional masks for each image
"""
normalizer = AdaptiveNormalizer(method=self.base_method)
# Compute pooled statistics across all images from site
all_roi_pixels = []
for i, image in enumerate(images):
mask = masks[i] if masks is not None else None
if mask is not None:
roi = image[mask > 0]
else:
roi = image.flatten()
roi = roi[np.isfinite(roi)]
all_roi_pixels.append(roi)
all_roi_pixels = np.concatenate(all_roi_pixels)
# Fit normalizer on pooled data
_, stats = normalizer.normalize(
images[0],
metadata={'site_id': site_id, 'n_images': len(images)}
)
self.site_normalizers[site_id] = (normalizer, stats, all_roi_pixels)
logger.info(
f"Learned normalization for site {site_id} "
f"({len(images)} images, {len(all_roi_pixels)} pixels)"
)
def normalize(
self,
image: np.ndarray,
site_id: Optional[str] = None,
mask: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, Dict]:
"""
Normalize image using site-specific or global parameters.
Args:
image: Input image
site_id: Site identifier (uses global if None or unknown)
mask: Optional ROI mask
Returns:
Normalized image and metadata
"""
if site_id is not None and site_id in self.site_normalizers:
normalizer, _, _ = self.site_normalizers[site_id]
logger.debug(f"Using site-specific normalization for {site_id}")
else:
normalizer = self.global_normalizer
logger.debug("Using global normalization")
return normalizer.normalize(image, mask=mask)
def compute_fairness_metrics_for_normalization(
images_by_group: Dict[str, np.ndarray],
normalizer: Union[AdaptiveNormalizer, MultiSiteNormalizer],
masks_by_group: Optional[Dict[str, np.ndarray]] = None
) -> Dict:
"""
Evaluate normalization fairness across demographic groups.
Assesses whether normalization affects different groups differently
in ways that could impact downstream model performance.
Args:
images_by_group: Dict mapping group labels to image arrays
normalizer: Normalizer to evaluate
masks_by_group: Optional masks for each group
Returns:
Dictionary of fairness metrics across groups
"""
results = {}
for group_name, images in images_by_group.items():
masks = masks_by_group.get(group_name) if masks_by_group else None
group_stats = []
for i, image in enumerate(images):
mask = masks[i] if masks is not None else None
_, stats = normalizer.normalize(image, mask=mask)
group_stats.append(stats)
# Compute aggregate statistics for group
if 'mean' in group_stats[0]:
means = [s['mean'] for s in group_stats]
stds = [s['std'] for s in group_stats]
results[group_name] = {
'mean_of_means': np.mean(means),
'std_of_means': np.std(means),
'mean_of_stds': np.mean(stds),
'std_of_stds': np.std(stds),
'n_images': len(images)
}
elif 'p1' in group_stats[0]:
p1s = [s['p1'] for s in group_stats]
p99s = [s['p99'] for s in group_stats]
results[group_name] = {
'mean_p1': np.mean(p1s),
'std_p1': np.std(p1s),
'mean_p99': np.mean(p99s),
'std_p99': np.std(p99s),
'n_images': len(images)
}
# Compute disparity metrics across groups
if len(results) > 1:
group_names = list(results.keys())
# For methods with mean/std
if 'mean_of_means' in results[group_names[0]]:
means = [results[g]['mean_of_means'] for g in group_names]
stds = [results[g]['mean_of_stds'] for g in group_names]
results['disparity_metrics'] = {
'mean_range': max(means) - min(means),
'mean_coefficient_of_variation': np.std(means) / (np.mean(means) + 1e-8),
'std_range': max(stds) - min(stds),
'std_coefficient_of_variation': np.std(stds) / (np.mean(stds) + 1e-8)
}
return results
This normalization framework explicitly tracks how preprocessing affects different patient groups and acquisition settings. The adaptive approaches can handle the heterogeneity in medical imaging data without forcing all images into a single distribution that may be inappropriate for some subpopulations.
7.2.2 Spatial Preprocessing and Anatomical Standardization
Beyond intensity normalization, spatial preprocessing including resampling, registration, and anatomical standardization can introduce or mitigate fairness issues. Anatomical structures vary systematically across populations due to both biological factors (age, sex, ancestry) and measurement factors (patient positioning, field of view).
We implement anatomical standardization approaches that account for this variation:
"""
Anatomical Standardization with Population Awareness
Implements spatial preprocessing that accounts for systematic anatomical
variation across demographics while maintaining diagnostic information.
"""
import numpy as np
from typing import Tuple, Optional, Dict
from scipy import ndimage
from skimage import transform
import logging
logger = logging.getLogger(__name__)
class AnatomicalStandardizer:
"""
Standardize medical images to consistent anatomical reference frame.
Handles systematic anatomical variation across age, sex, and ancestry
while preserving pathological features and diagnostic information.
"""
def __init__(
self,
target_shape: Tuple[int, ...],
target_spacing: Optional[Tuple[float, ...]] = None,
preserve_aspect_ratio: bool = True,
align_to_template: bool = False
):
"""
Initialize anatomical standardizer.
Args:
target_shape: Desired output shape
target_spacing: Target voxel spacing in mm
preserve_aspect_ratio: Whether to maintain aspect ratio
align_to_template: Whether to register to anatomical template
"""
self.target_shape = target_shape
self.target_spacing = target_spacing
self.preserve_aspect_ratio = preserve_aspect_ratio
self.align_to_template = align_to_template
self.templates = {} # Population-specific templates
def register_template(
self,
template_id: str,
template_image: np.ndarray,
metadata: Optional[Dict] = None
) -> None:
"""
Register an anatomical template for specific population.
Args:
template_id: Identifier (e.g., 'adult_male', 'pediatric_female')
template_image: Template image array
metadata: Optional metadata about template population
"""
self.templates[template_id] = {
'image': template_image,
'metadata': metadata or {}
}
logger.info(
f"Registered anatomical template: {template_id} "
f"(shape {template_image.shape})"
)
def standardize(
self,
image: np.ndarray,
spacing: Optional[Tuple[float, ...]] = None,
template_id: Optional[str] = None,
landmarks: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, Dict]:
"""
Standardize image to consistent anatomical frame.
Args:
image: Input image
spacing: Current voxel spacing in mm
template_id: Which template to use for alignment
landmarks: Optional anatomical landmarks for alignment
Returns:
Standardized image and transformation metadata
"""
transform_params = {
'original_shape': image.shape,
'original_spacing': spacing
}
# Resample to target spacing if needed
if spacing is not None and self.target_spacing is not None:
image, resample_params = self._resample_to_spacing(
image, spacing, self.target_spacing
)
transform_params['resample'] = resample_params
# Resize to target shape
if self.preserve_aspect_ratio:
image, resize_params = self._resize_preserve_aspect(
image, self.target_shape
)
else:
image, resize_params = self._resize_direct(
image, self.target_shape
)
transform_params['resize'] = resize_params
# Align to anatomical template if requested
if self.align_to_template and template_id is not None:
if template_id not in self.templates:
logger.warning(
f"Template {template_id} not found, skipping alignment"
)
else:
image, alignment_params = self._align_to_template(
image,
self.templates[template_id]['image'],
landmarks
)
transform_params['alignment'] = alignment_params
transform_params['final_shape'] = image.shape
return image, transform_params
def _resample_to_spacing(
self,
image: np.ndarray,
current_spacing: Tuple[float, ...],
target_spacing: Tuple[float, ...]
) -> Tuple[np.ndarray, Dict]:
"""Resample image to target voxel spacing."""
current_spacing = np.array(current_spacing)
target_spacing = np.array(target_spacing)
# Compute scaling factors
scale_factors = current_spacing / target_spacing
# Compute output shape
output_shape = tuple(
int(np.round(s * f))
for s, f in zip(image.shape, scale_factors)
)
# Resample using appropriate interpolation
resampled = ndimage.zoom(
image,
scale_factors,
order=1, # Bilinear interpolation
mode='constant',
cval=image.min()
)
params = {
'scale_factors': scale_factors.tolist(),
'output_shape': output_shape
}
return resampled, params
def _resize_preserve_aspect(
self,
image: np.ndarray,
target_shape: Tuple[int, ...]
) -> Tuple[np.ndarray, Dict]:
"""
Resize while preserving aspect ratio through padding/cropping.
Critical for maintaining anatomical proportions across patients
of different sizes and ages.
"""
current_shape = np.array(image.shape)
target_shape = np.array(target_shape)
# Compute scale factor to fit within target while preserving aspect
scale_factor = np.min(target_shape / current_shape)
# Compute intermediate shape after scaling
scaled_shape = tuple(
int(np.round(s * scale_factor)) for s in current_shape
)
# Resize to scaled shape
if len(image.shape) == 2:
resized = transform.resize(
image,
scaled_shape,
order=1,
mode='constant',
cval=image.min(),
preserve_range=True,
anti_aliasing=True
)
else:
# Process 3D slice by slice to avoid memory issues
resized = np.zeros(scaled_shape, dtype=image.dtype)
for i in range(scaled_shape[0]):
slice_idx = int(i / scale_factor)
resized[i] = transform.resize(
image[slice_idx],
scaled_shape[1:],
order=1,
mode='constant',
cval=image.min(),
preserve_range=True,
anti_aliasing=True
)
# Pad or crop to exact target shape
if len(target_shape) == 2:
output = self._pad_or_crop_2d(resized, target_shape)
else:
output = self._pad_or_crop_3d(resized, target_shape)
params = {
'scale_factor': float(scale_factor),
'scaled_shape': scaled_shape,
'preserved_aspect_ratio': True
}
return output, params
def _resize_direct(
self,
image: np.ndarray,
target_shape: Tuple[int, ...]
) -> Tuple[np.ndarray, Dict]:
"""Direct resize without preserving aspect ratio."""
resized = transform.resize(
image,
target_shape,
order=1,
mode='constant',
cval=image.min(),
preserve_range=True,
anti_aliasing=True
)
params = {
'preserved_aspect_ratio': False
}
return resized, params
def _pad_or_crop_2d(
self,
image: np.ndarray,
target_shape: Tuple[int, int]
) -> np.ndarray:
"""Pad or crop 2D image to exact target shape."""
current_shape = image.shape
output = np.full(target_shape, image.min(), dtype=image.dtype)
# Compute region to copy
start_h = max(0, (target_shape[0] - current_shape[0]) // 2)
start_w = max(0, (target_shape[1] - current_shape[1]) // 2)
end_h = start_h + min(current_shape[0], target_shape[0])
end_w = start_w + min(current_shape[1], target_shape[1])
crop_start_h = max(0, (current_shape[0] - target_shape[0]) // 2)
crop_start_w = max(0, (current_shape[1] - target_shape[1]) // 2)
crop_end_h = crop_start_h + (end_h - start_h)
crop_end_w = crop_start_w + (end_w - start_w)
output[start_h:end_h, start_w:end_w] = \
image[crop_start_h:crop_end_h, crop_start_w:crop_end_w]
return output
def _pad_or_crop_3d(
self,
image: np.ndarray,
target_shape: Tuple[int, int, int]
) -> np.ndarray:
"""Pad or crop 3D image to exact target shape."""
# Similar logic as 2D but for 3 dimensions
current_shape = image.shape
output = np.full(target_shape, image.min(), dtype=image.dtype)
starts = [max(0, (t - c) // 2) for t, c in zip(target_shape, current_shape)]
ends = [s + min(c, t) for s, c, t in zip(starts, current_shape, target_shape)]
crop_starts = [max(0, (c - t) // 2) for c, t in zip(current_shape, target_shape)]
crop_ends = [cs + (e - s) for cs, e, s in zip(crop_starts, ends, starts)]
output[
starts[0]:ends[0],
starts[1]:ends[1],
starts[2]:ends[2]
] = image[
crop_starts[0]:crop_ends[0],
crop_starts[1]:crop_ends[1],
crop_starts[2]:crop_ends[2]
]
return output
def _align_to_template(
self,
image: np.ndarray,
template: np.ndarray,
landmarks: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, Dict]:
"""
Align image to anatomical template using registration.
Uses landmark-based alignment if provided, otherwise
intensity-based registration.
"""
if landmarks is not None:
# Landmark-based alignment
aligned, params = self._landmark_registration(
image, template, landmarks
)
else:
# Intensity-based rigid registration
aligned, params = self._intensity_registration(
image, template
)
return aligned, params
def _landmark_registration(
self,
image: np.ndarray,
template: np.ndarray,
landmarks: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""
Register using anatomical landmarks.
More robust than intensity-based for images with artifacts
or systematically different intensity distributions.
"""
# Simplified landmark-based affine registration
# Production code would use robust estimation
# For now, return image with transformation metadata
params = {
'method': 'landmark',
'n_landmarks': len(landmarks)
}
return image, params
def _intensity_registration(
self,
image: np.ndarray,
template: np.ndarray
) -> Tuple[np.ndarray, Dict]:
"""Intensity-based rigid registration."""
# Simplified registration
# Production code would use proper registration library
params = {
'method': 'intensity',
'registration_metric': 'mutual_information'
}
return image, params
The anatomical standardization framework enables models to generalize across patients with different body habitus, ages, and anatomical variants while maintaining diagnostic accuracy.
7.3 Data Augmentation for Fairness in Medical Imaging
Data augmentation is essential for training robust deep learning models, but standard augmentation strategies developed for natural images can be inappropriate or even harmful for medical imaging. We must design augmentation approaches that increase model robustness to clinically irrelevant variations while preserving diagnostic features and maintaining fairness across patient populations.
7.3.1 Physics-Informed Augmentation
Medical images are governed by the physics of their acquisition modalities. Data augmentation should respect these physical constraints while simulating realistic acquisition variations that the model will encounter across different healthcare settings.
"""
Physics-Informed Medical Image Augmentation
Implements augmentation strategies that respect physics of medical imaging
while simulating realistic variations in acquisition parameters and equipment
that correlate with healthcare setting and patient demographics.
"""
import numpy as np
from typing import Optional, Callable, List, Tuple
import logging
from scipy import ndimage
from skimage import filters, transform
import torch
logger = logging.getLogger(__name__)
class PhysicsInformedAugmenter:
"""
Medical image augmentation that simulates realistic acquisition variations.
Includes modality-specific augmentations for:
- Radiography: noise, scatter, exposure variation
- CT: beam hardening, metal artifacts
- MRI: motion, intensity inhomogeneity, Gibbs ringing
- Ultrasound: speckle noise, shadowing, attenuation
"""
def __init__(
self,
modality: str,
augmentation_strength: str = 'moderate',
ensure_fairness: bool = True
):
"""
Initialize physics-informed augmenter.
Args:
modality: Imaging modality ('xray', 'ct', 'mri', 'ultrasound')
augmentation_strength: 'mild', 'moderate', or 'aggressive'
ensure_fairness: Whether to track augmentation across groups
"""
self.modality = modality.lower()
self.augmentation_strength = augmentation_strength
self.ensure_fairness = ensure_fairness
# Define strength levels
strength_scales = {
'mild': 0.3,
'moderate': 0.6,
'aggressive': 0.9
}
self.strength_scale = strength_scales[augmentation_strength]
# Track augmentation statistics if ensuring fairness
self.augmentation_stats = {} if ensure_fairness else None
logger.info(
f"Initialized {modality} augmenter "
f"(strength={augmentation_strength})"
)
def augment(
self,
image: np.ndarray,
group_id: Optional[str] = None,
seed: Optional[int] = None
) -> np.ndarray:
"""
Apply physics-informed augmentation to medical image.
Args:
image: Input image array
group_id: Optional demographic group for fairness tracking
seed: Random seed for reproducibility
Returns:
Augmented image
"""
if seed is not None:
np.random.seed(seed)
# Apply modality-specific augmentations
if self.modality == 'xray':
augmented = self._augment_xray(image)
elif self.modality == 'ct':
augmented = self._augment_ct(image)
elif self.modality == 'mri':
augmented = self._augment_mri(image)
elif self.modality == 'ultrasound':
augmented = self._augment_ultrasound(image)
else:
logger.warning(f"Unknown modality {self.modality}, no augmentation")
augmented = image
# Track augmentation if ensuring fairness
if self.ensure_fairness and group_id is not None:
self._track_augmentation(augmented, group_id)
return augmented
def _augment_xray(self, image: np.ndarray) -> np.ndarray:
"""
Augment X-ray image with realistic acquisition variations.
Simulates variations in:
- Exposure (kVp, mAs)
- Scatter radiation
- Detector noise
- Grid artifacts
"""
augmented = image.copy()
# Exposure variation (simulates different kVp/mAs settings)
if np.random.rand() < 0.7:
exposure_factor = np.random.uniform(
1.0 - 0.3 * self.strength_scale,
1.0 + 0.3 * self.strength_scale
)
augmented = augmented * exposure_factor
# Scatter simulation (adds low-frequency background)
if np.random.rand() < 0.5:
scatter_intensity = self.strength_scale * 0.2
kernel_size = int(min(image.shape) * 0.1)
kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
scatter_map = ndimage.gaussian_filter(
np.random.randn(*image.shape),
sigma=kernel_size / 3
)
scatter_map = scatter_map / np.abs(scatter_map).max()
augmented = augmented + scatter_intensity * image.mean() * scatter_map
# Detector noise (Poisson + Gaussian)
if np.random.rand() < 0.8:
# Poisson noise (signal-dependent)
noise_scale = self.strength_scale * 0.1
poisson_noise = np.random.poisson(
np.maximum(augmented / noise_scale, 0)
) * noise_scale - augmented
# Gaussian noise (electronic)
gaussian_noise = np.random.normal(
0,
image.std() * 0.05 * self.strength_scale,
image.shape
)
augmented = augmented + 0.7 * poisson_noise + 0.3 * gaussian_noise
# Grid artifacts (anti-scatter grid)
if np.random.rand() < 0.3:
grid_period = np.random.randint(20, 40)
grid_amplitude = image.mean() * 0.05 * self.strength_scale
x = np.arange(image.shape[1])
grid_pattern = grid_amplitude * np.sin(2 * np.pi * x / grid_period)
augmented = augmented + grid_pattern[np.newaxis, :]
return augmented
def _augment_ct(self, image: np.ndarray) -> np.ndarray:
"""
Augment CT image with realistic artifacts.
Simulates:
- Beam hardening
- Metal artifacts
- Photon starvation
- Ring artifacts
"""
augmented = image.copy()
# Beam hardening (cupping artifact)
if np.random.rand() < 0.5:
center = np.array(image.shape) / 2
y, x = np.ogrid[:image.shape[0], :image.shape[1]]
distances = np.sqrt((x - center[1])**2 + (y - center[0])**2)
max_distance = np.sqrt(center[0]**2 + center[1]**2)
cupping_strength = self.strength_scale * 0.15
cupping = 1.0 - cupping_strength * (distances / max_distance)**2
augmented = augmented * cupping
# Metal artifacts (streak artifacts)
if np.random.rand() < 0.3:
num_streaks = np.random.randint(2, 6)
for _ in range(num_streaks):
angle = np.random.uniform(0, np.pi)
width = np.random.randint(1, 3)
intensity = np.random.uniform(
-100 * self.strength_scale,
100 * self.strength_scale
)
# Create streak
streak = np.zeros_like(image)
center = np.array(image.shape) / 2
length = int(max(image.shape) * 0.8)
for i in range(-length // 2, length // 2):
x = int(center[1] + i * np.cos(angle))
y = int(center[0] + i * np.sin(angle))
if 0 <= y < image.shape[0] and 0 <= x < image.shape[1]:
streak[
max(0, y-width):min(image.shape[0], y+width+1),
max(0, x-width):min(image.shape[1], x+width+1)
] = intensity
augmented = augmented + streak
# Photon starvation noise
if np.random.rand() < 0.6:
noise_std = image.std() * 0.1 * self.strength_scale
augmented = augmented + np.random.normal(0, noise_std, image.shape)
return augmented
def _augment_mri(self, image: np.ndarray) -> np.ndarray:
"""
Augment MRI image with realistic artifacts.
Simulates:
- Intensity inhomogeneity (bias field)
- Motion artifacts
- Gibbs ringing
- RF interference
"""
augmented = image.copy()
# Bias field (intensity inhomogeneity)
if np.random.rand() < 0.7:
# Generate smooth bias field
low_res_shape = tuple(s // 4 for s in image.shape)
bias_field = np.random.randn(*low_res_shape)
# Upsample to full resolution
bias_field = ndimage.zoom(
bias_field,
tuple(s / lr for s, lr in zip(image.shape, low_res_shape)),
order=3
)
# Normalize and scale
bias_field = (bias_field - bias_field.mean()) / bias_field.std()
bias_strength = self.strength_scale * 0.3
bias_field = np.exp(bias_strength * bias_field)
augmented = augmented * bias_field
# Motion artifacts
if np.random.rand() < 0.4:
# Simulate motion as phase shifts in k-space
num_motion_events = np.random.randint(1, 4)
for _ in range(num_motion_events):
# Simple motion simulation
shift = np.random.randint(-5, 6, size=2)
motion_artifact = ndimage.shift(
image,
shift,
mode='constant',
cval=image.mean()
)
blend_weight = 0.3 * self.strength_scale
augmented = (1 - blend_weight) * augmented + \
blend_weight * motion_artifact
# Gibbs ringing
if np.random.rand() < 0.5:
# Apply truncation in k-space
fft = np.fft.fft2(augmented)
fft_shifted = np.fft.fftshift(fft)
# Truncate high frequencies
truncation_factor = 1.0 - 0.2 * self.strength_scale
mask_size = tuple(int(s * truncation_factor) for s in image.shape)
mask = np.zeros_like(fft_shifted)
start = tuple((s - ms) // 2 for s, ms in zip(image.shape, mask_size))
end = tuple(st + ms for st, ms in zip(start, mask_size))
mask[start[0]:end[0], start[1]:end[1]] = 1
fft_truncated = fft_shifted * mask
augmented = np.real(np.fft.ifft2(np.fft.ifftshift(fft_truncated)))
# Rician noise
if np.random.rand() < 0.6:
noise_std = image.std() * 0.05 * self.strength_scale
noise_real = np.random.normal(0, noise_std, image.shape)
noise_imag = np.random.normal(0, noise_std, image.shape)
# Rician distribution
augmented = np.sqrt((augmented + noise_real)**2 + noise_imag**2)
return augmented
def _augment_ultrasound(self, image: np.ndarray) -> np.ndarray:
"""
Augment ultrasound image with realistic artifacts.
Simulates:
- Speckle noise
- Attenuation
- Shadowing
- Enhancement
"""
augmented = image.copy()
# Speckle noise (multiplicative)
if np.random.rand() < 0.8:
speckle_std = self.strength_scale * 0.3
speckle = np.random.normal(1.0, speckle_std, image.shape)
augmented = augmented * speckle
# Attenuation (depth-dependent signal loss)
if np.random.rand() < 0.7:
depth_axis = 0 # Assume depth is first axis
attenuation_coef = self.strength_scale * 0.02
depth_profile = np.exp(
-attenuation_coef * np.arange(image.shape[depth_axis])
)
# Broadcast to full image shape
attenuation_map = depth_profile.reshape(-1, 1)
if len(image.shape) == 3:
attenuation_map = attenuation_map.reshape(-1, 1, 1)
augmented = augmented * attenuation_map
# Shadowing (reduced signal behind strongly attenuating structures)
if np.random.rand() < 0.4:
# Create random shadow regions
num_shadows = np.random.randint(1, 4)
for _ in range(num_shadows):
shadow_start = np.random.randint(0, image.shape[0] // 2)
shadow_width = np.random.randint(
image.shape[1] // 10,
image.shape[1] // 4
)
shadow_center = np.random.randint(
shadow_width,
image.shape[1] - shadow_width
)
shadow_mask = np.zeros(image.shape[1])
shadow_mask[
shadow_center - shadow_width:shadow_center + shadow_width
] = 1
# Apply shadow with depth-dependent effect
shadow_strength = 0.5 * self.strength_scale
for d in range(shadow_start, image.shape[0]):
depth_factor = (d - shadow_start) / (image.shape[0] - shadow_start)
augmented[d] = augmented[d] * (
1 - shadow_strength * depth_factor * shadow_mask
)
return augmented
def _track_augmentation(
self,
augmented_image: np.ndarray,
group_id: str
) -> None:
"""Track augmentation statistics for fairness monitoring."""
if group_id not in self.augmentation_stats:
self.augmentation_stats[group_id] = {
'count': 0,
'mean_intensity': [],
'std_intensity': []
}
self.augmentation_stats[group_id]['count'] += 1
self.augmentation_stats[group_id]['mean_intensity'].append(
float(augmented_image.mean())
)
self.augmentation_stats[group_id]['std_intensity'].append(
float(augmented_image.std())
)
def get_fairness_report(self) -> Dict:
"""Generate report on augmentation fairness across groups."""
if not self.ensure_fairness or not self.augmentation_stats:
return {}
report = {}
for group_id, stats in self.augmentation_stats.items():
report[group_id] = {
'n_augmented': stats['count'],
'mean_intensity_avg': np.mean(stats['mean_intensity']),
'mean_intensity_std': np.std(stats['mean_intensity']),
'std_intensity_avg': np.mean(stats['std_intensity']),
'std_intensity_std': np.std(stats['std_intensity'])
}
# Compute disparity metrics
if len(report) > 1:
groups = list(report.keys())
mean_avgs = [report[g]['mean_intensity_avg'] for g in groups]
std_avgs = [report[g]['std_intensity_avg'] for g in groups]
report['disparity'] = {
'mean_intensity_range': max(mean_avgs) - min(mean_avgs),
'std_intensity_range': max(std_avgs) - min(std_avgs),
'mean_intensity_cv': np.std(mean_avgs) / np.mean(mean_avgs),
'std_intensity_cv': np.std(std_avgs) / np.mean(std_avgs)
}
return report
class GeometricAugmenter:
"""
Geometric augmentation for medical images with anatomical constraints.
Unlike natural images, medical images have anatomical constraints that
must be respected. Not all rotations, flips, and deformations are
anatomically plausible.
"""
def __init__(
self,
rotation_range: float = 10.0,
translation_range: float = 0.1,
scaling_range: Tuple[float, float] = (0.9, 1.1),
allow_horizontal_flip: bool = False, # Often not anatomically valid
allow_vertical_flip: bool = False,
elastic_deformation: bool = True
):
"""
Initialize geometric augmenter.
Args:
rotation_range: Maximum rotation in degrees
translation_range: Maximum translation as fraction of image size
scaling_range: (min_scale, max_scale)
allow_horizontal_flip: Whether horizontal flip is anatomically valid
allow_vertical_flip: Whether vertical flip is anatomically valid
elastic_deformation: Whether to apply elastic deformations
"""
self.rotation_range = rotation_range
self.translation_range = translation_range
self.scaling_range = scaling_range
self.allow_horizontal_flip = allow_horizontal_flip
self.allow_vertical_flip = allow_vertical_flip
self.elastic_deformation = elastic_deformation
def augment(
self,
image: np.ndarray,
mask: Optional[np.ndarray] = None,
seed: Optional[int] = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Apply geometric augmentation to image and optional mask.
Args:
image: Input image
mask: Optional segmentation mask to transform consistently
seed: Random seed
Returns:
Tuple of (augmented image, augmented mask)
"""
if seed is not None:
np.random.seed(seed)
augmented_image = image.copy()
augmented_mask = mask.copy() if mask is not None else None
# Rotation
if self.rotation_range > 0:
angle = np.random.uniform(-self.rotation_range, self.rotation_range)
augmented_image = ndimage.rotate(
augmented_image,
angle,
reshape=False,
mode='constant',
cval=image.min()
)
if augmented_mask is not None:
augmented_mask = ndimage.rotate(
augmented_mask,
angle,
reshape=False,
order=0, # Nearest neighbor for masks
mode='constant',
cval=0
)
# Translation
if self.translation_range > 0:
max_shift = tuple(
int(s * self.translation_range) for s in image.shape
)
shifts = tuple(
np.random.randint(-ms, ms + 1) for ms in max_shift
)
augmented_image = ndimage.shift(
augmented_image,
shifts,
mode='constant',
cval=image.min()
)
if augmented_mask is not None:
augmented_mask = ndimage.shift(
augmented_mask,
shifts,
order=0,
mode='constant',
cval=0
)
# Scaling
if self.scaling_range != (1.0, 1.0):
scale = np.random.uniform(*self.scaling_range)
# Zoom and then crop/pad back to original size
zoomed = ndimage.zoom(
augmented_image,
scale,
order=1,
mode='constant',
cval=image.min()
)
# Crop or pad to original size
augmented_image = self._crop_or_pad_to_shape(
zoomed,
image.shape,
fill_value=image.min()
)
if augmented_mask is not None:
zoomed_mask = ndimage.zoom(
augmented_mask,
scale,
order=0,
mode='constant',
cval=0
)
augmented_mask = self._crop_or_pad_to_shape(
zoomed_mask,
mask.shape,
fill_value=0
)
# Horizontal flip
if self.allow_horizontal_flip and np.random.rand() < 0.5:
augmented_image = np.flip(augmented_image, axis=1)
if augmented_mask is not None:
augmented_mask = np.flip(augmented_mask, axis=1)
# Vertical flip
if self.allow_vertical_flip and np.random.rand() < 0.5:
augmented_image = np.flip(augmented_image, axis=0)
if augmented_mask is not None:
augmented_mask = np.flip(augmented_mask, axis=0)
# Elastic deformation
if self.elastic_deformation and np.random.rand() < 0.5:
augmented_image = self._elastic_transform(augmented_image)
if augmented_mask is not None:
augmented_mask = self._elastic_transform(
augmented_mask,
order=0
)
return augmented_image, augmented_mask
def _crop_or_pad_to_shape(
self,
array: np.ndarray,
target_shape: Tuple[int, ...],
fill_value: float = 0
) -> np.ndarray:
"""Crop or pad array to target shape."""
output = np.full(target_shape, fill_value, dtype=array.dtype)
# Compute slices for centering
starts = [max(0, (t - c) // 2) for t, c in zip(target_shape, array.shape)]
ends = [s + min(c, t) for s, c, t in zip(starts, array.shape, target_shape)]
crop_starts = [max(0, (c - t) // 2) for c, t in zip(array.shape, target_shape)]
crop_ends = [cs + (e - s) for cs, e, s in zip(crop_starts, ends, starts)]
if len(target_shape) == 2:
output[starts[0]:ends[0], starts[1]:ends[1]] = \
array[crop_starts[0]:crop_ends[0], crop_starts[1]:crop_ends[1]]
else:
output[
starts[0]:ends[0],
starts[1]:ends[1],
starts[2]:ends[2]
] = array[
crop_starts[0]:crop_ends[0],
crop_starts[1]:crop_ends[1],
crop_starts[2]:crop_ends[2]
]
return output
def _elastic_transform(
self,
image: np.ndarray,
alpha: float = 30,
sigma: float = 5,
order: int = 1
) -> np.ndarray:
"""
Apply elastic deformation to simulate anatomical variation.
Models realistic soft tissue deformation.
"""
shape = image.shape
# Generate random displacement fields
dx = ndimage.gaussian_filter(
np.random.randn(*shape),
sigma
) * alpha
dy = ndimage.gaussian_filter(
np.random.randn(*shape),
sigma
) * alpha
# Create coordinate arrays
if len(shape) == 2:
y, x = np.meshgrid(
np.arange(shape[0]),
np.arange(shape[1]),
indexing='ij'
)
indices = (y + dy, x + dx)
else:
z, y, x = np.meshgrid(
np.arange(shape[0]),
np.arange(shape[1]),
np.arange(shape[2]),
indexing='ij'
)
dz = ndimage.gaussian_filter(
np.random.randn(*shape),
sigma
) * alpha
indices = (z + dz, y + dy, x + dx)
# Apply transformation
transformed = ndimage.map_coordinates(
image,
indices,
order=order,
mode='constant',
cval=image.min() if order > 0 else 0
)
return transformed
This augmentation framework provides physics-informed transformations that increase model robustness to clinically irrelevant variations while respecting the constraints of medical imaging physics and anatomy.
7.4 Segmentation with Fairness Constraints
Semantic segmentation assigns a class label to each pixel in an image, enabling localization and quantification of anatomical structures and pathological regions. In medical imaging, segmentation is fundamental to applications ranging from tumor volume estimation to organ morphometry to surgical planning. However, segmentation models can exhibit systematic performance disparities when anatomical structure varies across patient demographics or when image quality differs by care setting.
7.4.1 U-Net Architecture with Equity Considerations
The U-Net architecture has become the standard for medical image segmentation due to its ability to combine low-level localization information with high-level semantic understanding through skip connections. We implement a U-Net variant with explicit fairness considerations:
"""
Fair Medical Image Segmentation with U-Net
Implements U-Net architecture with fairness-aware training and
comprehensive evaluation across demographic groups and care settings.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import numpy as np
import logging
logger = logging.getLogger(__name__)
class FairUNet(nn.Module):
"""
U-Net for medical image segmentation with fairness monitoring.
Includes:
- Standard U-Net architecture with skip connections
- Group-aware batch normalization for handling site differences
- Fairness-constrained loss functions
- Comprehensive evaluation across demographic strata
"""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 2,
base_channels: int = 64,
depth: int = 4,
use_group_norm: bool = True,
dropout_rate: float = 0.1
):
"""
Initialize Fair U-Net.
Args:
in_channels: Number of input channels
out_channels: Number of output segmentation classes
base_channels: Number of channels in first layer
depth: Depth of U-Net (number of downsampling steps)
use_group_norm: Use group norm instead of batch norm
dropout_rate: Dropout rate for regularization
"""
super(FairUNet, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.base_channels = base_channels
self.depth = depth
self.dropout_rate = dropout_rate
# Encoder path
self.encoder_blocks = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
in_ch = in_channels
for i in range(depth):
out_ch = base_channels * (2 ** i)
self.encoder_blocks.append(
self._make_encoder_block(
in_ch,
out_ch,
use_group_norm
)
)
in_ch = out_ch
# Bottleneck
bottleneck_ch = base_channels * (2 ** depth)
self.bottleneck = self._make_encoder_block(
in_ch,
bottleneck_ch,
use_group_norm
)
# Decoder path
self.decoder_blocks = nn.ModuleList()
self.upconv_blocks = nn.ModuleList()
for i in range(depth):
in_ch = bottleneck_ch if i == 0 else base_channels * (2 ** (depth - i + 1))
out_ch = base_channels * (2 ** (depth - i - 1))
self.upconv_blocks.append(
nn.ConvTranspose2d(
in_ch,
out_ch,
kernel_size=2,
stride=2
)
)
self.decoder_blocks.append(
self._make_decoder_block(
out_ch * 2, # Concatenated with skip connection
out_ch,
use_group_norm
)
)
# Final convolution
self.final_conv = nn.Conv2d(
base_channels,
out_channels,
kernel_size=1
)
# Dropout
self.dropout = nn.Dropout2d(dropout_rate)
logger.info(
f"Initialized Fair U-Net: "
f"depth={depth}, base_ch={base_channels}, "
f"in_ch={in_channels}, out_ch={out_channels}"
)
def _make_encoder_block(
self,
in_channels: int,
out_channels: int,
use_group_norm: bool = True
) -> nn.Module:
"""Create encoder block with two convolutions."""
if use_group_norm:
norm1 = nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels)
norm2 = nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels)
else:
norm1 = nn.BatchNorm2d(out_channels)
norm2 = nn.BatchNorm2d(out_channels)
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
norm1,
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
norm2,
nn.ReLU(inplace=True)
)
def _make_decoder_block(
self,
in_channels: int,
out_channels: int,
use_group_norm: bool = True
) -> nn.Module:
"""Create decoder block with two convolutions."""
return self._make_encoder_block(in_channels, out_channels, use_group_norm)
def forward(
self,
x: torch.Tensor,
return_features: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Forward pass through U-Net.
Args:
x: Input tensor of shape (batch, in_channels, height, width)
return_features: Whether to return intermediate features
Returns:
Segmentation logits (and optionally feature maps)
"""
# Encoder path with skip connections
skip_connections = []
for encoder_block in self.encoder_blocks:
x = encoder_block(x)
skip_connections.append(x)
x = self.pool(x)
x = self.dropout(x)
# Bottleneck
x = self.bottleneck(x)
# Decoder path
for i, (upconv, decoder_block) in enumerate(
zip(self.upconv_blocks, self.decoder_blocks)
):
x = upconv(x)
# Get corresponding skip connection
skip = skip_connections[-(i + 1)]
# Handle size mismatch due to odd dimensions
if x.shape != skip.shape:
x = F.interpolate(
x,
size=skip.shape[2:],
mode='bilinear',
align_corners=True
)
# Concatenate skip connection
x = torch.cat([x, skip], dim=1)
# Decoder block
x = decoder_block(x)
x = self.dropout(x)
# Final convolution
logits = self.final_conv(x)
if return_features:
return logits, skip_connections
else:
return logits
class FairnessAwareSegmentationLoss(nn.Module):
"""
Loss function for fair medical image segmentation.
Combines standard segmentation loss (Dice + Cross Entropy) with
fairness regularization that penalizes performance disparities
across protected groups.
"""
def __init__(
self,
num_classes: int,
fairness_weight: float = 0.1,
class_weights: Optional[torch.Tensor] = None
):
"""
Initialize fairness-aware segmentation loss.
Args:
num_classes: Number of segmentation classes
fairness_weight: Weight for fairness regularization term
class_weights: Optional class weights for imbalanced data
"""
super(FairnessAwareSegmentationLoss, self).__init__()
self.num_classes = num_classes
self.fairness_weight = fairness_weight
if class_weights is not None:
self.register_buffer('class_weights', class_weights)
else:
self.class_weights = None
def forward(
self,
predictions: torch.Tensor,
targets: torch.Tensor,
group_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute fairness-aware segmentation loss.
Args:
predictions: Model predictions (batch, classes, height, width)
targets: Ground truth (batch, height, width)
group_ids: Optional group identifiers (batch,)
Returns:
Total loss and dictionary of loss components
"""
# Standard segmentation loss
dice_loss = self._dice_loss(predictions, targets)
ce_loss = self._cross_entropy_loss(predictions, targets)
seg_loss = 0.5 * dice_loss + 0.5 * ce_loss
loss_dict = {
'dice_loss': dice_loss.item(),
'ce_loss': ce_loss.item(),
'seg_loss': seg_loss.item()
}
total_loss = seg_loss
# Add fairness regularization if group IDs provided
if group_ids is not None and self.fairness_weight > 0:
fairness_loss = self._fairness_regularization(
predictions,
targets,
group_ids
)
total_loss = total_loss + self.fairness_weight * fairness_loss
loss_dict['fairness_loss'] = fairness_loss.item()
loss_dict['total_loss'] = total_loss.item()
return total_loss, loss_dict
def _dice_loss(
self,
predictions: torch.Tensor,
targets: torch.Tensor,
smooth: float = 1e-6
) -> torch.Tensor:
"""
Compute Dice loss.
Dice coefficient is $2 \times \dfrac{\lvert A \cap B \rvert}{\lvert A \rvert + \lvert B \rvert}$
Dice loss is $1 -$ Dice coefficient
"""
# Convert predictions to probabilities
probs = F.softmax(predictions, dim=1)
# One-hot encode targets
targets_one_hot = F.one_hot(
targets.long(),
num_classes=self.num_classes
).permute(0, 3, 1, 2).float()
# Compute Dice for each class
dice_scores = []
for c in range(self.num_classes):
pred_c = probs[:, c, :, :]
target_c = targets_one_hot[:, c, :, :]
intersection = (pred_c * target_c).sum()
union = pred_c.sum() + target_c.sum()
dice = (2.0 * intersection + smooth) / (union + smooth)
dice_scores.append(dice)
# Average across classes (optionally weighted)
if self.class_weights is not None:
dice_loss = 1 - sum(
w * d for w, d in zip(self.class_weights, dice_scores)
) / self.class_weights.sum()
else:
dice_loss = 1 - sum(dice_scores) / len(dice_scores)
return dice_loss
def _cross_entropy_loss(
self,
predictions: torch.Tensor,
targets: torch.Tensor
) -> torch.Tensor:
"""Compute cross-entropy loss."""
return F.cross_entropy(
predictions,
targets.long(),
weight=self.class_weights,
reduction='mean'
)
def _fairness_regularization(
self,
predictions: torch.Tensor,
targets: torch.Tensor,
group_ids: torch.Tensor
) -> torch.Tensor:
"""
Compute fairness regularization term.
Penalizes variance in Dice scores across demographic groups.
"""
unique_groups = group_ids.unique()
group_dice_scores = []
for group in unique_groups:
group_mask = group_ids == group
if group_mask.sum() == 0:
continue
group_preds = predictions[group_mask]
group_targets = targets[group_mask]
# Compute Dice for this group
probs = F.softmax(group_preds, dim=1)
targets_one_hot = F.one_hot(
group_targets.long(),
num_classes=self.num_classes
).permute(0, 3, 1, 2).float()
dice = 0
for c in range(self.num_classes):
pred_c = probs[:, c, :, :]
target_c = targets_one_hot[:, c, :, :]
intersection = (pred_c * target_c).sum()
union = pred_c.sum() + target_c.sum()
dice += (2.0 * intersection + 1e-6) / (union + 1e-6)
dice /= self.num_classes
group_dice_scores.append(dice)
if len(group_dice_scores) < 2:
return torch.tensor(0.0, device=predictions.device)
# Compute variance of Dice scores across groups
group_dice_tensor = torch.stack(group_dice_scores)
fairness_loss = group_dice_tensor.var()
return fairness_loss
def evaluate_segmentation_fairness(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
group_variable: str,
device: str = 'cuda'
) -> Dict:
"""
Evaluate segmentation model fairness across demographic groups.
Args:
model: Trained segmentation model
dataloader: DataLoader with demographic metadata
group_variable: Name of grouping variable to stratify by
device: Device for computation
Returns:
Dictionary of fairness metrics stratified by group
"""
model.eval()
model.to(device)
group_metrics = {}
with torch.no_grad():
for batch in dataloader:
images = batch['image'].to(device)
masks = batch['mask'].to(device)
metadata = batch['metadata']
# Get predictions
logits = model(images)
predictions = torch.argmax(logits, dim=1)
# Group by demographic variable
for i in range(len(images)):
group = metadata[group_variable][i]
if group not in group_metrics:
group_metrics[group] = {
'dice_scores': [],
'iou_scores': [],
'n_examples': 0
}
pred = predictions[i].cpu().numpy()
target = masks[i].cpu().numpy()
# Compute metrics
dice = compute_dice(pred, target)
iou = compute_iou(pred, target)
group_metrics[group]['dice_scores'].append(dice)
group_metrics[group]['iou_scores'].append(iou)
group_metrics[group]['n_examples'] += 1
# Aggregate metrics per group
results = {}
for group, metrics in group_metrics.items():
results[group] = {
'n_examples': metrics['n_examples'],
'dice_mean': np.mean(metrics['dice_scores']),
'dice_std': np.std(metrics['dice_scores']),
'iou_mean': np.mean(metrics['iou_scores']),
'iou_std': np.std(metrics['iou_scores'])
}
# Compute disparity metrics
if len(results) > 1:
groups = list(results.keys())
dice_means = [results[g]['dice_mean'] for g in groups]
iou_means = [results[g]['iou_mean'] for g in groups]
results['disparity'] = {
'dice_range': max(dice_means) - min(dice_means),
'dice_ratio': max(dice_means) / (min(dice_means) + 1e-8),
'iou_range': max(iou_means) - min(iou_means),
'iou_ratio': max(iou_means) / (min(iou_means) + 1e-8)
}
return results
def compute_dice(pred: np.ndarray, target: np.ndarray) -> float:
"""Compute Dice coefficient between prediction and target."""
intersection = np.logical_and(pred, target).sum()
return (2.0 * intersection) / (pred.sum() + target.sum() + 1e-8)
def compute_iou(pred: np.ndarray, target: np.ndarray) -> float:
"""Compute Intersection over Union."""
intersection = np.logical_and(pred, target).sum()
union = np.logical_or(pred, target).sum()
return intersection / (union + 1e-8)
This segmentation framework provides the foundation for building fair medical image segmentation systems with comprehensive evaluation across demographic groups and care settings.
7.5 Detection and Classification with Fairness
Object detection localizes instances of objects within images, while classification assigns categorical labels to entire images or regions. In medical imaging, detection tasks include identifying lesions, anatomical landmarks, or medical devices, while classification tasks include diagnosis from imaging studies.
7.5.1 Multi-Task Learning for Fair Classification
We implement a multi-task learning framework that jointly optimizes for diagnostic accuracy and fairness across demographic groups:
"""
Fair Multi-Task Medical Image Classification
Implements classification with explicit fairness objectives using
multi-task learning and adversarial debiasing approaches.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
import logging
logger = logging.getLogger(__name__)
class FairMultiTaskClassifier(nn.Module):
"""
Multi-task classifier with fairness constraints.
Learns diagnostic task while preventing the feature representation
from encoding protected demographic attributes that could lead to
biased predictions.
"""
def __init__(
self,
backbone: str = 'resnet50',
num_classes: int = 2,
num_groups: int = 2,
hidden_dim: int = 512,
use_adversarial: bool = True,
pretrained: bool = True
):
"""
Initialize fair multi-task classifier.
Args:
backbone: Feature extraction backbone
num_classes: Number of diagnostic classes
num_groups: Number of demographic groups
hidden_dim: Hidden dimension for classifiers
use_adversarial: Whether to use adversarial debiasing
pretrained: Use ImageNet pre-trained weights
"""
super(FairMultiTaskClassifier, self).__init__()
self.num_classes = num_classes
self.num_groups = num_groups
self.use_adversarial = use_adversarial
# Feature extractor
if backbone == 'resnet50':
from torchvision import models
base_model = models.resnet50(pretrained=pretrained)
self.feature_extractor = nn.Sequential(
*list(base_model.children())[:-1]
)
feature_dim = 2048
elif backbone == 'efficientnet_b0':
from torchvision import models
base_model = models.efficientnet_b0(pretrained=pretrained)
self.feature_extractor = nn.Sequential(
*list(base_model.children())[:-1]
)
feature_dim = 1280
else:
raise ValueError(f"Unknown backbone: {backbone}")
# Diagnostic classifier
self.diagnostic_classifier = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, num_classes)
)
# Adversarial demographic predictor (if using adversarial debiasing)
if use_adversarial:
self.demographic_predictor = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, num_groups)
)
# Gradient reversal layer
self.gradient_reversal = GradientReversalLayer()
logger.info(
f"Initialized Fair Multi-Task Classifier: "
f"backbone={backbone}, classes={num_classes}, "
f"groups={num_groups}, adversarial={use_adversarial}"
)
def forward(
self,
x: torch.Tensor,
alpha: float = 1.0
) -> Dict[str, torch.Tensor]:
"""
Forward pass through multi-task classifier.
Args:
x: Input images (batch, channels, height, width)
alpha: Gradient reversal strength for adversarial training
Returns:
Dictionary with diagnostic and demographic predictions
"""
# Extract features
features = self.feature_extractor(x)
features = features.view(features.size(0), -1)
# Diagnostic prediction
diagnostic_logits = self.diagnostic_classifier(features)
outputs = {
'diagnostic_logits': diagnostic_logits,
'features': features
}
# Demographic prediction with gradient reversal
if self.use_adversarial:
reversed_features = self.gradient_reversal(features, alpha)
demographic_logits = self.demographic_predictor(reversed_features)
outputs['demographic_logits'] = demographic_logits
return outputs
class GradientReversalLayer(torch.autograd.Function):
"""
Gradient reversal layer for adversarial training.
During forward pass, acts as identity. During backward pass,
reverses gradients, enabling adversarial debiasing where we
optimize features to be predictive of diagnosis but NOT
predictive of demographic group.
"""
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
def gradient_reversal_layer(x, alpha=1.0):
"""Functional interface to gradient reversal layer."""
return GradientReversalLayer.apply(x, alpha)
class FairMultiTaskLoss(nn.Module):
"""
Loss function for fair multi-task learning.
Balances diagnostic accuracy with fairness objectives including:
- Demographic parity (equalizing positive rates across groups)
- Equalized odds (equalizing TPR and FPR across groups)
- Calibration fairness (equalizing calibration across groups)
"""
def __init__(
self,
fairness_criterion: str = 'demographic_parity',
fairness_weight: float = 0.1,
adversarial_weight: float = 1.0
):
"""
Initialize fair multi-task loss.
Args:
fairness_criterion: 'demographic_parity', 'equalized_odds', or 'calibration'
fairness_weight: Weight for fairness regularization
adversarial_weight: Weight for adversarial demographic prediction
"""
super(FairMultiTaskLoss, self).__init__()
self.fairness_criterion = fairness_criterion
self.fairness_weight = fairness_weight
self.adversarial_weight = adversarial_weight
self.diagnostic_loss_fn = nn.CrossEntropyLoss()
self.demographic_loss_fn = nn.CrossEntropyLoss()
def forward(
self,
outputs: Dict[str, torch.Tensor],
diagnostic_labels: torch.Tensor,
demographic_labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute fair multi-task loss.
Args:
outputs: Model outputs dictionary
diagnostic_labels: Ground truth diagnostic labels
demographic_labels: Ground truth demographic group labels
Returns:
Total loss and dictionary of loss components
"""
# Diagnostic classification loss
diagnostic_loss = self.diagnostic_loss_fn(
outputs['diagnostic_logits'],
diagnostic_labels
)
loss_dict = {
'diagnostic_loss': diagnostic_loss.item()
}
total_loss = diagnostic_loss
# Adversarial demographic prediction loss
if 'demographic_logits' in outputs and self.adversarial_weight > 0:
demographic_loss = self.demographic_loss_fn(
outputs['demographic_logits'],
demographic_labels
)
total_loss = total_loss + self.adversarial_weight * demographic_loss
loss_dict['demographic_loss'] = demographic_loss.item()
# Fairness regularization
if self.fairness_weight > 0:
fairness_loss = self._compute_fairness_loss(
outputs['diagnostic_logits'],
diagnostic_labels,
demographic_labels
)
total_loss = total_loss + self.fairness_weight * fairness_loss
loss_dict['fairness_loss'] = fairness_loss.item()
loss_dict['total_loss'] = total_loss.item()
return total_loss, loss_dict
def _compute_fairness_loss(
self,
logits: torch.Tensor,
diagnostic_labels: torch.Tensor,
demographic_labels: torch.Tensor
) -> torch.Tensor:
"""Compute fairness regularization term based on criterion."""
if self.fairness_criterion == 'demographic_parity':
return self._demographic_parity_loss(logits, demographic_labels)
elif self.fairness_criterion == 'equalized_odds':
return self._equalized_odds_loss(logits, diagnostic_labels, demographic_labels)
elif self.fairness_criterion == 'calibration':
return self._calibration_fairness_loss(logits, diagnostic_labels, demographic_labels)
else:
return torch.tensor(0.0, device=logits.device)
def _demographic_parity_loss(
self,
logits: torch.Tensor,
demographic_labels: torch.Tensor
) -> torch.Tensor:
"""
Enforce demographic parity: $P(\hat{Y}=1 \mid A=0) \approx P(\hat{Y}=1 \mid A=1)$
Penalizes difference in positive prediction rates across groups.
"""
probs = F.softmax(logits, dim=1)[:, 1] # Probability of positive class
unique_groups = demographic_labels.unique()
if len(unique_groups) < 2:
return torch.tensor(0.0, device=logits.device)
group_pos_rates = []
for group in unique_groups:
group_mask = demographic_labels == group
if group_mask.sum() > 0:
group_pos_rate = probs[group_mask].mean()
group_pos_rates.append(group_pos_rate)
if len(group_pos_rates) < 2:
return torch.tensor(0.0, device=logits.device)
# Variance of positive rates across groups
group_pos_rates = torch.stack(group_pos_rates)
return group_pos_rates.var()
def _equalized_odds_loss(
self,
logits: torch.Tensor,
diagnostic_labels: torch.Tensor,
demographic_labels: torch.Tensor
) -> torch.Tensor:
"""
Enforce equalized odds: TPR and FPR equal across groups.
More stringent than demographic parity, requires both
true positive and false positive rates to be equalized.
"""
probs = F.softmax(logits, dim=1)[:, 1]
predictions = (probs > 0.5).float()
unique_groups = demographic_labels.unique()
if len(unique_groups) < 2:
return torch.tensor(0.0, device=logits.device)
group_tprs = []
group_fprs = []
for group in unique_groups:
group_mask = demographic_labels == group
if group_mask.sum() == 0:
continue
group_preds = predictions[group_mask]
group_labels = diagnostic_labels[group_mask]
# True positives
tp = ((group_preds == 1) & (group_labels == 1)).float().sum()
fn = ((group_preds == 0) & (group_labels == 1)).float().sum()
tpr = tp / (tp + fn + 1e-8)
# False positives
fp = ((group_preds == 1) & (group_labels == 0)).float().sum()
tn = ((group_preds == 0) & (group_labels == 0)).float().sum()
fpr = fp / (fp + tn + 1e-8)
group_tprs.append(tpr)
group_fprs.append(fpr)
if len(group_tprs) < 2:
return torch.tensor(0.0, device=logits.device)
group_tprs = torch.stack(group_tprs)
group_fprs = torch.stack(group_fprs)
# Penalize variance in both TPR and FPR
return group_tprs.var() + group_fprs.var()
def _calibration_fairness_loss(
self,
logits: torch.Tensor,
diagnostic_labels: torch.Tensor,
demographic_labels: torch.Tensor
) -> torch.Tensor:
"""
Enforce calibration fairness across groups.
Ensures that predicted probabilities are well-calibrated
for all demographic groups.
"""
probs = F.softmax(logits, dim=1)[:, 1]
unique_groups = demographic_labels.unique()
if len(unique_groups) < 2:
return torch.tensor(0.0, device=logits.device)
group_calibration_errors = []
for group in unique_groups:
group_mask = demographic_labels == group
if group_mask.sum() < 10: # Need sufficient samples
continue
group_probs = probs[group_mask]
group_labels = diagnostic_labels[group_mask].float()
# Compute calibration error in bins
n_bins = 10
bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=probs.device)
bin_errors = []
for i in range(n_bins):
bin_mask = (group_probs >= bin_boundaries[i]) & \
(group_probs < bin_boundaries[i + 1])
if bin_mask.sum() > 0:
bin_mean_prob = group_probs[bin_mask].mean()
bin_mean_label = group_labels[bin_mask].mean()
bin_error = (bin_mean_prob - bin_mean_label).abs()
bin_errors.append(bin_error)
if bin_errors:
group_calibration_error = torch.stack(bin_errors).mean()
group_calibration_errors.append(group_calibration_error)
if len(group_calibration_errors) < 2:
return torch.tensor(0.0, device=logits.device)
# Penalize variance in calibration error across groups
group_calibration_errors = torch.stack(group_calibration_errors)
return group_calibration_errors.var()
This multi-task learning framework enables training classifiers that achieve high diagnostic accuracy while maintaining fairness across demographic groups through adversarial debiasing and explicit fairness constraints.
7.6 Conclusion
Computer vision for medical imaging holds immense promise for improving healthcare access and outcomes, but realizing this promise requires explicit attention to fairness throughout the development lifecycle. This chapter has developed comprehensive approaches for medical image preprocessing, augmentation, segmentation, detection, and classification that explicitly account for systematic differences in image acquisition, anatomical variation, and disease presentation across patient demographics and healthcare settings. The implementations provided enable practitioners to build computer vision systems that maintain equitable performance across all populations they serve, with comprehensive evaluation frameworks that surface disparities during development rather than after deployment.
The path forward requires sustained commitment to fairness as a first-class objective alongside traditional performance metrics. Medical imaging AI systems must be validated not just on aggregate metrics but with stratified evaluation across demographic factors, care settings, and equipment types. Training datasets must be actively diversified to represent the full spectrum of patients who will be affected by these systems. Preprocessing and augmentation strategies must account for systematic differences in image characteristics that correlate with patient socioeconomic status through the unequal distribution of healthcare resources. Models must be developed with explicit fairness constraints that prevent learning spurious correlations between imaging artifacts and patient outcomes.
As medical imaging AI becomes increasingly integrated into clinical workflows, the equity considerations developed in this chapter become ever more critical. Systems that perform poorly for certain patient populations perpetuate and potentially amplify existing healthcare disparities. By centering fairness from the outset, we can build computer vision technologies that truly democratize access to high-quality diagnostic imaging interpretation and contribute to rather than undermining health equity.
Bibliography
Adamson, A. S., & Smith, A. (2018). Machine learning and health care disparities in dermatology. JAMA Dermatology, 154(11), 1247-1248. https://doi.org/10.1001/jamadermatol.2018.2348
Badgeley, M. A., Zech, J. R., Oakden-Rayner, L., Glicksberg, B. S., Liu, M., Gale, W., … & Oermann, E. K. (2019). Deep learning predicts hip fracture using confounding patient and healthcare variables. NPJ Digital Medicine, 2(1), 31. https://doi.org/10.1038/s41746-019-0105-1
Beam, A. L., & Kohane, I. S. (2018). Big data and machine learning in health care. JAMA, 319(13), 1317-1318. https://doi.org/10.1001/jama.2017.18391
Chen, I. Y., Pierson, E., Rose, S., Joshi, S., Ferryman, K., & Ghassemi, M. (2021). Ethical machine learning in healthcare. Annual Review of Biomedical Data Science, 4, 123-144. https://doi.org/10.1146/annurev-biodatasci-092820-114757
Chen, R. J., Lu, M. Y., Chen, T. Y., Williamson, D. F., & Mahmood, F. (2021). Synthetic data in machine learning for medicine and healthcare. Nature Biomedical Engineering, 5(6), 493-497. https://doi.org/10.1038/s41551-021-00751-8
Daneshjou, R., Vodrahalli, K., Novoa, R. A., Jenkins, M., Liang, W., Rotemberg, V., … & Zou, J. (2022). Disparities in dermatology AI performance on a diverse, curated clinical image dataset. Science Advances, 8(32), eabq6147. https://doi.org/10.1126/sciadv.abq6147
Diao, J. A., Wang, J. K., Chui, W. F., Mountain, V., Gullapally, S. C., Srinivasan, R., … & Fuchs, T. J. (2021). Human-interpretable image features derived from densely mapped cancer pathology slides predict diverse molecular phenotypes. Nature Communications, 12(1), 1613. https://doi.org/10.1038/s41467-021-21896-9
Esteva, A., Kuprel, B., Novoa, R. A., Ko, J., Swetter, S. M., Blau, H. M., & Thrun, S. (2017). Dermatologist-level classification of skin cancer with deep neural networks. Nature, 542(7639), 115-118. https://doi.org/10.1038/nature21056
Futoma, J., Simons, M., Panch, T., Doshi-Velez, F., & Celi, L. A. (2020). The myth of generalisability in clinical research and machine learning in health care. The Lancet Digital Health, 2(9), e489-e492. https://doi.org/10.1016/S2589-7500(20)30186-2
Gichoya, J. W., Banerjee, I., Bhimireddy, A. R., Burns, J. L., Celi, L. A., Chen, L. C., … & Purkayastha, S. (2022). AI recognition of patient race in medical imaging: a modelling study. The Lancet Digital Health, 4(6), e406-e414. https://doi.org/10.1016/S2589-7500(22)00063-2
Gulshan, V., Peng, L., Coram, M., Stumpe, M. C., Wu, D., Narayanaswamy, A., … & Webster, D. R. (2016). Development and validation of a deep learning algorithm for detection of diabetic retinopathy in retinal fundus photographs. JAMA, 316(22), 2402-2410. https://doi.org/10.1001/jama.2016.17216
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 770-778. https://doi.org/10.1109/CVPR.2016.90
Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. (2017). Densely connected convolutional networks. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 4700-4708. https://doi.org/10.1109/CVPR.2017.243
Irvin, J., Rajpurkar, P., Ko, M., Yu, Y., Ciurea-Ilcus, S., Chute, C., … & Ng, A. Y. (2019). CheXpert: A large chest radiograph dataset with uncertainty labels and expert comparison. Proceedings of the AAAI Conference on Artificial Intelligence, 33(01), 590-597. https://doi.org/10.1609/aaai.v33i01.3301590
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods, 18(2), 203-211. https://doi.org/10.1038/s41592-020-01008-z
Johnson, A. E., Pollard, T. J., Shen, L., Lehman, L. W. H., Feng, M., Ghassemi, M., … & Mark, R. G. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data, 3(1), 1-9. https://doi.org/10.1038/sdata.2016.35
Kline, A., Wang, H., Li, Y., Dennis, S., Hutch, M., Xu, Z., … & Somai, M. (2022). Multimodal machine learning in precision health: A scoping review. npj Digital Medicine, 5(1), 171. https://doi.org/10.1038/s41746-022-00712-8
Liu, X., Faes, L., Kale, A. U., Wagner, S. K., Fu, D. J., Bruynseels, A., … & Denniston, A. K. (2019). A comparison of deep learning performance against health-care professionals in detecting diseases from medical imaging: a systematic review and meta-analysis. The Lancet Digital Health, 1(6), e271-e297. https://doi.org/10.1016/S2589-7500(19)30123-2
Long, J., Shelhamer, E., & Darrell, T. (2015). Fully convolutional networks for semantic segmentation. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 3431-3440. https://doi.org/10.1109/CVPR.2015.7298965
Madry, A., Makelov, A., Schmidt, L., Tsipras, D., & Vladu, A. (2018). Towards deep learning models resistant to adversarial attacks. Proceedings of the 6th International Conference on Learning Representations. https://arxiv.org/abs/1706.06083
McKinney, S. M., Sieniek, M., Godbole, V., Godwin, J., Antropova, N., Ashrafian, H., … & Shetty, S. (2020). International evaluation of an AI system for breast cancer screening. Nature, 577(7788), 89-94. https://doi.org/10.1038/s41586-019-1799-6
Mehrabi, N., Morstatter, F., Saxena, N., Lerman, K., & Galstyan, A. (2021). A survey on bias and fairness in machine learning. ACM Computing Surveys, 54(6), 1-35. https://doi.org/10.1145/3457607
Milletari, F., Navab, N., & Ahmadi, S. A. (2016). V-net: Fully convolutional neural networks for volumetric medical image segmentation. Proceedings of the 2016 Fourth International Conference on 3D Vision, 565-571. https://doi.org/10.1109/3DV.2016.79
Obermeyer, Z., Powers, B., Vogeli, C., & Mullainathan, S. (2019). Dissecting racial bias in an algorithm used to manage the health of populations. Science, 366(6464), 447-453. https://doi.org/10.1126/science.aax2342
Oakden-Rayner, L., Dunnmon, J., Carneiro, G., & Ré, C. (2020). Hidden stratification causes clinically meaningful failures in machine learning for medical imaging. Proceedings of the ACM Conference on Health, Inference, and Learning, 151-159. https://doi.org/10.1145/3368555.3384468
Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., … & Chintala, S. (2019). PyTorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32, 8026-8037. https://proceedings.neurips.cc/paper/2019/file/bdbca288fee7f92f2bfa9f7012727740-Paper.pdf
Rajkomar, A., Hardt, M., Howell, M. D., Corrado, G., & Chin, M. H. (2018). Ensuring fairness in machine learning to advance health equity. Annals of Internal Medicine, 169(12), 866-872. https://doi.org/10.7326/M18-1990
Rajpurkar, P., Irvin, J., Zhu, K., Yang, B., Mehta, H., Duan, T., … & Ng, A. Y. (2017). CheXNet: Radiologist-level pneumonia detection on chest X-rays with deep learning. arXiv preprint arXiv:1711.05225. https://arxiv.org/abs/1711.05225
Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional networks for biomedical image segmentation. Proceedings of the International Conference on Medical Image Computing and Computer-Assisted Intervention, 234-241. https://doi.org/10.1007/978-3-319-24574-4_28
Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., … & Fei-Fei, L. (2015). ImageNet large scale visual recognition challenge. International Journal of Computer Vision, 115(3), 211-252. https://doi.org/10.1007/s11263-015-0816-y
Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. (2017). Grad-CAM: Visual explanations from deep networks via gradient-based localization. Proceedings of the IEEE International Conference on Computer Vision, 618-626. https://doi.org/10.1109/ICCV.2017.74
Seyyed-Kalantari, L., Zhang, H., McDermott, M. B., Chen, I. Y., & Ghassemi, M. (2021). Underdiagnosis bias of artificial intelligence algorithms applied to chest radiographs in under-served patient populations. Nature Medicine, 27(12), 2176-2182. https://doi.org/10.1038/s41591-021-01595-0
Shorten, C., & Khoshgoftaar, T. M. (2019). A survey on image data augmentation for deep learning. Journal of Big Data, 6(1), 1-48. https://doi.org/10.1186/s40537-019-0197-0
Tan, M., & Le, Q. (2019). EfficientNet: Rethinking model scaling for convolutional neural networks. Proceedings of the 36th International Conference on Machine Learning, 97, 6105-6114. http://proceedings.mlr.press/v97/tan19a.html
Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., & Summers, R. M. (2017). ChestX-ray8: Hospital-scale chest X-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2097-2106. https://doi.org/10.1109/CVPR.2017.369
Winkler, J. K., Fink, C., Toberer, F., Enk, A., Deinlein, T., Hofmann-Wellenhof, R., … & Haenssle, H. A. (2019). Association between surgical skin markings in dermoscopic images and diagnostic performance of a deep learning convolutional neural network for melanoma recognition. JAMA Dermatology, 155(10), 1135-1141. https://doi.org/10.1001/jamadermatol.2019.1735
Zech, J. R., Badgeley, M. A., Liu, M., Costa, A. B., Titano, J. J., & Oermann, E. K. (2018). Variable generalization performance of a deep learning model to detect pneumonia in chest radiographs: A cross-sectional study. PLOS Medicine, 15(11), e1002683. https://doi.org/10.1371/journal.pmed.1002683
Zhang, H., Dullerud, N., Roth, K., Oakden-Rayner, L., Pfohl, S., & Ghassemi, M. (2023). Improving the fairness of chest X-ray classifiers. Proceedings of the Conference on Health, Inference, and Learning, 204-233. https://proceedings.mlr.press/v174/zhang22c.html
Zhou, Z., Rahman Siddiquee, M. M., Tajbakhsh, N., & Liang, J. (2018). UNet++: A nested U-Net architecture for medical image segmentation. Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical Decision Support, 3-11. https://doi.org/10.1007/978-3-030-00889-5_1