Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
**/tmp*
test_data
.ruff_cache
.tox
*.egg-info
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<description>Superpixel parameters</description>
<boolean>
<name>gensuperpixels</name>
<longflag>generate-superpxiels</longflag>
<longflag>generate-superpixels</longflag>
<description>If an image does not have an annotation with superpixels, generate one</description>
<label>Generate superpixels</label>
<default>true</default>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,14 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature,
item['name'], annotrec['annotation']['name'], annotrec['_id'], annotrec['_version']))
featurePath = os.path.join(record['tempdir'], feature['name'])
gc.downloadFile(feature['_id'], featurePath)
print(f"Downloaded '{feature['_id']}' to '{featurePath}'")
with h5py.File(featurePath, 'r') as ffptr:
fds = ffptr['images']
for idx, labelnum in enumerate(elem['values']):
if labelnum and labelnum < len(elem['categories']):
labelname = elem['categories'][labelnum]['label']
if labelname in excludeLabelList:
skipped_excluded += 1
continue
if labelname not in record['groups']:
record['groups'][labelname] = elem['categories'][labelnum]
Expand Down Expand Up @@ -449,6 +451,7 @@ def trainModelAddItem(self, gc, record, item, annotrec, elem, feature,
record['lastlog'] = time.time()
print(record['ds'].shape, record['counts'],
'%5.3f' % (time.time() - record['starttime']))
print(f"Skipped {skipped_excluded} samples with labels that were excluded")

def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
batchSize, epochs, trainingSplit, randomInput, labelList,
Expand Down Expand Up @@ -506,11 +509,11 @@ def trainModel(self, gc, folderId, annotationName, features, modelFolderId,
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
with attempt:
modelFile = gc.uploadFileToFolder(modelFolderId, modelPath)
print('Saved model')
print(f'Saved model to {modelFolderId}')
for attempt in Retrying(stop=stop_after_attempt(self.uploadRetries)):
with attempt:
modTrainingFile = gc.uploadFileToFolder(modelFolderId, modTrainingPath)
print('Saved modTraining')
print(f'Saved modTraining to {modelFolderId}')
return modelFile, modTrainingFile

def predictLabelsForItem(self, gc, annotationName, annotationFolderId, tempdir, model, item,
Expand Down Expand Up @@ -734,7 +737,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
modelFile = next(gc.listFile(item['_id'], limit=1))
break
if not modelFile:
print('No model file found')
print(f'No model file found in {modelFolderId}')
return
print(modelFile['name'], item)
modelPath = os.path.join(tempdir, modelFile['name'])
Expand All @@ -747,7 +750,7 @@ def predictLabels(self, gc, folderId, annotationName, features, modelFolderId,
modTrainingFile = next(gc.listFile(item['_id'], limit=1))
break
if not modTrainingFile:
print('No modTraining file found')
print(f'No modTraining file found in {modelFolderId}')
return
print(modTrainingFile['name'], item)
modTrainingPath = os.path.join(tempdir, modTrainingFile['name'])
Expand Down Expand Up @@ -797,16 +800,21 @@ def main(self, args):
gc, args.images, args.annotationName, args.radius, args.magnification,
args.annotationDir, args.numWorkers, prog)

print("Creating features...")
features = self.createFeatures(
gc, args.images, args.annotationName, args.features, args.patchSize,
args.numWorkers, prog)
print("Done creating features...")

if args.train:
self.trainModel(
gc, args.images, args.annotationName, features, args.modeldir, args.batchSize,
args.epochs, args.split, args.randominput, args.labels, args.exclude, prog)
print("Done training...")

self.predictLabels(
gc, args.images, args.annotationName, features, args.modeldir, args.annotationDir,
args.heatmaps, args.radius, args.magnification, args.certainty, args.batchSize,
prog)
print("Done predicting labels...")
print("Done, exiting")
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import h5py
import numpy as np
import tensorflow as tf
from SuperpixelClassificationBase import SuperpixelClassificationBase

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def trainModelDetails(
val_ds: torch.utils.data.TensorDataset
train_dl: torch.utils.data.DataLoader
val_dl: torch.utils.data.DataLoader
prog.message('Loading features for model training')
train_arg1 = torch.from_numpy(record['ds'][train_indices].transpose((0, 3, 2, 1)))
train_arg2 = torch.from_numpy(record['labelds'][train_indices])
val_arg1 = torch.from_numpy(record['ds'][val_indices].transpose((0, 3, 2, 1)))
Expand Down Expand Up @@ -410,9 +411,15 @@ def cacheOptimalBatchSize(self, batchSize: int, model: torch.nn.Module, training
return batchSize

def loadModel(self, modelPath):
model = torch.load(modelPath)
model.eval()
return model
self.add_safe_globals()
try:
model = torch.load(modelPath, weights_only=False)
model.eval()
return model
except Exception as e:
print(f"Unable to load {modelPath}")
raise


def saveModel(self, model, modelPath):
torch.save(model, modelPath)
37 changes: 37 additions & 0 deletions tools/inspect_image_feature_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
'''
This script will open a feature file (.h5) and show a 3x3 grid of images.
This tool is useful if you suspect that features are not extracted properly, for example due to erroneous mask values/indexing.
'''

import h5py
import matplotlib.pyplot as plt
import numpy as np
import sys

if len(sys.argv) > 0:
feature_file = sys.argv[1]
else:
feature_file = "features.h5"

# open the file
with h5py.File(feature_file, "r") as f:
# get the images dataset
images = f["images"]
# get the first 9 images
images = images[:9]
# reshape the images to 3x3
#images = np.reshape(images, (3,3,100,100,3))
# transpose the images to 3x3
#images = np.transpose(images, (0,2,1,3,4))
# flatten the images to 9x100x100x3
#images = np.reshape(images, (9,100,100,3))

# hide axis from pyplot
plt.axis('off')

# plot the images
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(images[i])
plt.show()
print(f"Image {i+1} is {images[i].shape}")