Commit 4eec270e authored by Clement Vachet's avatar Clement Vachet

Initial commit

parents
Pipeline #2710 canceled with stages
This diff is collapsed.
#!/usr/bin/env python3
import os
import sys
import argparse
import time
# Device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchio as tio
import imageio
import math
from network import *
import utils
import dataset
# --------------------
# Model - FC layers
dict_fc_features = {
# Phase1- concatenation on 3rd layer
'Phase1': [2048,512,256,64],
'Phase2': [128,64,32],
}
# MC Dropout
mc_dropout = True
mc_passes = 50
# --------------------
def arg_parser():
parser = argparse.ArgumentParser(description='Inference - ')
required = parser.add_argument_group('Required')
required.add_argument('--input', type=str, required=True,
help='Combined TIFF file (multi-layer)')
required.add_argument('--network', type=str, required=True,
help='pytorch neural network')
required.add_argument('--output', type=str, required=True,
help='Image prediction (2D TIFF file)')
options = parser.add_argument_group('Options')
options.add_argument('--tile_size', type=int, default=15,
help='tile size')
options.add_argument('--adjacent_tiles_dim', type=int, default=1,
help='adjacent tiles dim (e.g. 3, 5)')
options.add_argument('--bs', type=int, default=5000,
help='Batch size (default 5000)')
options.add_argument('--output_median', type=str,
help='Image output - median for MCDropout (2D TIFF file)')
options.add_argument('--output_cv', type=str,
help='Image output - Coefficient of Variation for MCDropout (2D TIFF file)')
return parser
def apply_dropout(m):
if m.__class__.__name__.startswith('Dropout'):
print('\t\t Enabling MC dropout!')
m.train()
#MAIN
def main(args=None):
args = arg_parser().parse_args(args)
InputFile = args.input
ModelName = args.network
OutputFile = args.output
OutputFile_median = args.output_median
OutputFile_CV = args.output_cv
TileSize = args.tile_size
AdjacentTilesDim = args.adjacent_tiles_dim
bs = args.bs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('\n--------------------')
since1 = time.time()
# TorchIO subject
print('\nGenerating TIO subject...')
Subject = tio.Subject(
Combined = tio.ScalarImage(InputFile),
)
# Initialize variables
InputFile_Shape = Subject['Combined'].shape
NbTiles_H = InputFile_Shape[1] // TileSize
NbTiles_W = InputFile_Shape[2] // TileSize
NbImageLayers = InputFile_Shape[3]
NbCorrLayers = NbImageLayers -4
InputDepth = NbCorrLayers
print('InputFile_Shape: ', InputFile_Shape)
print('NbTiles_H: ', NbTiles_H)
print('NbTiles_W: ', NbTiles_W)
print('NbImageLayers: ', NbImageLayers)
print('InputDepth: ', InputDepth)
# GridSampler
print('\nGenerating Grid Sampler...')
patch_size, patch_overlap, padding_mode = dataset.initialize_gridsampler_variables(NbImageLayers, TileSize, AdjacentTilesDim, padding_mode=None)
print('patch_size: ',patch_size)
print('patch_overlap: ',patch_overlap)
print('padding_mode: ',padding_mode)
grid_sampler = tio.data.GridSampler(
subject = Subject,
patch_size = patch_size,
patch_overlap = patch_overlap,
padding_mode = padding_mode,
)
len_grid_sampler = len(grid_sampler)
print('length grid_sampler', len(grid_sampler))
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=bs)
aggregator = tio.data.GridAggregator(grid_sampler, overlap_mode = 'average')
print('\nLoading DNN model...')
model = MyParallelNetwork(InputDepth, TileSize, AdjacentTilesDim, dict_fc_features)
model.load_state_dict(torch.load(ModelName))
print(model)
model.to(device)
model.eval()
if mc_dropout:
print('\t MC Dropout')
model.apply(apply_dropout)
print('\nPatch-based inference...')
since2 = time.time()
#model = nn.Identity().eval()
with torch.no_grad():
for patch_idx, patches_batch in enumerate(patch_loader):
print('\t patch_idx: ', patch_idx)
#print('\t\t Preparing data...')
inputs = patches_batch['Combined'][tio.DATA]
print('\t\t inputs shape: ', inputs.shape)
input1_tiles, input2_tiles_real, GroundTruth_real = dataset.prepare_data_withfiltering(inputs, NbImageLayers, NbCorrLayers, TileSize, AdjacentTilesDim)
#print('\t\t Preparing data - done -')
input1_tiles = input1_tiles.to(device)
input2_tiles_real = input2_tiles_real.to(device)
#GroundTruth_real = GroundTruth_real.to(device)
# Reducing last dimension to compute loss
#GroundTruth_real = torch.squeeze(GroundTruth_real, dim=2)
print('\t\t input1_tiles shape: ', input1_tiles.shape)
print('\t\t input2_tiles_real shape:', input2_tiles_real.shape)
if mc_dropout:
# Perform multiple inference (mc_passes)
outputs_all = torch.empty(size=(mc_passes, input1_tiles.shape[0])).to(device)
for i in range(0, mc_passes):
outputs = model(input1_tiles, input2_tiles_real)
outputs_all[i] = torch.squeeze(outputs)
# Compute mean, std, CV (coefficient of variation), SE (standard error)
outputs_mean = torch.mean(outputs_all,0)
outputs_median = torch.median(outputs_all,0)[0]
outputs_std = torch.std(outputs_all,0)
outputs_cv = torch.div(outputs_std, torch.abs(outputs_mean))
# outputs_se = torch.div(outputs_std, math.sqrt(mc_passes))
outputs_combined = torch.stack((outputs_mean, outputs_median, outputs_cv), dim=1)
print('\t\t outputs shape: ',outputs.shape)
print('\t\t outputs device', outputs.device)
print('\t\t outputs_all shape: ', outputs_all.shape)
print('\t\t outputs_all device', outputs_all.device)
print('\t\t outputs_mean shape: ', outputs_mean.shape)
print('\t\t outputs_median shape: ', outputs_median.shape)
print('\t\t outputs_median type: ', outputs_median.type())
print('\t\t outputs_combined shape: ', outputs_combined.shape)
print('\t\t outputs_mean[:20]',outputs_mean[:20])
print('\t\t outputs_median[:20]',outputs_median[:20])
print('\t\t outputs_std[:20]',outputs_std[:20])
print('\t\t outputs_cv[:20]',outputs_cv[:20])
else:
outputs_combined = model(input1_tiles, input2_tiles_real)
print('\t\t outputs_combined device', outputs_combined.device)
print('\t\t outputs_combined shape: ', outputs_combined.shape)
# Reshape outputs to match location dimensions
outputs_combined_reshape = torch.reshape(outputs_combined,[outputs_combined.shape[0],outputs_combined.shape[1],1,1,1])
print('\t\t outputs_combined_reshape shape: ', outputs_combined_reshape.shape)
input_location = patches_batch[tio.LOCATION]
print('\t\t input_location shape: ', input_location.shape)
print('\t\t input_location type: ', input_location.dtype)
print('\t\t input_location[:20]: ', input_location[:20])
# Reshape input_location to prediction_location, to fit output image size (78,62,1)
pred_location = dataset.prediction_patch_location(input_location, TileSize, AdjacentTilesDim)
print('\t\t pred_location shape: ', pred_location.shape)
print('\t\t pred_location[:20]: ', pred_location[:20])
# Add batch with location to TorchIO aggregator
aggregator.add_batch(outputs_combined_reshape, pred_location)
# output_tensor shape [3, 1170, 930, 124]
output_tensor_combined = aggregator.get_output_tensor()
print('output_tensor_combined type: ', output_tensor_combined.dtype)
print('output_tensor_combined shape: ', output_tensor_combined.shape)
# Extract real information of interest [3, 78,62]
output_tensor_combined_real = output_tensor_combined[:,:NbTiles_H,:NbTiles_W,0]
print('output_tensor_combined_real shape: ', output_tensor_combined_real.shape)
output_combined_np = output_tensor_combined_real.numpy().squeeze()
print('output_combined_np type', output_combined_np.dtype)
print('output_combined_np shape', output_combined_np.shape)
if mc_dropout:
output_mean_np = output_combined_np[0,...]
output_median_np = output_combined_np[1,...]
output_cv_np = output_combined_np[2,...]
imageio_output_mean = np.moveaxis(output_mean_np, 0,1)
imageio_output_median = np.moveaxis(output_median_np, 0,1)
imageio_output_cv = np.moveaxis(output_cv_np, 0,1)
print('imageio_output_mean shape', imageio_output_mean.shape)
print('imageio_output_median shape', imageio_output_median.shape)
print('imageio_output_cv shape', imageio_output_cv.shape)
else:
output_np = output_combined_np
imageio_output = np.moveaxis(output_np, 0,1)
print('imageio_output shape', imageio_output.shape)
time_elapsed2 = time.time() - since2
if mc_dropout:
print('Writing output mean image via imageio...')
imageio.imwrite(OutputFile, imageio_output_mean)
print('Writing output median image via imageio...')
imageio.imwrite(OutputFile_median, imageio_output_median)
print('Writing output CV image via imageio...')
imageio.imwrite(OutputFile_CV, imageio_output_cv)
else:
print('Writing output image via imageio...')
imageio.imwrite(OutputFile, imageio_output)
time_elapsed3 = time.time() - since2
time_elapsed1 = time.time() - since1
print('--- Inference in {:.2f}s---'.format(time_elapsed2))
print('--- Inference and saving in {:.2f}s---'.format(time_elapsed3))
print('--- Total time in {:.2f}s---'.format(time_elapsed1))
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
# -*- coding: utf-8 -*-
"""
AI analysis via parallel neural networks
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import os
import copy
import yaml
import argparse
# Device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torchio as tio
# -----------
from model import *
import utils
import dataset
import sampler
# Manual seed
torch.manual_seed(42)
######################################################################
def arg_parser():
parser = argparse.ArgumentParser(description='AI analysis - DNN training')
required = parser.add_argument_group('Required')
required.add_argument('--config', type=str, required=True,
help='YAML configuration / parameter file')
options = parser.add_argument_group('Options')
options.add_argument('--verbose', action="store_true",
help='verbose mode')
return parser
def main(args=None):
args = arg_parser().parse_args(args)
config_filename = args.config
# plt.ion() # interactive mode
######################################################################
# Loading parameter file
print('\n--- Loading configuration file --- ')
with open(config_filename,'r') as yaml_file:
config_file = yaml.safe_load(yaml_file)
if args.verbose:
print('config_file', config_file)
# Defining parameters
CSVFile_train = config_file['CSVFile_train']
CSVFile_val = config_file['CSVFile_val']
model_filename = config_file['ModelName']
loss_filename = config_file['LossName']
nb_image_layers = config_file['NbImageLayers']
nb_corr_layers = config_file['NbCorrLayers']
tile_size = config_file['TileSize']
adjacent_tiles_dim = config_file['AdjacentTilesDim']
num_workers = config_file['num_workers']
samples_per_volume = config_file['samples_per_volume']
queue_length = config_file['queue_length']
data_filtering = config_file['DataFiltering']
confidence_threshold = config_file['ConfidenceThreshold']
dict_fc_features = config_file['dict_fc_features']
bs = config_file['bs']
lr = config_file['lr']
nb_epochs = config_file['nb_epochs']
# ------------------
print('\n--- Generating torchIO dataset ---')
File_list_train, TIOSubjects_list_train = dataset.GenerateTIOSubjectsList(CSVFile_train)
File_list_test, TIOSubjects_list_test = dataset.GenerateTIOSubjectsList(CSVFile_val)
# torchIO transforms
TIOtransforms = [
tio.RandomFlip(axes=('lr')),
]
TIOtransform = tio.Compose(TIOtransforms)
# TIO dataset
TIOSubjects_dataset_train = tio.SubjectsDataset(TIOSubjects_list_train, transform=TIOtransform)
TIOSubjects_dataset_test = tio.SubjectsDataset(TIOSubjects_list_test, transform=None)
print('Training set: ', len(TIOSubjects_dataset_train), 'subjects')
print('Validation set: ', len(TIOSubjects_dataset_test), 'subjects')
# ------------------
# ------------------
# Subject visualization
if args.verbose:
print('\n--- Quality control: TIOSubject Info ---')
MyTIOSubject = TIOSubjects_dataset_train[0]
print('MySubject: ', MyTIOSubject)
print('MySubject.shape: ', MyTIOSubject.shape)
print('MySubject.spacing: ', MyTIOSubject.spacing)
print('MySubject.spatial_shape: ', MyTIOSubject.spatial_shape)
print('MySubject.spatial_shape.type: ', type(MyTIOSubject.spatial_shape))
print('MySubject history: ', MyTIOSubject.get_composed_history())
# ------------------
# - - - - - - - - - - - - - -
# Training with GridSampler
# patch_size, patch_overlap, padding_mode = dataset.initialize_gridsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None)
# print('patch_size: ',patch_size)
# print('patch_overlap: ',patch_overlap)
# print('padding_mode: ',padding_mode)
# example_grid_sampler = tio.data.GridSampler(
# subject = MyTIOSubject,
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# samples_per_volume = len(example_grid_sampler)
# queue_length = samples_per_volume * num_workers
# print('samples_per_volume', samples_per_volume)
# print('queue_length', queue_length)
# sampler_train = tio.data.GridSampler(
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# sampler_test = tio.data.GridSampler(
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# - - - - - - - - - - - - - -
# - - - - - - - - - - - - - -
# Training with UniformSampler
print('\n--- Initializing patch sampling variables ---')
patch_size, patch_overlap, padding_mode = dataset.initialize_uniformsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None)
if args.verbose:
print('patch_size: ',patch_size)
print('patch_overlap: ',patch_overlap)
print('padding_mode: ',padding_mode)
print('samples_per_volume', samples_per_volume)
print('queue_length', queue_length)
sampler_train = sampler.MyUniformSampler(
patch_size = patch_size,
tile_size = tile_size,
)
sampler_test = sampler.MyUniformSampler(
patch_size = patch_size,
tile_size = tile_size,
)
patches_queue_train = tio.Queue(
subjects_dataset = TIOSubjects_dataset_train,
max_length = queue_length,
samples_per_volume = samples_per_volume,
sampler = sampler_train,
num_workers = num_workers,
shuffle_subjects = True,
shuffle_patches = True,
)
patches_queue_test = tio.Queue(
subjects_dataset = TIOSubjects_dataset_test,
max_length = queue_length,
samples_per_volume = samples_per_volume,
sampler = sampler_test,
num_workers = num_workers,
shuffle_subjects = True,
shuffle_patches = True,
)
patches_loader_train = DataLoader(
patches_queue_train,
batch_size = bs,
shuffle = True,
num_workers = 0, # this must be 0
)
patches_loader_test = DataLoader(
patches_queue_test,
batch_size = bs,
shuffle = False,
num_workers = 0, # this must be 0
)
# Dictionary for patch data loaders
patches_loader_dict = {}
patches_loader_dict['train'] = patches_loader_train
patches_loader_dict['val'] = patches_loader_test
# ----------------------
# Visualize input data
writer = SummaryWriter('tensorboard/MyNetwork')
# # Get a batch of training data
print('\n--- Quality control: patch inputs ---')
patches_batch = next(iter(patches_loader_dict['val']))
inputs = patches_batch['Combined'][tio.DATA]
locations = patches_batch[tio.LOCATION]
# Variable initialization needed for TensorBoard
input_Corr_tiles, input_TargetDisp_tiles_real, GroundTruth_real = dataset.prepare_data_withfiltering(inputs, nb_image_layers, nb_corr_layers, tile_size, adjacent_tiles_dim, data_filtering, confidence_threshold)
if args.verbose:
print('\ninput_Corr_tiles.shape: ', input_Corr_tiles.shape)
print('input_TargetDisp_tiles_real.shape: ', input_TargetDisp_tiles_real.shape)
print('GroundTruth_real.shape: ', GroundTruth_real.shape)
######################################################################
# Neural network - training
# ----------------------
#
# Create a neural network model and start training / testing.
#
# ----------------------
# Create model
print('\n--- Creating neural network architecture ---')
model_ft = Model(writer, nb_image_layers, nb_corr_layers, tile_size, adjacent_tiles_dim, model_filename, dict_fc_features, loss_filename, data_filtering, confidence_threshold)
# Tensorboard - add graph
writer.add_graph(model_ft.model, [input_Corr_tiles.to(model_ft.device), input_TargetDisp_tiles_real.to(model_ft.device)])
writer.close()
# ----------------------
# Train and evaluate
print('\n--- DNN training ---')
model_ft.train_model(dataloaders=patches_loader_dict, lr=lr, nb_epochs=nb_epochs)
# ----------------------
# Evaluate on validation data
print('\n--- DNN testing ---')
model_ft.test_model(dataloaders=patches_loader_dict)
# plt.ioff()
# plt.show()
if __name__ == "__main__":
main()
# - - - - - - - - -
# Inference on real validation data - CSV file
# Environment 1
python3 AI_Inference_CSV.py --config ./Config_Files/AI_Inference_Config_Tiles1x1.yaml --verbose
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles1x1.yaml > AI_Training_Tiles1x1_MCDropout.log
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles3x3.yaml > AI_Training_Tiles3x3_MCDropout.log
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles5x5.yaml > AI_Training_Tiles5x5_MCDropout.log
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 1 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 3 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 5 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 1 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True
ConfidenceThreshold: 0.0
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 500
# Learning rate
lr: 1.0e-3
# Number Epochs
nb_epochs: 15
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 3 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True
ConfidenceThreshold: 0.0
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 500
# Learning rate
lr: 1.0e-3
# Number Epochs
nb_epochs: 15
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 5 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True
ConfidenceThreshold: 0.0
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 500
# Learning rate
lr: 1.0e-3
# Number Epochs
nb_epochs: 15
name: Elphel_chimera_conda3_Test
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2022.4.26=h06a4308_0
- certifi=2022.6.15=py39h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.3=h5eee18b_3
- openssl=1.1.1q=h7f8727e_0
- pip=22.1.2=py39h06a4308_0
- python=3.9.12=h12debd9_1
- readline=8.1.2=h7f8727e_1
- setuptools=61.2.0=py39h06a4308_0
- sqlite=3.38.5=hc218d9a_0
- tk=8.6.12=h1ccaba5_0
- tzdata=2022a=hda174b7_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- pip:
- --extra-index-url https://download.pytorch.org/whl/cu113
- absl-py==1.2.0
- cachetools==5.2.0
- charset-normalizer==2.1.0
- click==8.1.3
- cycler==0.11.0
- deprecated==1.2.13
- fonttools==4.34.4
- google-auth==2.9.1
- google-auth-oauthlib==0.4.6
- grpcio==1.47.0
- humanize==4.2.3
- idna==3.3
- imageio==2.19.5
- importlib-metadata==4.12.0
- joblib==1.1.0
- kiwisolver==1.4.4
- markdown==3.4.1
- matplotlib==3.5.2
- nibabel==4.0.1
- numpy==1.23.1
- oauthlib==3.2.0
- packaging==21.3
- pandas==1.4.3
- pillow==9.2.0
- protobuf==3.19.4
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pytz==2022.1
- pyyaml==6.0
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.1.1
- scipy==1.8.1
- seaborn==0.11.2
- simpleitk==2.1.1.2
- six==1.16.0
- tensorboard==2.9.1
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- torch==1.12.0+cu113
- torchaudio==0.12.0+cu113
- torchio==0.18.83
- torchvision==0.13.0+cu113
- tqdm==4.64.0
- typing-extensions==4.3.0
- urllib3==1.26.10
- werkzeug==2.1.2
- wrapt==1.14.1
- zipp==3.8.1
Combined
/data/CNN_Input/001_001-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_002-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_003-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_004-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_005-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
Combined
/data/CNN_Input/001_006-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_007-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_008-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_009-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
/data/CNN_Input/001_010-ML-AUX-RND-DOFFS0.014_Combined_WithDispLMA.tiff
README file for package installation
## 1. Installation and requirements
Installation using conda
```
> conda create --name <env_name> --file <conda_environment.yaml>
> conda activate <env_name>
OR
> source activate <env_name>
```
General libraries being used:
- python 3.9
- pillow, numpy, pandas, matplotlib
- pytorch, torchio, imageio, tensorboard
- scikit-learn, yaml
## 2. Updating CSV files to define your datast
Update your CSV files to point to your training and validation datasets.
- Please see "Example_CSV/Data_Example_train.csv"
## 3. Running the package
### 3.1 DNN training using configuration file
Command line:
```
python3 AI_Training.py --config ./Config_Files/AI_Training_Config.yaml
```
### 3.2 DNN inference using configuration file
Command line:
```
python3 AI_Inference_CSV.py --config ./Config_Files/AI_Inference_Config.yaml --verbose
```
name: Elphel_chimera_conda3_Test
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2022.4.26=h06a4308_0
- certifi=2022.6.15=py39h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.3=h5eee18b_3
- openssl=1.1.1q=h7f8727e_0
- pip=22.1.2=py39h06a4308_0
- python=3.9.12=h12debd9_1
- readline=8.1.2=h7f8727e_1
- setuptools=61.2.0=py39h06a4308_0
- sqlite=3.38.5=hc218d9a_0
- tk=8.6.12=h1ccaba5_0
- tzdata=2022a=hda174b7_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- pip:
- --extra-index-url https://download.pytorch.org/whl/cu113
- absl-py==1.2.0
- cachetools==5.2.0
- charset-normalizer==2.1.0
- click==8.1.3
- cycler==0.11.0
- deprecated==1.2.13
- fonttools==4.34.4
- google-auth==2.9.1
- google-auth-oauthlib==0.4.6
- grpcio==1.47.0
- humanize==4.2.3
- idna==3.3
- imageio==2.19.5
- importlib-metadata==4.12.0
- joblib==1.1.0
- kiwisolver==1.4.4
- markdown==3.4.1
- matplotlib==3.5.2
- nibabel==4.0.1
- numpy==1.23.1
- oauthlib==3.2.0
- packaging==21.3
- pandas==1.4.3
- pillow==9.2.0
- protobuf==3.19.4
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pytz==2022.1
- pyyaml==6.0
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.1.1
- scipy==1.8.1
- seaborn==0.11.2
- simpleitk==2.1.1.2
- six==1.16.0
- tensorboard==2.9.1
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- torch==1.12.0+cu113
- torchaudio==0.12.0+cu113
- torchio==0.18.83
- torchvision==0.13.0+cu113
- tqdm==4.64.0
- typing-extensions==4.3.0
- urllib3==1.26.10
- werkzeug==2.1.2
- wrapt==1.14.1
- zipp==3.8.1
import pandas as pd
import torchio as tio
import torch
import numpy as np
# Generate list of torchIO subjects from CSV file
def GenerateTIOSubjectsList(CSVFile):
df = pd.read_csv(CSVFile, sep=',')
File_list = df['Combined'].tolist()
TIOSubjects_list = []
for idx in range(len(File_list)):
TIOSubject = tio.Subject(
Combined = tio.ScalarImage(File_list[idx]),
)
TIOSubjects_list.append(TIOSubject)
return File_list, TIOSubjects_list
# split ROI into tiles (3x3, or 5x5)
# initial shape [bs=2000,1,45,45,120]
# new shape for 3x3: [bs=2000,9,15,15,120]
def roi_split(t_roi, tile_size, adjacent_tiles_dim):
# t_roi_shape = [B,C,H,W,D]
t_roi_shape = t_roi.shape
#print('\t t_roi_shape: ',t_roi_shape)
# Remove channel layer
a = torch.squeeze(t_roi)
# Reshape tensor [bs=2000,3,15,3,15,120]
b = a.reshape(t_roi_shape[0],adjacent_tiles_dim,tile_size,adjacent_tiles_dim,tile_size,-1)
# Swap axes [bs=2000,3,3,15,15,120]
c = b.swapaxes(2,3)
# Combine channels [bs=2000,3x3,15,15,120]
d = c.reshape(t_roi_shape[0],adjacent_tiles_dim*adjacent_tiles_dim,tile_size,tile_size,-1)
# Remove last dimension of size 1 when needed (D=1 for TargetDisparity)
e = torch.squeeze(d,axis=4)
return e
# Prepare data as multiple inputs to network (tensors)
# Perform data filtering if enabled, using Confidence and DispLMA maps
def prepare_data_withfiltering(t_input, nb_image_layers, nb_corr_layers, tile_size, adjacent_tiles_dim, is_filtering=0, confidence_threshold=0.0):
t_input_Corr = t_input[:,:,:,:,0:nb_corr_layers]
t_input_TargetDisp = t_input[:,:,:,:,-4]
t_GroundTruth = t_input[:,:,:,:,-3]
t_Confidence = t_input[:,:,:,:,-2]
t_DispLMA = t_input[:,:,:,:,-1]
# print('t_input.type: ', t_input.type())
# # torch.Size([2000, 1, 45, 45, 122])
# print('t_input.shape: ', t_input.shape)
# print('t_input_Corr.shape: ', t_input_Corr.shape)
# print('t_input_TargetDisp.shape: ', t_input_TargetDisp.shape)
# print('t_GroundTruth.shape: ', t_GroundTruth.shape)
# Generate tiles when needed
if (adjacent_tiles_dim == 1):
t_input_Corr_tiles = t_input_Corr
t_input_TargetDisp_tiles = t_input_TargetDisp
t_GroundTruth_tiles = t_GroundTruth
t_Confidence_tiles = t_Confidence
t_DispLMA_tiles = t_DispLMA
else:
# Split t_input into neighboring tiles
t_input_Corr_tiles = roi_split(t_input_Corr, tile_size, adjacent_tiles_dim)
# # torch.Size([2000, 9, 15, 15, 120])
# print('t_input_Corr_tiles.shape: ', t_input_Corr_tiles.shape)
t_input_TargetDisp_tiles = roi_split(t_input_TargetDisp, tile_size, adjacent_tiles_dim)
# print('t_input_TargetDisp_tiles.shape: ', t_input_TargetDisp_tiles.shape)
# # torch.Size([2000, 9, 15, 15])
t_GroundTruth_tiles = roi_split(t_GroundTruth, tile_size, adjacent_tiles_dim)
t_Confidence_tiles = roi_split(t_Confidence, tile_size, adjacent_tiles_dim)
t_DispLMA_tiles = roi_split(t_DispLMA, tile_size, adjacent_tiles_dim)
# Generate input_TargetDisp_tiles_real, t_GroundTruth_tiles_real, scaling back to 62x78 pixels
t_input_TargetDisp_tiles_real = t_input_TargetDisp_tiles[:,:,::tile_size,::tile_size]
t_GroundTruth_tiles_real = t_GroundTruth_tiles[:,:,::tile_size,::tile_size]
t_Confidence_tiles_real = t_Confidence_tiles[:,:,::tile_size,::tile_size]
t_DispLMA_tiles_real = t_DispLMA_tiles[:,:,::tile_size,::tile_size]
# # torch.Size([2000, 9, 1, 1])
# print('\nt_input_TargetDisp_tiles_real.shape: ', t_input_TargetDisp_tiles_real.shape)
# print('t_GroundTruth_tiles_real.shape: ', t_GroundTruth_tiles_real.shape)
# print('t_Confidence_tiles_real.shape: ', t_Confidence_tiles_real.shape)
# print('t_DispLMA_tiles_real.shape: ', t_DispLMA_tiles_real.shape)
# - - - - - - - -
# Data filtering
# print('Data filtering...')
if is_filtering:
t_DispLMA_tiles_real_Mask = ~torch.isnan(t_DispLMA_tiles_real)
t_Confidence_tiles_real_Mask = torch.where(t_Confidence_tiles_real >= confidence_threshold, 1, 0)
t_Mask = torch.logical_and(t_DispLMA_tiles_real_Mask, t_Confidence_tiles_real_Mask)
t_Mask = torch.squeeze(t_Mask)
if (adjacent_tiles_dim != 1):
t_Mask = torch.all(t_Mask, axis=1)
# print('t_Mask.shape: ', t_Mask.shape)
# print('t_Mask[:20]: ', t_Mask[:20,...])
t_input_Corr_tiles_filtered = t_input_Corr_tiles[t_Mask]
t_input_TargetDisp_tiles_real_filtered = t_input_TargetDisp_tiles_real[t_Mask]
t_GroundTruth_tiles_real_filtered = t_GroundTruth_tiles_real[t_Mask]
# print('t_input_Corr_tiles_filtered.shape: ', t_input_Corr_tiles_filtered.shape)
# print('t_input_TargetDisp_tiles_real_filtered.shape: ', t_input_TargetDisp_tiles_real_filtered.shape)
else:
t_input_Corr_tiles_filtered = t_input_Corr_tiles
t_input_TargetDisp_tiles_real_filtered = t_input_TargetDisp_tiles_real
t_GroundTruth_tiles_real_filtered = t_GroundTruth_tiles_real
# Define center tile for GroundTruth and TargetDisp maps
if (adjacent_tiles_dim == 3):
IndexCenterTile = 4
elif (adjacent_tiles_dim == 5):
IndexCenterTile = 12
else:
IndexCenterTile = 0
t_input_TargetDisp_real_filtered_center = t_input_TargetDisp_tiles_real_filtered[:,IndexCenterTile,...]
t_GroundTruth_real_filtered_center = t_GroundTruth_tiles_real_filtered[:,IndexCenterTile,...]
# print('t_GroundTruth_tiles.shape: ', t_GroundTruth_tiles.shape)
# print('t_GroundTruth_tiles_real.shape: ', t_GroundTruth_tiles_real.shape)
# print('t_GroundTruth_tiles_real_filtered.shape: ', t_GroundTruth_tiles_real_filtered.shape)
# print('t_GroundTruth_real_filtered_center.shape: ', t_GroundTruth_real_filtered_center.shape)
# print('t_input_TargetDisp_tiles.shape: ', t_input_TargetDisp_tiles.shape)
# print('t_input_TargetDisp_tiles_real.shape: ', t_input_TargetDisp_tiles_real.shape)
# print('t_input_TargetDisp_tiles_real_filtered.shape: ', t_input_TargetDisp_tiles_real_filtered.shape)
# print('t_input_TargetDisp_real_filtered_center.shape: ', t_input_TargetDisp_real_filtered_center.shape)
return t_input_Corr_tiles_filtered, t_input_TargetDisp_real_filtered_center, t_GroundTruth_real_filtered_center
# Initialize TorchIO GridSampler variables
# Generate patch_overlap based on adjacent_tiles_dim
def initialize_gridsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None):
# Define patch_size
patch_size = (adjacent_tiles_dim * tile_size, adjacent_tiles_dim * tile_size, nb_image_layers)
# Define padding_mode
#padding_mode = 'symmetric'
# Define patch_overlap
if (adjacent_tiles_dim == 1):
patch_overlap = (0,0,0)
elif (adjacent_tiles_dim == 3):
# patch_overlap = (30,30,0)
patch_overlap = (2*tile_size,2*tile_size,0)
elif (adjacent_tiles_dim == 5):
# patch_overlap = (60,60,0)
patch_overlap = (4*tile_size,4*tile_size,0)
else:
print("Error initialize_gridsampler_variables - adjacent_tiles_dim")
sys.exit()
# print('patch_size: ',patch_size)
# print('patch_overlap: ',patch_overlap)
# print('padding_mode: ',padding_mode)
padding_mode = padding_mode
return patch_size, patch_overlap, padding_mode
# Initialize TorchIO uniform Sampler variables
# patch_overlap = (0,0,0) # Not directly used
# patch overlap is generated by the random locations
def initialize_uniformsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None):
# Define patch_size
patch_size = (adjacent_tiles_dim * tile_size, adjacent_tiles_dim * tile_size, nb_image_layers)
# Define patch_overlap
patch_overlap = (0,0,0)
# Define padding_mode
#padding_mode = 'symmetric'
padding_mode = padding_mode
# print('patch_size: ',patch_size)
# print('patch_overlap: ',patch_overlap)
# print('padding_mode: ',padding_mode)
return patch_size, patch_overlap, padding_mode
# Generate TorchIO aggregator patch_location for prediction
# Example - Input patch location for Tiles 5x5 = [ 0, 0, 0, 75, 75, 122]
# Example - Output patch location for Tiles5x5 = [ 2, 2, 0, 3, 3, 1]
# - Use CenterTile location
# - Divide by TileSize
# - Depth = 1
def prediction_patch_location(input_location, tile_size, adjacent_tiles_dim):
if (adjacent_tiles_dim == 1):
output_location = input_location
elif (adjacent_tiles_dim == 3):
#CenterTile_Update = torch.tensor([15,15,0,-15,-15,0], dtype=torch.int64)
CenterTile_Update = torch.tensor([tile_size,tile_size,0,-tile_size,-tile_size,0], dtype=torch.int64)
output_location = input_location + CenterTile_Update[None,:]
elif (adjacent_tiles_dim == 5):
#CenterTile_Update = torch.tensor([30,30,0,-30,-30,0], dtype=torch.int64)
CenterTile_Update = torch.tensor([2*tile_size,2*tile_size,0,-2*tile_size,-2*tile_size,0], dtype=torch.int64)
output_location = input_location + CenterTile_Update[None,:]
else:
print("Error prediction_patch_location - adjacent_tiles_dim")
sys.exit()
# print('\t\t output_location shape: ', output_location.shape)
# print('\t\t output_location: ', output_location)
# Divide by tile_size
output_location = torch.div(output_location, tile_size, rounding_mode='floor')
# Update depth to 1 (from 3D volume to 2D image)
output_location[:,-1]=1
# print('\t\t output_location shape: ', output_location.shape)
# print('\t\t output_location: ', output_location)
return output_location
This diff is collapsed.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MySubNetworkPhase1(nn.Module):
def __init__(self, nb_image_layers, tile_size, adjacent_tiles_dim, list_fc_features):
super().__init__()
self.nb_image_layers = nb_image_layers
self.tile_size = tile_size
self.adjacent_tiles_dim = adjacent_tiles_dim
#print('SUBNETPHASE1 - nb_image_layers',self.nb_image_layers)
#print('SUBNETPHASE1 - adjacent_tiles_dim',self.adjacent_tiles_dim)
self.list_fc_features = list_fc_features
self.flatten1 = nn.Flatten()
nb_input_features = self.nb_image_layers * self.tile_size * self.tile_size # 120*15*15 #27000
#print('SUBNETPHASE1 - nb_input_features',nb_input_features)
self.fc1 = nn.Linear(nb_input_features, self.list_fc_features[0])
self.BN1 = torch.nn.BatchNorm1d(self.list_fc_features[0])
#self.drop1 = nn.Dropout(p=0.25)
self.fc2 = nn.Linear(self.list_fc_features[0], self.list_fc_features[1])
self.BN2 = torch.nn.BatchNorm1d(self.list_fc_features[1])
#self.drop2 = nn.Dropout(p=0.25)
self.fc3 = nn.Linear(self.list_fc_features[1], self.list_fc_features[2])
self.BN3 = torch.nn.BatchNorm1d(self.list_fc_features[2])
#self.drop3 = nn.Dropout(p=0.25)
self.fc4 = nn.Linear(self.list_fc_features[2], self.list_fc_features[3])
#self.BN4 = torch.nn.BatchNorm1d(self.list_fc_features[3])
#self.drop4 = nn.Dropout(p=0.25)
def forward(self, x1):
x1 = self.flatten1(x1)
x = self.fc1(x1)
x = F.relu(x)
x = self.BN1(x)
#x = self.drop1(x)
x = self.fc2(x)
x = F.relu(x)
x = self.BN2(x)
#x = self.drop2(x)
x = self.fc3(x)
x = F.relu(x)
x = self.BN3(x)
#x = self.drop3(x)
x = self.fc4(x)
out = F.relu(x)
#out = self.BN4(out)
#out = self.drop4(out)
return out
# sub-Network - Phase2
class MySubNetworkPhase2(nn.Module):
def __init__(self, fc_inputfeatures, list_fc_features):
super().__init__()
# Input features after concatenation e.g. [3x3x64] or [5x5x64]
self.fc_inputfeatures = fc_inputfeatures
#list_nbfeatures = [128,64,32]
self.list_fc_features = list_fc_features
self.BN0 = torch.nn.BatchNorm1d(self.fc_inputfeatures)
self.drop0 = nn.Dropout(p=0.25)
self.fc1 = nn.Linear(self.fc_inputfeatures, self.list_fc_features[0])
self.BN1 = torch.nn.BatchNorm1d(self.list_fc_features[0])
self.drop1 = nn.Dropout(p=0.20)
self.fc2 = nn.Linear(self.list_fc_features[0], self.list_fc_features[1])
self.BN2 = torch.nn.BatchNorm1d(self.list_fc_features[1])
self.drop2 = nn.Dropout(p=0.10)
self.fc3 = nn.Linear(self.list_fc_features[1], self.list_fc_features[2])
self.BN3 = torch.nn.BatchNorm1d(self.list_fc_features[2])
#self.drop3 = nn.Dropout(p=0.25)
self.fc4 = nn.Linear(self.list_fc_features[2], 1)
self.flatten2 = nn.Flatten()
def forward(self, x, x2):
x = self.BN0(x)
x = self.drop0(x)
x = self.fc1(x)
x = F.relu(x)
x = self.BN1(x)
x = self.drop1(x)
x = self.fc2(x)
x = F.relu(x)
x = self.BN2(x)
x = self.drop2(x)
x = self.fc3(x)
x = F.relu(x)
x = self.BN3(x)
#x = self.drop3(x)
x = self.fc4(x)
x2 = self.flatten2(x2)
out = torch.add(x,x2)
return out
# Parallel network for neighboring tiles (e.g. 3x3, 5x5)
# - Phase1 : dynamic parallel sub-networks with FC layers (e.g. 3x3 or 5x5)
# - Concatenating 3x3 or 5x5 features
# - Phase 2 - single sub-network with FC layers
class MyParallelNetwork(nn.Module):
def __init__(self, nb_image_layers, tile_size, adjacent_tiles_dim, dict_fc_features):
super().__init__()
self.nb_image_layers = nb_image_layers
self.tile_size = tile_size
self.adjacent_tiles_dim = adjacent_tiles_dim
# Number of sub-networks: 5x5
self.nb_subnetworks = self.adjacent_tiles_dim * self.adjacent_tiles_dim
# FC features
self.dict_fc_features = dict_fc_features
self.Phase2_InputFeatures = self.adjacent_tiles_dim * self.adjacent_tiles_dim * self.dict_fc_features['Phase1'][-1]
#print('NETWORK - nb_image_layers',self.nb_image_layers)
#print('NETWORK - adjacent_tiles_dim',self.adjacent_tiles_dim)
#print('NETWORK - nb_subnetworks',self.nb_subnetworks)
# define ModuleList of subnetworks
self.Phase1_subnetworks = nn.ModuleList([MySubNetworkPhase1(self.nb_image_layers, self.tile_size, self.adjacent_tiles_dim, self.dict_fc_features['Phase1']) for i in range(self.nb_subnetworks)])
self.Phase2_net = MySubNetworkPhase2(self.Phase2_InputFeatures, self.dict_fc_features['Phase2'])
def forward(self, x1, x2):
# Phase 1 - Parallel subnets
# x1 & x2 = list of 5x5 neighboring tiles
outputs_subnetworks = [Phase1_net(x1[:,i,...]) for i, Phase1_net in enumerate(self.Phase1_subnetworks)]
#print('NETWORK - len outputs_subnetworks',len(outputs_subnetworks))
#print('NETWORK - outputs_subnetworks[0] shape',outputs_subnetworks[0].shape)
# Concatenating outputs of subnets
out_Phase1 = torch.cat((outputs_subnetworks), dim = 1)
#print('NETWORK - out_Phase1 shape',out_Phase1.shape)
# Phase 2 - FC layers
out_Phase2 = self.Phase2_net(out_Phase1, x2)
return out_Phase2
import torch
from torchio.data.subject import Subject
from torchio.data.sampler import RandomSampler
from typing import Generator
import numpy as np
class MyUniformSampler(RandomSampler):
"""Randomly extract patches from a volume with uniform probability.
Args:
patch_size: See :class:`~torchio.data.PatchSampler`.
"""
def __init__(self, patch_size, tile_size):
super().__init__(patch_size)
self.tile_size = tile_size
def get_probability_map(self, subject: Subject) -> torch.Tensor:
return torch.ones(1, *subject.spatial_shape)
def _generate_patches(
self,
subject: Subject,
num_patches: int = None,
) -> Generator[Subject, None, None]:
valid_range = subject.spatial_shape - self.patch_size
patches_left = num_patches if num_patches is not None else True
# Random location using tile_size (multiple of tile_size)
while patches_left:
index_ini = [
self.tile_size * torch.randint(x//self.tile_size + 1, (1,)).item()
for x in valid_range
]
index_ini_array = np.asarray(index_ini)
yield self.extract_patch(subject, index_ini_array)
if num_patches is not None:
patches_left -= 1
import numpy as np
import matplotlib.pyplot as plt
import torch
def imshow(inp, title=None):
"""Imshow for Tensor."""
# inp = inp.numpy().transpose((1, 2, 0))
# mean = np.array([0.485, 0.456, 0.406])
# std = np.array([0.229, 0.224, 0.225])
# inp = std * inp + mean
# inp = np.clip(inp, 0, 1)
print('imshow inp.shape: ',inp.shape)
# Transpose axes for matplotlib from HWC to CHW (Channel, Height, Width)
inp = inp.numpy().transpose((1, 2, 0))
print('imshow inp.shape: ',inp.shape)
fig = plt.subplots()
plt.imshow(inp,cmap='gray')
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment