Commit 4eec270e authored by Clement Vachet's avatar Clement Vachet

Initial commit

parents
Pipeline #2710 canceled with stages
This diff is collapsed.
#!/usr/bin/env python3
import os
import sys
import argparse
import time
# Device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchio as tio
import imageio
import math
from network import *
import utils
import dataset
# --------------------
# Model - FC layers
dict_fc_features = {
# Phase1- concatenation on 3rd layer
'Phase1': [2048,512,256,64],
'Phase2': [128,64,32],
}
# MC Dropout
mc_dropout = True
mc_passes = 50
# --------------------
def arg_parser():
parser = argparse.ArgumentParser(description='Inference - ')
required = parser.add_argument_group('Required')
required.add_argument('--input', type=str, required=True,
help='Combined TIFF file (multi-layer)')
required.add_argument('--network', type=str, required=True,
help='pytorch neural network')
required.add_argument('--output', type=str, required=True,
help='Image prediction (2D TIFF file)')
options = parser.add_argument_group('Options')
options.add_argument('--tile_size', type=int, default=15,
help='tile size')
options.add_argument('--adjacent_tiles_dim', type=int, default=1,
help='adjacent tiles dim (e.g. 3, 5)')
options.add_argument('--bs', type=int, default=5000,
help='Batch size (default 5000)')
options.add_argument('--output_median', type=str,
help='Image output - median for MCDropout (2D TIFF file)')
options.add_argument('--output_cv', type=str,
help='Image output - Coefficient of Variation for MCDropout (2D TIFF file)')
return parser
def apply_dropout(m):
if m.__class__.__name__.startswith('Dropout'):
print('\t\t Enabling MC dropout!')
m.train()
#MAIN
def main(args=None):
args = arg_parser().parse_args(args)
InputFile = args.input
ModelName = args.network
OutputFile = args.output
OutputFile_median = args.output_median
OutputFile_CV = args.output_cv
TileSize = args.tile_size
AdjacentTilesDim = args.adjacent_tiles_dim
bs = args.bs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('\n--------------------')
since1 = time.time()
# TorchIO subject
print('\nGenerating TIO subject...')
Subject = tio.Subject(
Combined = tio.ScalarImage(InputFile),
)
# Initialize variables
InputFile_Shape = Subject['Combined'].shape
NbTiles_H = InputFile_Shape[1] // TileSize
NbTiles_W = InputFile_Shape[2] // TileSize
NbImageLayers = InputFile_Shape[3]
NbCorrLayers = NbImageLayers -4
InputDepth = NbCorrLayers
print('InputFile_Shape: ', InputFile_Shape)
print('NbTiles_H: ', NbTiles_H)
print('NbTiles_W: ', NbTiles_W)
print('NbImageLayers: ', NbImageLayers)
print('InputDepth: ', InputDepth)
# GridSampler
print('\nGenerating Grid Sampler...')
patch_size, patch_overlap, padding_mode = dataset.initialize_gridsampler_variables(NbImageLayers, TileSize, AdjacentTilesDim, padding_mode=None)
print('patch_size: ',patch_size)
print('patch_overlap: ',patch_overlap)
print('padding_mode: ',padding_mode)
grid_sampler = tio.data.GridSampler(
subject = Subject,
patch_size = patch_size,
patch_overlap = patch_overlap,
padding_mode = padding_mode,
)
len_grid_sampler = len(grid_sampler)
print('length grid_sampler', len(grid_sampler))
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=bs)
aggregator = tio.data.GridAggregator(grid_sampler, overlap_mode = 'average')
print('\nLoading DNN model...')
model = MyParallelNetwork(InputDepth, TileSize, AdjacentTilesDim, dict_fc_features)
model.load_state_dict(torch.load(ModelName))
print(model)
model.to(device)
model.eval()
if mc_dropout:
print('\t MC Dropout')
model.apply(apply_dropout)
print('\nPatch-based inference...')
since2 = time.time()
#model = nn.Identity().eval()
with torch.no_grad():
for patch_idx, patches_batch in enumerate(patch_loader):
print('\t patch_idx: ', patch_idx)
#print('\t\t Preparing data...')
inputs = patches_batch['Combined'][tio.DATA]
print('\t\t inputs shape: ', inputs.shape)
input1_tiles, input2_tiles_real, GroundTruth_real = dataset.prepare_data_withfiltering(inputs, NbImageLayers, NbCorrLayers, TileSize, AdjacentTilesDim)
#print('\t\t Preparing data - done -')
input1_tiles = input1_tiles.to(device)
input2_tiles_real = input2_tiles_real.to(device)
#GroundTruth_real = GroundTruth_real.to(device)
# Reducing last dimension to compute loss
#GroundTruth_real = torch.squeeze(GroundTruth_real, dim=2)
print('\t\t input1_tiles shape: ', input1_tiles.shape)
print('\t\t input2_tiles_real shape:', input2_tiles_real.shape)
if mc_dropout:
# Perform multiple inference (mc_passes)
outputs_all = torch.empty(size=(mc_passes, input1_tiles.shape[0])).to(device)
for i in range(0, mc_passes):
outputs = model(input1_tiles, input2_tiles_real)
outputs_all[i] = torch.squeeze(outputs)
# Compute mean, std, CV (coefficient of variation), SE (standard error)
outputs_mean = torch.mean(outputs_all,0)
outputs_median = torch.median(outputs_all,0)[0]
outputs_std = torch.std(outputs_all,0)
outputs_cv = torch.div(outputs_std, torch.abs(outputs_mean))
# outputs_se = torch.div(outputs_std, math.sqrt(mc_passes))
outputs_combined = torch.stack((outputs_mean, outputs_median, outputs_cv), dim=1)
print('\t\t outputs shape: ',outputs.shape)
print('\t\t outputs device', outputs.device)
print('\t\t outputs_all shape: ', outputs_all.shape)
print('\t\t outputs_all device', outputs_all.device)
print('\t\t outputs_mean shape: ', outputs_mean.shape)
print('\t\t outputs_median shape: ', outputs_median.shape)
print('\t\t outputs_median type: ', outputs_median.type())
print('\t\t outputs_combined shape: ', outputs_combined.shape)
print('\t\t outputs_mean[:20]',outputs_mean[:20])
print('\t\t outputs_median[:20]',outputs_median[:20])
print('\t\t outputs_std[:20]',outputs_std[:20])
print('\t\t outputs_cv[:20]',outputs_cv[:20])
else:
outputs_combined = model(input1_tiles, input2_tiles_real)
print('\t\t outputs_combined device', outputs_combined.device)
print('\t\t outputs_combined shape: ', outputs_combined.shape)
# Reshape outputs to match location dimensions
outputs_combined_reshape = torch.reshape(outputs_combined,[outputs_combined.shape[0],outputs_combined.shape[1],1,1,1])
print('\t\t outputs_combined_reshape shape: ', outputs_combined_reshape.shape)
input_location = patches_batch[tio.LOCATION]
print('\t\t input_location shape: ', input_location.shape)
print('\t\t input_location type: ', input_location.dtype)
print('\t\t input_location[:20]: ', input_location[:20])
# Reshape input_location to prediction_location, to fit output image size (78,62,1)
pred_location = dataset.prediction_patch_location(input_location, TileSize, AdjacentTilesDim)
print('\t\t pred_location shape: ', pred_location.shape)
print('\t\t pred_location[:20]: ', pred_location[:20])
# Add batch with location to TorchIO aggregator
aggregator.add_batch(outputs_combined_reshape, pred_location)
# output_tensor shape [3, 1170, 930, 124]
output_tensor_combined = aggregator.get_output_tensor()
print('output_tensor_combined type: ', output_tensor_combined.dtype)
print('output_tensor_combined shape: ', output_tensor_combined.shape)
# Extract real information of interest [3, 78,62]
output_tensor_combined_real = output_tensor_combined[:,:NbTiles_H,:NbTiles_W,0]
print('output_tensor_combined_real shape: ', output_tensor_combined_real.shape)
output_combined_np = output_tensor_combined_real.numpy().squeeze()
print('output_combined_np type', output_combined_np.dtype)
print('output_combined_np shape', output_combined_np.shape)
if mc_dropout:
output_mean_np = output_combined_np[0,...]
output_median_np = output_combined_np[1,...]
output_cv_np = output_combined_np[2,...]
imageio_output_mean = np.moveaxis(output_mean_np, 0,1)
imageio_output_median = np.moveaxis(output_median_np, 0,1)
imageio_output_cv = np.moveaxis(output_cv_np, 0,1)
print('imageio_output_mean shape', imageio_output_mean.shape)
print('imageio_output_median shape', imageio_output_median.shape)
print('imageio_output_cv shape', imageio_output_cv.shape)
else:
output_np = output_combined_np
imageio_output = np.moveaxis(output_np, 0,1)
print('imageio_output shape', imageio_output.shape)
time_elapsed2 = time.time() - since2
if mc_dropout:
print('Writing output mean image via imageio...')
imageio.imwrite(OutputFile, imageio_output_mean)
print('Writing output median image via imageio...')
imageio.imwrite(OutputFile_median, imageio_output_median)
print('Writing output CV image via imageio...')
imageio.imwrite(OutputFile_CV, imageio_output_cv)
else:
print('Writing output image via imageio...')
imageio.imwrite(OutputFile, imageio_output)
time_elapsed3 = time.time() - since2
time_elapsed1 = time.time() - since1
print('--- Inference in {:.2f}s---'.format(time_elapsed2))
print('--- Inference and saving in {:.2f}s---'.format(time_elapsed3))
print('--- Total time in {:.2f}s---'.format(time_elapsed1))
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
# -*- coding: utf-8 -*-
"""
AI analysis via parallel neural networks
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import os
import copy
import yaml
import argparse
# Device for CUDA
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torchio as tio
# -----------
from model import *
import utils
import dataset
import sampler
# Manual seed
torch.manual_seed(42)
######################################################################
def arg_parser():
parser = argparse.ArgumentParser(description='AI analysis - DNN training')
required = parser.add_argument_group('Required')
required.add_argument('--config', type=str, required=True,
help='YAML configuration / parameter file')
options = parser.add_argument_group('Options')
options.add_argument('--verbose', action="store_true",
help='verbose mode')
return parser
def main(args=None):
args = arg_parser().parse_args(args)
config_filename = args.config
# plt.ion() # interactive mode
######################################################################
# Loading parameter file
print('\n--- Loading configuration file --- ')
with open(config_filename,'r') as yaml_file:
config_file = yaml.safe_load(yaml_file)
if args.verbose:
print('config_file', config_file)
# Defining parameters
CSVFile_train = config_file['CSVFile_train']
CSVFile_val = config_file['CSVFile_val']
model_filename = config_file['ModelName']
loss_filename = config_file['LossName']
nb_image_layers = config_file['NbImageLayers']
nb_corr_layers = config_file['NbCorrLayers']
tile_size = config_file['TileSize']
adjacent_tiles_dim = config_file['AdjacentTilesDim']
num_workers = config_file['num_workers']
samples_per_volume = config_file['samples_per_volume']
queue_length = config_file['queue_length']
data_filtering = config_file['DataFiltering']
confidence_threshold = config_file['ConfidenceThreshold']
dict_fc_features = config_file['dict_fc_features']
bs = config_file['bs']
lr = config_file['lr']
nb_epochs = config_file['nb_epochs']
# ------------------
print('\n--- Generating torchIO dataset ---')
File_list_train, TIOSubjects_list_train = dataset.GenerateTIOSubjectsList(CSVFile_train)
File_list_test, TIOSubjects_list_test = dataset.GenerateTIOSubjectsList(CSVFile_val)
# torchIO transforms
TIOtransforms = [
tio.RandomFlip(axes=('lr')),
]
TIOtransform = tio.Compose(TIOtransforms)
# TIO dataset
TIOSubjects_dataset_train = tio.SubjectsDataset(TIOSubjects_list_train, transform=TIOtransform)
TIOSubjects_dataset_test = tio.SubjectsDataset(TIOSubjects_list_test, transform=None)
print('Training set: ', len(TIOSubjects_dataset_train), 'subjects')
print('Validation set: ', len(TIOSubjects_dataset_test), 'subjects')
# ------------------
# ------------------
# Subject visualization
if args.verbose:
print('\n--- Quality control: TIOSubject Info ---')
MyTIOSubject = TIOSubjects_dataset_train[0]
print('MySubject: ', MyTIOSubject)
print('MySubject.shape: ', MyTIOSubject.shape)
print('MySubject.spacing: ', MyTIOSubject.spacing)
print('MySubject.spatial_shape: ', MyTIOSubject.spatial_shape)
print('MySubject.spatial_shape.type: ', type(MyTIOSubject.spatial_shape))
print('MySubject history: ', MyTIOSubject.get_composed_history())
# ------------------
# - - - - - - - - - - - - - -
# Training with GridSampler
# patch_size, patch_overlap, padding_mode = dataset.initialize_gridsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None)
# print('patch_size: ',patch_size)
# print('patch_overlap: ',patch_overlap)
# print('padding_mode: ',padding_mode)
# example_grid_sampler = tio.data.GridSampler(
# subject = MyTIOSubject,
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# samples_per_volume = len(example_grid_sampler)
# queue_length = samples_per_volume * num_workers
# print('samples_per_volume', samples_per_volume)
# print('queue_length', queue_length)
# sampler_train = tio.data.GridSampler(
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# sampler_test = tio.data.GridSampler(
# patch_size = patch_size,
# patch_overlap = patch_overlap,
# padding_mode = padding_mode,
# )
# - - - - - - - - - - - - - -
# - - - - - - - - - - - - - -
# Training with UniformSampler
print('\n--- Initializing patch sampling variables ---')
patch_size, patch_overlap, padding_mode = dataset.initialize_uniformsampler_variables(nb_image_layers, tile_size, adjacent_tiles_dim, padding_mode=None)
if args.verbose:
print('patch_size: ',patch_size)
print('patch_overlap: ',patch_overlap)
print('padding_mode: ',padding_mode)
print('samples_per_volume', samples_per_volume)
print('queue_length', queue_length)
sampler_train = sampler.MyUniformSampler(
patch_size = patch_size,
tile_size = tile_size,
)
sampler_test = sampler.MyUniformSampler(
patch_size = patch_size,
tile_size = tile_size,
)
patches_queue_train = tio.Queue(
subjects_dataset = TIOSubjects_dataset_train,
max_length = queue_length,
samples_per_volume = samples_per_volume,
sampler = sampler_train,
num_workers = num_workers,
shuffle_subjects = True,
shuffle_patches = True,
)
patches_queue_test = tio.Queue(
subjects_dataset = TIOSubjects_dataset_test,
max_length = queue_length,
samples_per_volume = samples_per_volume,
sampler = sampler_test,
num_workers = num_workers,
shuffle_subjects = True,
shuffle_patches = True,
)
patches_loader_train = DataLoader(
patches_queue_train,
batch_size = bs,
shuffle = True,
num_workers = 0, # this must be 0
)
patches_loader_test = DataLoader(
patches_queue_test,
batch_size = bs,
shuffle = False,
num_workers = 0, # this must be 0
)
# Dictionary for patch data loaders
patches_loader_dict = {}
patches_loader_dict['train'] = patches_loader_train
patches_loader_dict['val'] = patches_loader_test
# ----------------------
# Visualize input data
writer = SummaryWriter('tensorboard/MyNetwork')
# # Get a batch of training data
print('\n--- Quality control: patch inputs ---')
patches_batch = next(iter(patches_loader_dict['val']))
inputs = patches_batch['Combined'][tio.DATA]
locations = patches_batch[tio.LOCATION]
# Variable initialization needed for TensorBoard
input_Corr_tiles, input_TargetDisp_tiles_real, GroundTruth_real = dataset.prepare_data_withfiltering(inputs, nb_image_layers, nb_corr_layers, tile_size, adjacent_tiles_dim, data_filtering, confidence_threshold)
if args.verbose:
print('\ninput_Corr_tiles.shape: ', input_Corr_tiles.shape)
print('input_TargetDisp_tiles_real.shape: ', input_TargetDisp_tiles_real.shape)
print('GroundTruth_real.shape: ', GroundTruth_real.shape)
######################################################################
# Neural network - training
# ----------------------
#
# Create a neural network model and start training / testing.
#
# ----------------------
# Create model
print('\n--- Creating neural network architecture ---')
model_ft = Model(writer, nb_image_layers, nb_corr_layers, tile_size, adjacent_tiles_dim, model_filename, dict_fc_features, loss_filename, data_filtering, confidence_threshold)
# Tensorboard - add graph
writer.add_graph(model_ft.model, [input_Corr_tiles.to(model_ft.device), input_TargetDisp_tiles_real.to(model_ft.device)])
writer.close()
# ----------------------
# Train and evaluate
print('\n--- DNN training ---')
model_ft.train_model(dataloaders=patches_loader_dict, lr=lr, nb_epochs=nb_epochs)
# ----------------------
# Evaluate on validation data
print('\n--- DNN testing ---')
model_ft.test_model(dataloaders=patches_loader_dict)
# plt.ioff()
# plt.show()
if __name__ == "__main__":
main()
# - - - - - - - - -
# Inference on real validation data - CSV file
# Environment 1
python3 AI_Inference_CSV.py --config ./Config_Files/AI_Inference_Config_Tiles1x1.yaml --verbose
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles1x1.yaml > AI_Training_Tiles1x1_MCDropout.log
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles3x3.yaml > AI_Training_Tiles3x3_MCDropout.log
python3 AI_Training.py --config ./Config_Files/AI_Training_Config_Tiles5x5.yaml > AI_Training_Tiles5x5_MCDropout.log
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 1 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 3 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile: '../Example_CSV/Data_Example_val.csv'
ModelName: './pytorch_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.h5'
# Output files
OutputFolder: 'CNN_Output_WithFiltering_MCDropout/'
OutputSuffix: 'Pred_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 5 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 5000
# MC Dropout
MCDropout: True
MCPasses: 40
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles1x1_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 1 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True
ConfidenceThreshold: 0.0
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 500
# Learning rate
lr: 1.0e-3
# Number Epochs
nb_epochs: 15
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles3x3_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 3 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True
ConfidenceThreshold: 0.0
# Neural network parameters
# Model - FC layers
dict_fc_features:
Phase1: [2048,512,256,64]
Phase2: [128,64,32]
# Batch size
bs: 500
# Learning rate
lr: 1.0e-3
# Number Epochs
nb_epochs: 15
---
# Input files
CSVFile_train: '../Example_CSV/Data_Example_train.csv'
CSVFile_val: '../Example_CSV/Data_Example_val.csv'
# Output files
ModelName: './pytorch_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.h5'
LossName: './Loss_IRTPNet_Tiles5x5_WithFiltering0.0_MCDropout.png'
# Data parameters
NbImageLayers: 124
NbCorrLayers: 120
TileSize: 15
AdjacentTilesDim: 5 # 1 for 1x1, 3 for 3x3, 5 for 5x5 adjacent tiles
# Data sampling parameters
num_workers: 6
samples_per_volume: 1000
queue_length: 6000 # samples_per_volume * num_workers
# Data filtering parameters
DataFiltering: True