diff --git a/.dockerignore b/.dockerignore
index 7797741..96a401f 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,3 +1,5 @@
+**/tmp*
+test_data
.ruff_cache
.tox
*.egg-info
diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml b/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml
index ea0c301..b5c570c 100644
--- a/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml
+++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassification.xml
@@ -61,7 +61,7 @@
Superpixel parameters
gensuperpixels
- generate-superpxiels
+ generate-superpixels
If an image does not have an annotation with superpixels, generate one
true
diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py
index a694d59..f708aab 100644
--- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py
+++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationBase.py
@@ -416,12 +416,14 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature,
item['name'], annotrec['annotation']['name'], annotrec['_id'], annotrec['_version']))
featurePath = os.path.join(record['tempdir'], feature['name'])
gc.downloadFile(feature['_id'], featurePath)
+ print(f"Downloaded '{feature['_id']}' to '{featurePath}'")
with h5py.File(featurePath, 'r') as ffptr:
fds = ffptr['images']
for idx, labelnum in enumerate(elem['values']):
if labelnum and labelnum < len(elem['categories']):
labelname = elem['categories'][labelnum]['label']
if labelname in excludeLabelList:
+ skipped_excluded += 1
continue
if labelname not in record['groups']:
record['groups'][labelname] = elem['categories'][labelnum]
@@ -449,6 +451,7 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature,
record['lastlog'] = time.time()
print(record['ds'].shape, record['counts'],
'%5.3f' % (time.time() - record['starttime']))
+ print(f"Skipped {skipped_excluded} samples with labels that were excluded")
def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
batchSize, epochs, trainingSplit, randomInput, labelList,
@@ -506,11 +509,11 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
with attempt:
modelFile = gc.uploadFileToFolder(modelFolderId, modelPath)
- print('Saved model')
+ print(f'Saved model to {modelFolderId}')
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
with attempt:
modTrainingFile = gc.uploadFileToFolder(modelFolderId, modTrainingPath)
- print('Saved modTraining')
+ print(f'Saved modTraining to {modelFolderId}')
return modelFile, modTrainingFile
def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir, model, item,
@@ -734,7 +737,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
modelFile = next(gc.listFile(item['_id'], limit=1))
break
if not modelFile:
- print('No model file found')
+ print(f'No model file found in {modelFolderId}')
return
print(modelFile['name'], item)
modelPath = os.path.join(tempdir, modelFile['name'])
@@ -747,7 +750,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
modTrainingFile = next(gc.listFile(item['_id'], limit=1))
break
if not modTrainingFile:
- print('No modTraining file found')
+ print(f'No modTraining file found in {modelFolderId}')
return
print(modTrainingFile['name'], item)
modTrainingPath = os.path.join(tempdir, modTrainingFile['name'])
@@ -797,16 +800,21 @@ def main(self, args):
gc, args.images, args.annotationName, args.radius, args.magnification,
args.annotationDir, args.numWorkers, prog)
+ print("Creating features...")
features = self.createFeatures(
gc, args.images, args.annotationName, args.features, args.patchSize,
args.numWorkers, prog)
+ print("Done creating features...")
if args.train:
self.trainModel(
gc, args.images, args.annotationName, features, args.modeldir, args.batchSize,
args.epochs, args.split, args.randominput, args.labels, args.exclude, prog)
+ print("Done training...")
self.predictLabels(
gc, args.images, args.annotationName, features, args.modeldir, args.annotationDir,
args.heatmaps, args.radius, args.magnification, args.certainty, args.batchSize,
prog)
+ print("Done predicting labels...")
+ print("Done, exiting")
diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py
index 0af02d8..958ab42 100644
--- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py
+++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTensorflow.py
@@ -3,6 +3,7 @@
from typing import Optional
import h5py
+import numpy as np
import tensorflow as tf
from SuperpixelClassificationBase import SuperpixelClassificationBase
diff --git a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py
index 672584c..1bcf297 100644
--- a/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py
+++ b/superpixel_classification/SuperpixelClassification/SuperpixelClassificationTorch.py
@@ -171,6 +171,7 @@ def trainModelDetails(
val_ds: torch.utils.data.TensorDataset
train_dl: torch.utils.data.DataLoader
val_dl: torch.utils.data.DataLoader
+ prog.message('Loading features for model training')
train_arg1 = torch.from_numpy(record['ds'][train_indices].transpose((0, 3, 2, 1)))
train_arg2 = torch.from_numpy(record['labelds'][train_indices])
val_arg1 = torch.from_numpy(record['ds'][val_indices].transpose((0, 3, 2, 1)))
@@ -410,9 +411,15 @@ def cacheOptimalBatchSize(self, batchSize: int, model: torch.nn.Module, training
return batchSize
def loadModel(self, modelPath):
- model = torch.load(modelPath)
- model.eval()
- return model
+ self.add_safe_globals()
+ try:
+ model = torch.load(modelPath, weights_only=False)
+ model.eval()
+ return model
+ except Exception as e:
+ print(f"Unable to load {modelPath}")
+ raise
+
def saveModel(self, model, modelPath):
torch.save(model, modelPath)
diff --git a/tools/inspect_image_feature_file.py b/tools/inspect_image_feature_file.py
new file mode 100644
index 0000000..a93d911
--- /dev/null
+++ b/tools/inspect_image_feature_file.py
@@ -0,0 +1,37 @@
+'''
+This script will open a feature file (.h5) and show a 3x3 grid of images.
+This tool is useful if you suspect that features are not extracted properly, for example due to erroneous mask values/indexing.
+'''
+
+import h5py
+import matplotlib.pyplot as plt
+import numpy as np
+import sys
+
+if len(sys.argv) > 0:
+ feature_file = sys.argv[1]
+else:
+ feature_file = "features.h5"
+
+# open the file
+with h5py.File(feature_file, "r") as f:
+ # get the images dataset
+ images = f["images"]
+ # get the first 9 images
+ images = images[:9]
+ # reshape the images to 3x3
+ #images = np.reshape(images, (3,3,100,100,3))
+ # transpose the images to 3x3
+ #images = np.transpose(images, (0,2,1,3,4))
+ # flatten the images to 9x100x100x3
+ #images = np.reshape(images, (9,100,100,3))
+
+ # hide axis from pyplot
+ plt.axis('off')
+
+ # plot the images
+ for i in range(9):
+ plt.subplot(3,3,i+1)
+ plt.imshow(images[i])
+ plt.show()
+ print(f"Image {i+1} is {images[i].shape}")