Commit 79594c9c authored by JosefBrandt's avatar JosefBrandt

Bugfix in handling invalid particles in detection.

parent ff9142ea
...@@ -143,11 +143,13 @@ class ParticleContainer(object): ...@@ -143,11 +143,13 @@ class ParticleContainer(object):
def getParticleOfIndex(self, index): def getParticleOfIndex(self, index):
try: try:
particle = self.particles[index] particle = self.particles[index]
assert particle.index == index, f'particle.index ({particle.index}) does not match requested index in particleList ({index})'
except: except:
print('failed getting particle') print('failed getting particle')
print('requested Index:', index) print('requested Index:', index)
print('len particles', len(self.particles)) print('len particles', len(self.particles))
assert particle.index == index, f'particle.index ({particle.index}) does match requested index in particleList ({index})' raise
return particle return particle
def getParticleIndexContainingSpecIndex(self, index): def getParticleIndexContainingSpecIndex(self, index):
......
...@@ -264,8 +264,8 @@ class ImageView(QtWidgets.QLabel): ...@@ -264,8 +264,8 @@ class ImageView(QtWidgets.QLabel):
painter.drawPolygon(c) painter.drawPolygon(c)
painter.setPen(QtCore.Qt.red) painter.setPen(QtCore.Qt.red)
painter.setBrush(QtCore.Qt.red) painter.setBrush(QtCore.Qt.red)
for p in self.measpoints: for points in self.measpoints:
for point in self.measpoints[p]: for point in points:
painter.drawEllipse(point.x-2, point.y-2, 5, 5) painter.drawEllipse(point.x-2, point.y-2, 5, 5)
if self.showseedpoints: if self.showseedpoints:
...@@ -667,6 +667,7 @@ class ParticleDetectionView(QtWidgets.QWidget): ...@@ -667,6 +667,7 @@ class ParticleDetectionView(QtWidgets.QWidget):
self.dataset.mode = "opticalscan" self.dataset.mode = "opticalscan"
self.dataset.save() self.dataset.save()
self.imageUpdate.emit(self.view.microscopeMode) self.imageUpdate.emit(self.view.microscopeMode)
self.view.resetParticleContours()
@QtCore.pyqtSlot() @QtCore.pyqtSlot()
def cancelThread(self): def cancelThread(self):
...@@ -770,17 +771,19 @@ class ParticleDetectionView(QtWidgets.QWidget): ...@@ -770,17 +771,19 @@ class ParticleDetectionView(QtWidgets.QWidget):
def applyResultsToDataset(self, measurementPoints, contours): def applyResultsToDataset(self, measurementPoints, contours):
self.dataset.ramanscandone = False self.dataset.ramanscandone = False
particlestats = self.getParticleStats(contours) particlestats, invalidParticleIndices = self.getParticleStats(contours)
contours = self.removeInvalidContours(contours, invalidParticleIndices)
measurementPoints = self.removeInvalidMeasurementPoints(measurementPoints, invalidParticleIndices)
particleContainer = self.dataset.particleContainer particleContainer = self.dataset.particleContainer
numParticles = len(contours) numParticles = len(contours)
particleContainer.initializeParticles(numParticles) particleContainer.initializeParticles(numParticles)
particleContainer.setParticleContours(contours) particleContainer.setParticleContours(contours)
particleContainer.setParticleStats(particlestats) particleContainer.setParticleStats(particlestats)
particleContainer.clearMeasurements() particleContainer.clearMeasurements()
for particleIndex in measurementPoints.keys(): for particleIndex, measPoints in enumerate(measurementPoints):
measPoints = measurementPoints[particleIndex]
for index, point in enumerate(measPoints): for index, point in enumerate(measPoints):
curParticle = particleContainer.getParticleOfIndex(particleIndex) curParticle = particleContainer.getParticleOfIndex(particleIndex)
indexOfNewMeas = particleContainer.addEmptyMeasurement() indexOfNewMeas = particleContainer.addEmptyMeasurement()
...@@ -794,15 +797,25 @@ class ParticleDetectionView(QtWidgets.QWidget): ...@@ -794,15 +797,25 @@ class ParticleDetectionView(QtWidgets.QWidget):
def getParticleStats(self, contours): def getParticleStats(self, contours):
particlestats = [] particlestats = []
zvalimg = loadZValImageFromDataset(self.dataset) zvalimg = loadZValImageFromDataset(self.dataset)
for contour in contours: invalidParticleIndices = []
for contourIndex, contour in enumerate(contours):
try: try:
stats = getParticleStatsWithPixelScale(contour, self.dataset, fullimage=self.img, zimg=zvalimg) stats = getParticleStatsWithPixelScale(contour, self.dataset, fullimage=self.img, zimg=zvalimg)
particlestats.append(stats)
except InvalidParticleError: except InvalidParticleError:
print('invalid contour in detection, skipping partile. Contour is:', contour) print('invalid contour in detection, skipping particle. Contour is:', contour)
invalidParticleIndices.append(contourIndex)
continue continue
particlestats.append(stats)
return particlestats
return particlestats, invalidParticleIndices
def removeInvalidContours(self, contours, invalidParticleIndices):
validContours = [cnt for index, cnt in enumerate(contours) if index not in invalidParticleIndices]
return validContours
def removeInvalidMeasurementPoints(self, measurementPoints, invalidParticleIndices):
validMeasPoints = [points for index, points in enumerate(measurementPoints) if index not in invalidParticleIndices]
return validMeasPoints
def updateSeedsInSampleview(self): def updateSeedsInSampleview(self):
self.view.updateSeedPointMarkers() self.view.updateSeedPointMarkers()
......
...@@ -30,7 +30,6 @@ import skfuzzy as fuzz ...@@ -30,7 +30,6 @@ import skfuzzy as fuzz
import random import random
from PyQt5 import QtCore from PyQt5 import QtCore
def closeHolesOfSubImage(subimg): def closeHolesOfSubImage(subimg):
subimg = cv2.copyMakeBorder(subimg, 1, 1, 1, 1, 0) subimg = cv2.copyMakeBorder(subimg, 1, 1, 1, 1, 0)
im_floodfill = subimg.copy() im_floodfill = subimg.copy()
...@@ -248,7 +247,7 @@ class Segmentation(QtCore.QObject): ...@@ -248,7 +247,7 @@ class Segmentation(QtCore.QObject):
self.detectionState.emit(f'DO: maxVal={n-1}') self.detectionState.emit(f'DO: maxVal={n-1}')
del thresh del thresh
measurementPoints = {} measurementPoints = []
finalcontours = [] finalcontours = []
particleIndex = 0 particleIndex = 0
...@@ -262,7 +261,7 @@ class Segmentation(QtCore.QObject): ...@@ -262,7 +261,7 @@ class Segmentation(QtCore.QObject):
for label in range(1, n): for label in range(1, n):
area = stats[label, cv2.CC_STAT_AREA] area = stats[label, cv2.CC_STAT_AREA]
if minArea < area < maxArea: if minArea < area:
up = stats[label, cv2.CC_STAT_TOP] up = stats[label, cv2.CC_STAT_TOP]
left = stats[label, cv2.CC_STAT_LEFT] left = stats[label, cv2.CC_STAT_LEFT]
width = stats[label, cv2.CC_STAT_WIDTH] width = stats[label, cv2.CC_STAT_WIDTH]
...@@ -336,7 +335,7 @@ class Segmentation(QtCore.QObject): ...@@ -336,7 +335,7 @@ class Segmentation(QtCore.QObject):
for cnt in tmpcontours: for cnt in tmpcontours:
contourArea = cv2.contourArea(cnt) * scaleFactor**2 contourArea = cv2.contourArea(cnt) * scaleFactor**2
if contourArea >= minArea: if minArea <= contourArea <= maxArea:
tmplabel = markers[cnt[0,0,1],cnt[0,0,0]] tmplabel = markers[cnt[0,0,1],cnt[0,0,0]]
if tmplabel ==0: if tmplabel ==0:
continue continue
...@@ -362,11 +361,11 @@ class Segmentation(QtCore.QObject): ...@@ -362,11 +361,11 @@ class Segmentation(QtCore.QObject):
cnt[i][0][1] += up cnt[i][0][1] += up
finalcontours.append(cnt) finalcontours.append(cnt)
measurementPoints[particleIndex] = [] measurementPoints.append([])
for index in range(0, len(x)): for index in range(0, len(x)):
newMeasPoint = MeasurementPoint(particleIndex, x[index] + x0 + left, y[index] + y0 + up) newMeasPoint = MeasurementPoint(particleIndex, x[index] + x0 + left, y[index] + y0 + up)
measurementPoints[particleIndex].append(newMeasPoint) measurementPoints[-1].append(newMeasPoint)
particleIndex += 1 particleIndex += 1
...@@ -389,11 +388,7 @@ class Segmentation(QtCore.QObject): ...@@ -389,11 +388,7 @@ class Segmentation(QtCore.QObject):
if self.measurefrac < 1.0: if self.measurefrac < 1.0:
nMeasurementsDesired = int(np.round(self.measurefrac * len(measurementPoints))) nMeasurementsDesired = int(np.round(self.measurefrac * len(measurementPoints)))
print(f'selecting {nMeasurementsDesired} of {len(measurementPoints)} measuring spots') print(f'selecting {nMeasurementsDesired} of {len(measurementPoints)} measuring spots')
partIndicesToMeasure = random.sample(measurementPoints.keys(), nMeasurementsDesired) measurementPoints = random.sample(measurementPoints, nMeasurementsDesired)
newMeasPoints = {}
for index in partIndicesToMeasure:
newMeasPoints[index] = measurementPoints[index]
measurementPoints = newMeasPoints
total_time = time()-t0 total_time = time()-t0
print('segmentation took', total_time, 'seconds') print('segmentation took', total_time, 'seconds')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment