diff --git a/.dockerignore b/.dockerignore index 7797741..96a401f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,5 @@ +**/tmp* +test_data .ruff_cache .tox *.egg-info diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml b/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml index ea0c301..b5c570c 100644 --- a/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml +++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml @@ -61,7 +61,7 @@ Superpixel parameters gensuperpixels - generate-superpxiels + generate-superpixels If an image does not have an annotation with superpixels, generate one true diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py index a694d59..f708aab 100644 --- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py +++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py @@ -416,12 +416,14 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature, item['name'], annotrec['annotation']['name'], annotrec['_id'], annotrec['_version'])) featurePath = os.path.join(record['tempdir'], feature['name']) gc.downloadFile(feature['_id'], featurePath) + print(f"Downloaded '{feature['_id']}' to '{featurePath}'") with h5py.File(featurePath, 'r') as ffptr: fds = ffptr['images'] for idx, labelnum in enumerate(elem['values']): if labelnum and labelnum < len(elem['categories']): labelname = elem['categories'][labelnum]['label'] if labelname in excludeLabelList: + skipped_excluded += 1 continue if labelname not in record['groups']: record['groups'][labelname] = elem['categories'][labelnum] @@ -449,6 +451,7 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature, record['lastlog'] = time.time() print(record['ds'].shape, record['counts'], '%5.3f' % (time.time() - record['starttime'])) + print(f"Skipped {skipped_excluded} samples with labels that were excluded") def trainModel(self, gc, folderId, annotationName, features, modelFolderId, batchSize, epochs, trainingSplit, randomInput, labelList, @@ -506,11 +509,11 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId, for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)): with attempt: modelFile = gc.uploadFileToFolder(modelFolderId, modelPath) - print('Saved model') + print(f'Saved model to {modelFolderId}') for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)): with attempt: modTrainingFile = gc.uploadFileToFolder(modelFolderId, modTrainingPath) - print('Saved modTraining') + print(f'Saved modTraining to {modelFolderId}') return modelFile, modTrainingFile def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir, model, item, @@ -734,7 +737,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId, modelFile = next(gc.listFile(item['_id'], limit=1)) break if not modelFile: - print('No model file found') + print(f'No model file found in {modelFolderId}') return print(modelFile['name'], item) modelPath = os.path.join(tempdir, modelFile['name']) @@ -747,7 +750,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId, modTrainingFile = next(gc.listFile(item['_id'], limit=1)) break if not modTrainingFile: - print('No modTraining file found') + print(f'No modTraining file found in {modelFolderId}') return print(modTrainingFile['name'], item) modTrainingPath = os.path.join(tempdir, modTrainingFile['name']) @@ -797,16 +800,21 @@ def main(self, args): gc, args.images, args.annotationName, args.radius, args.magnification, args.annotationDir, args.numWorkers, prog) + print("Creating features...") features = self.createFeatures( gc, args.images, args.annotationName, args.features, args.patchSize, args.numWorkers, prog) + print("Done creating features...") if args.train: self.trainModel( gc, args.images, args.annotationName, features, args.modeldir, args.batchSize, args.epochs, args.split, args.randominput, args.labels, args.exclude, prog) + print("Done training...") self.predictLabels( gc, args.images, args.annotationName, features, args.modeldir, args.annotationDir, args.heatmaps, args.radius, args.magnification, args.certainty, args.batchSize, prog) + print("Done predicting labels...") + print("Done, exiting") diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py index 0af02d8..958ab42 100644 --- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py +++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py @@ -3,6 +3,7 @@ from typing import Optional import h5py +import numpy as np import tensorflow as tf from SuperpixelClassificationBase import SuperpixelClassificationBase diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py index 672584c..1bcf297 100644 --- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py +++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py @@ -171,6 +171,7 @@ def trainModelDetails( val_ds: torch.utils.data.TensorDataset train_dl: torch.utils.data.DataLoader val_dl: torch.utils.data.DataLoader + prog.message('Loading features for model training') train_arg1 = torch.from_numpy(record['ds'][train_indices].transpose((0, 3, 2, 1))) train_arg2 = torch.from_numpy(record['labelds'][train_indices]) val_arg1 = torch.from_numpy(record['ds'][val_indices].transpose((0, 3, 2, 1))) @@ -410,9 +411,15 @@ def cacheOptimalBatchSize(self, batchSize: int, model: torch.nn.Module, training return batchSize def loadModel(self, modelPath): - model = torch.load(modelPath) - model.eval() - return model + self.add_safe_globals() + try: + model = torch.load(modelPath, weights_only=False) + model.eval() + return model + except Exception as e: + print(f"Unable to load {modelPath}") + raise + def saveModel(self, model, modelPath): torch.save(model, modelPath) diff --git a/tools/inspect_image_feature_file.py b/tools/inspect_image_feature_file.py new file mode 100644 index 0000000..a93d911 --- /dev/null +++ b/tools/inspect_image_feature_file.py @@ -0,0 +1,37 @@ +''' +This script will open a feature file (.h5) and show a 3x3 grid of images. +This tool is useful if you suspect that features are not extracted properly, for example due to erroneous mask values/indexing. +''' + +import h5py +import matplotlib.pyplot as plt +import numpy as np +import sys + +if len(sys.argv) > 0: + feature_file = sys.argv[1] +else: + feature_file = "features.h5" + +# open the file +with h5py.File(feature_file, "r") as f: + # get the images dataset + images = f["images"] + # get the first 9 images + images = images[:9] + # reshape the images to 3x3 + #images = np.reshape(images, (3,3,100,100,3)) + # transpose the images to 3x3 + #images = np.transpose(images, (0,2,1,3,4)) + # flatten the images to 9x100x100x3 + #images = np.reshape(images, (9,100,100,3)) + + # hide axis from pyplot + plt.axis('off') + + # plot the images + for i in range(9): + plt.subplot(3,3,i+1) + plt.imshow(images[i]) + plt.show() + print(f"Image {i+1} is {images[i].shape}")