segmentation.py 29.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# -*- coding: utf-8 -*-
"""
GEPARD - Gepard-Enabled PARticle Detection
Copyright (C) 2018  Lars Bittrich and Josef Brandt, Leibniz-Institut für 
Polymerforschung Dresden e. V. <bittrich-lars@ipfdd.de>    

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program, see COPYING.  
If not, see <https://www.gnu.org/licenses/>.
"""
import numpy as np
import cv2
cv2.useOptimized()
from time import time
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.morphology import watershed
29
import skfuzzy as fuzz
30
import random
31 32
from PyQt5 import QtCore

33 34 35 36 37 38 39 40 41
def closeHolesOfSubImage(subimg):
    subimg = cv2.copyMakeBorder(subimg, 1, 1, 1, 1, 0)
    im_floodfill = subimg.copy()
    h, w = subimg.shape[:2]
    mask = np.zeros((h+2, w+2), np.uint8)
    cv2.floodFill(im_floodfill, mask, (0,0), 255);
    im_floodfill_inv = cv2.bitwise_not(im_floodfill)
    im_out = subimg | im_floodfill_inv
    return im_out[1:-1, 1:-1]
JosefBrandt's avatar
JosefBrandt committed
42

43
class Parameter(object):
44 45 46 47 48
    """
    A Parameter for driving the image segmentation. All Parameters are initialized in the Segmentation Class. 
    The DetectionView-Widget reads these parameters and creates and connects the necessary items in the ui.
    :return:
    """
49
    def __init__(self, name, dtype, value=None, minval=None, maxval=None, 
50
                 decimals=0, stepsize=1, helptext=None, show=False, linkedParameter=None):
51 52 53 54 55 56 57 58
        self.name = name
        self.dtype = dtype
        self.value = value
        self.valrange = (minval, maxval)
        self.decimals = decimals
        self.stepsize = stepsize
        self.helptext = helptext
        self.show = show
59
        self.linkedParameter = linkedParameter
60

JosefBrandt's avatar
JosefBrandt committed
61 62 63 64 65 66
class MeasurementPoint(object):
    def __init__(self, particleIndex, x, y):
        self.particleIndex = particleIndex
        self.x = x
        self.y = y

67 68 69 70
class Segmentation(QtCore.QObject):
    detectionState = QtCore.pyqtSignal(str)
    def __init__(self, dataset=None, parent=None):
        super(Segmentation, self).__init__()
71
        self.cancelcomputation = False
72
        self.parent = parent
Josef Brandt's avatar
 
Josef Brandt committed
73 74 75 76 77 78 79 80 81 82 83
        self.defaultParams = {'adaptiveHistEqu': False,
                              'claheTileSize': 128,
                              'contrastCurve': np.array([[50,0],[100,200],[200,255]]),
                              'activateContrastCurve': True,
                              'blurRadius': 9,
                              'activateLowThresh': True,
                              'lowThresh': 0.2,
                              'activateUpThresh': False,
                              'upThresh': 0.5,
                              'invertThresh': False,
                              'maxholebrightness': 0.5,
84
                              'minparticlesize': 20,
85
                              'enableMaxArea': False,
86
                              'maxparticlesize': 100000,
Josef Brandt's avatar
 
Josef Brandt committed
87
                              'minparticledistance': 20,
88
                              'closeBackground': False,
89
                              'fuzzycluster': False,
90
                              'maxComponentSize': 20000,
Josef Brandt's avatar
 
Josef Brandt committed
91
                              'measurefrac': 1,
92
                              'compactness': 0.0,
Josef Brandt's avatar
 
Josef Brandt committed
93
                              'seedRad': 3}
94 95 96 97 98 99 100 101
        if dataset is not None:
            self.detectParams = dataset.detectParams
            for key in self.defaultParams:
                if key not in self.detectParams:
                    self.detectParams[key] = self.defaultParams[key]
        else:
            self.detectParams = self.defaultParams        
        self.initializeParameters()
102
        
103
    def initializeParameters(self):
Josef Brandt's avatar
 
Josef Brandt committed
104 105 106
        parlist = [Parameter("adaptiveHistEqu", np.bool, self.detectParams['adaptiveHistEqu'], helptext="Adaptive histogram equalization", show=False, linkedParameter='claheTileSize'),
                   Parameter("claheTileSize", int, self.detectParams['claheTileSize'], 1, 2048, 1, 1, helptext="Tile size for adaptive histogram adjustment\nThe Image will be split into tiles with size approx. (NxN)", show=True),
                   Parameter("contrastCurve", np.ndarray, self.detectParams['contrastCurve'], helptext="Curve contrast"),
107
                   Parameter("activateContrastCurve", np.bool, self.detectParams['activateContrastCurve'], helptext="activate Contrast curve", show=True, linkedParameter='contrastCurve'),
108
                   Parameter("blurRadius", int, self.detectParams['blurRadius'], 3, 99, 1, 2, helptext="Blur radius", show=True),
109 110
                   Parameter("invertThresh", np.bool, self.detectParams['invertThresh'], helptext="Invert the current threshold", show=False),
                   Parameter("activateLowThresh", np.bool, self.detectParams['activateLowThresh'], helptext="activate lower threshold", show=False, linkedParameter='lowThresh'),
111
                   Parameter("lowThresh", float, self.detectParams['lowThresh'], .01, .9, 2, .02, helptext="Lower threshold", show=True),
112
                   Parameter("activateUpThresh", np.bool, self.detectParams['activateUpThresh'], helptext="activate upper threshold", show=False, linkedParameter='upThresh'),
113
                   Parameter("upThresh", float, self.detectParams['upThresh'], .01, 1.0, 2, .02, helptext="Upper threshold", show=False),
114
                   Parameter("maxholebrightness", float, self.detectParams['maxholebrightness'], 0, 1, 2, 0.02, helptext="Close holes brighter than..", show = True),
115 116 117
                   Parameter("minparticlesize", int, self.detectParams['minparticlesize'], 1, 1000, 0, 50, helptext="Min. particle size (µm)", show=False),
                   Parameter("enableMaxArea", np.bool, self.detectParams['enableMaxArea'], helptext="enable filtering for maximal particle size", show=False, linkedParameter='maxparticlearea'),
                   Parameter("maxparticlesize", int, self.detectParams['maxparticlesize'], 10, 1E9, 0, 50, helptext="Max. particle size (µm)", show=False),
Josef Brandt's avatar
 
Josef Brandt committed
118
                   Parameter("minparticledistance", int, self.detectParams['minparticledistance'], 5, 1000, 0, 5, helptext="Min. distance between particles", show=False),
119
                   Parameter("measurefrac", float, self.detectParams['measurefrac'], 0, 1, 2, stepsize = 0.05, helptext="measure fraction of particles", show=False),
120
                   Parameter("closeBackground", np.bool, self.detectParams['closeBackground'], helptext="close holes in sure background", show=False),
121
                   Parameter("fuzzycluster", np.bool, self.detectParams['fuzzycluster'], helptext='Enable Fuzzy Clustering', show=False),
122
                   Parameter("maxComponentSize", int, self.detectParams['maxComponentSize'], 100, 1E6, 0, 100, helptext='Maximum size in x or y of connected component.\nLarger components are scaled down accordingly', show=False),
123
                   Parameter("sure_fg", None, helptext="Show sure foreground", show=True),
124
                   Parameter("compactness", float, self.detectParams['compactness'], 0, 1, 2, 0.05, helptext="watershed compactness", show=False),
125 126 127 128
                   Parameter("watershed", None, helptext="Show watershed markers", show=True),
                   ]
        # make each parameter accessible via self.name
        # the variables are defined as properties and because of how the local context
129
        # in for loops works the actual setter and getter functions are defined inside 
130 131 132 133 134 135 136 137 138 139 140 141
        # a separate contex in a local function
        def makeGetter(p):
            return lambda : p.value
        def makeSetter(p):
            def setter(value):
                p.value = value
            return setter
        for p in parlist:
            # variabels in self are writen directly to the name dictionary
            self.__dict__[p.name] = property(makeGetter(p), makeSetter(p))
        self.parlist = parlist
    
142
    def apply2Image(self, img, seedpoints, deletepoints, seedradius, dataset, return_step=None):
143 144 145 146
        """
        Takes an image with seedpoints and seeddeletepoints and runs segmentation on it.
        :return:
        """
147
        t0 = time()
148
        self.detectionState.emit('DO: setup')
149
        
150
        gray = self.convert2Gray(img)
151
        self.detectionState.emit('finished GrayScale')
152
        print("gray")
153
        
Josef Brandt's avatar
 
Josef Brandt committed
154 155 156 157 158
        if self.adaptiveHistEqu:
            numTilesX = round(img.shape[1]/self.claheTileSize)
            numTilesY = round(img.shape[0]/self.claheTileSize)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(numTilesY,numTilesX))
            gray = clahe.apply(gray)
159
            self.detectionState.emit('finished CLAHE')
Josef Brandt's avatar
 
Josef Brandt committed
160 161 162 163
        if return_step=="claheTileSize": return gray, 0
        print("adaptive Histogram Adjustment")
        
        if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
164
            return None, None
Josef Brandt's avatar
 
Josef Brandt committed
165
        
166 167
        if self.activateContrastCurve:
            xi, arr = self.calculateHistFunction(self.contrastCurve)
168
            gray = arr[gray]
169 170
            print("contrast curve")
            self.detectionState.emit('finished Contrast Curve')
171
        if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
172
            return None, None
173 174
            
        # return even if inactive!
175
        if return_step=="activateContrastCurve": return gray, 0
176 177 178 179
        
        # image blur for noise-reduction
        blur = cv2.medianBlur(gray, self.blurRadius)
        blur = np.uint8(blur*(255/blur.max()))
JosefBrandt's avatar
JosefBrandt committed
180 181
        del gray
        
182
        if return_step=="blurRadius": return blur, 0
183
        self.detectionState.emit('finished Blurring')
184 185
        print("blur")
        if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
186
            return None, None
187 188
        
        # thresholding
189 190
        if self.activateLowThresh and not self.activateUpThresh:
            thresh = cv2.threshold(blur, int(255*self.lowThresh), 255, cv2.THRESH_BINARY)[1]
191 192
            if self.invertThresh:
                thresh = 255-thresh
193 194 195
            if return_step=="lowThresh": return thresh, 0
            print("lower threshold")
            if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
196
                return None, None
197 198 199 200 201
            
        elif self.activateLowThresh and self.activateUpThresh:
            lowerLimit, upperLimit = np.round(self.lowThresh*255), np.round(self.upThresh*255)
            thresh = np.zeros_like(blur)
            thresh[np.where(np.logical_and(blur >= lowerLimit, blur <= upperLimit))] = 255
202 203
            if self.invertThresh:
                thresh = 255-thresh
204 205 206
            if return_step=="lowThresh" or return_step=="upThresh": return thresh, 0
            print("between threshold")
            if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
207
                return None, None
208 209 210 211
            
        elif not self.activateLowThresh and self.activateUpThresh:
            thresh = np.zeros_like(blur)
            thresh[np.where(blur <= np.round(self.upThresh*255))] = 255
212 213
            if self.invertThresh:
                thresh = 255-thresh
214 215 216
            if return_step=="upThresh": return thresh, 0
            print("upper threshold")
            if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
217
                return None, None
218 219 220 221
        else:   #no checkbox checked
            if self.parent is not None:
                self.parent.raiseWarning('No thresholding method selected!\nAborted detection..')
            print('NO THRESHOLDING SELECTED!')
222
            return blur, 0
223 224
        
        #close holes darkter than self.max_brightness
Josef Brandt's avatar
 
Josef Brandt committed
225
        thresh = self.closeBrightHoles(thresh, blur, self.maxholebrightness)
JosefBrandt's avatar
JosefBrandt committed
226 227
        del blur
        print("thresholded")
228
        self.detectionState.emit('finished thresholding')
229
        
230 231 232 233 234 235
        # modify thresh with seedpoints and deletepoints
        for p in np.int32(seedpoints):
            cv2.circle(thresh, tuple([p[0], p[1]]), int(p[2]), 255, -1)
        for p in np.int32(deletepoints):
            cv2.circle(thresh, tuple([p[0], p[1]]), int(p[2]), 0, -1)
        
236 237
        if return_step=='maxholebrightness': return thresh, 0
        if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
238
            return None, None
239
        
240
        minArea, maxArea = self.getMinMaxParticleArea(dataset)
Josef Brandt's avatar
 
Josef Brandt committed
241
        
242
        ##get sure_fg
243 244
        '''the peak_local_max function takes the min distance between peaks. Unfortunately, that means that individual 
        particles smaller than that distance are consequently disregarded. Hence, we need a connectec_components approach'''
245
        n, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 8, cv2.CV_32S)
246 247
        self.detectionState.emit('finished connected components search')
        self.detectionState.emit(f'DO: maxVal={n-1}')
248
        del thresh
249
        
250
        measurementPoints = []
251 252
        finalcontours = []
        particleIndex = 0
253
        
254 255 256 257 258 259 260
        if return_step == "sure_fg":
            preview_surefg = np.zeros(img.shape[:2])
            preview_surebg = np.zeros(img.shape[:2])
        elif return_step is None:
            previewImage = None
        else:
            previewImage = np.zeros(img.shape[:2])
261
        
262 263
        for label in range(1, n):
            area = stats[label, cv2.CC_STAT_AREA]
264
            if minArea < area:
265 266 267 268
                up = stats[label, cv2.CC_STAT_TOP]
                left = stats[label, cv2.CC_STAT_LEFT]
                width = stats[label, cv2.CC_STAT_WIDTH]
                height = stats[label, cv2.CC_STAT_HEIGHT]
269
                
270
                subthresh = np.uint8(255 * (labels[up:(up+height), left:(left+width)] == label))
271 272 273 274 275 276
                
                scaleFactor = 1.0
                if width > self.maxComponentSize or height > self.maxComponentSize:
                    scaleFactor = max([width/self.maxComponentSize, height/self.maxComponentSize])
                    subthresh = cv2.resize(subthresh, None, fx=1/scaleFactor, fy=1/scaleFactor)
                
JosefBrandt's avatar
JosefBrandt committed
277
                subdist = cv2.distanceTransform(subthresh, cv2.DIST_L2, 3)
278
                
279 280
                minDistance = round(self.minparticledistance / scaleFactor)
                sure_fg = self.getSureForeground(subthresh, subdist, minDistance)
281 282 283
                sure_bg = cv2.dilate(subthresh, np.ones((5, 5)), iterations = 1)
                if self.closeBackground:
                    sure_bg = self.closeHoles(sure_bg)
284

285 286
                # modify sure_fg and sure_bg with seedpoints and deletepoints
                for p in np.int32(seedpoints):
287 288 289 290 291
                    x = int(round(p[0] / scaleFactor)-left)
                    y = int(round(p[1] / scaleFactor) - up)
                    radius = int(round(p[2] / scaleFactor))
                    cv2.circle(sure_fg, (x, y), radius, 1, -1)
                    cv2.circle(sure_bg, (x, y), radius, 1, -1)
292
                for p in np.int32(deletepoints):
293 294 295 296 297
                    x = int(round(p[0] / scaleFactor) - left)
                    y = int(round(p[1] / scaleFactor) - up)
                    radius = int(round(p[2] / scaleFactor))
                    cv2.circle(sure_fg, (x, y), radius, 1, -1)
                    cv2.circle(sure_bg, (x, y), radius, 1, -1)
Josef Brandt's avatar
Josef Brandt committed
298

299
                if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
300
                    return None, None
301

302 303 304 305 306 307 308 309 310 311
                if return_step=="sure_fg":
                    preview_surefg = self.addToPreviewImage(sure_fg, up, left, preview_surefg)
                    preview_surebg = self.addToPreviewImage(sure_bg, up, left, preview_surebg)
                    continue
                
                unknown = cv2.subtract(sure_bg, sure_fg)
         
                ret, markers = cv2.connectedComponents(sure_fg)
                markers = markers+1
                markers[unknown==255] = 0
312

313
                markers = ndi.label(sure_fg)[0]
314
                try:
315
                    markers = watershed(-subdist, markers, mask=sure_bg, compactness = self.compactness, watershed_line = True)  #labels = 0 for background, 1... for particles
316 317 318 319
                except MemoryError:
                    self.parent.raiseWarning('Segmentation failed due to large connected components.\nPlease reduce maximal connected Component Size.')
                    return None, None
                
320
                if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
321
                    return None, None
JosefBrandt's avatar
JosefBrandt committed
322
                
323 324 325 326 327 328 329 330 331
                if return_step=="watershed":
                    previewImage = self.addToPreviewImage(markers, up, left, previewImage)
                    continue
     
                if cv2.__version__ > '3.5':        
                    contours, hierarchy = cv2.findContours(markers, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
                else:
                    temp, contours, hierarchy = cv2.findContours(markers, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
                if self.cancelcomputation:
JosefBrandt's avatar
JosefBrandt committed
332
                    return None, None
333
                    
334 335 336
                tmpcontours = [contours[i] for i in range(len(contours)) if hierarchy[0,i,3]<0]

                for cnt in tmpcontours:
337
                    contourArea = cv2.contourArea(cnt) * scaleFactor**2
338
                    if minArea <= contourArea <= maxArea:
339 340
                        tmplabel = markers[cnt[0,0,1],cnt[0,0,0]]
                        if tmplabel ==0:
341 342 343 344 345 346
                            continue
            
                        x0, x1 = cnt[:,0,0].min(), cnt[:,0,0].max()
                        y0, y1 = cnt[:,0,1].min(), cnt[:,0,1].max()
                        
                        subimg = (markers[y0:y1+1,x0:x1+1]).copy()
347
                        subimg[subimg!=tmplabel ] = 0
348 349
                        y, x = self.getMeasurementPoints(subimg)
                        
350 351 352 353 354
                        if scaleFactor != 1:
                            x0 = int(round(x0 * scaleFactor))
                            y0 = int(round(y0 * scaleFactor))
                            x = [int(round(subX * scaleFactor)) for subX in x]
                            y = [int(round(subY * scaleFactor)) for subY in y]
355
                            for i in range(len(cnt)):
356 357
                                cnt[i][0][0] = int(round(cnt[i][0][0] * scaleFactor))
                                cnt[i][0][1] = int(round(cnt[i][0][1] * scaleFactor))
358
                            
359 360 361 362 363
                        for i in range(len(cnt)):
                            cnt[i][0][0] += left
                            cnt[i][0][1] += up
                        
                        finalcontours.append(cnt)
364
                        measurementPoints.append([])
365 366 367
                        
                        for index in range(0, len(x)):
                            newMeasPoint = MeasurementPoint(particleIndex, x[index] + x0 + left, y[index] + y0 + up)
368
                            measurementPoints[-1].append(newMeasPoint)
369
                            
370
                        particleIndex += 1
371 372
                        
            self.detectionState.emit(f'DO: newVal={label}')
373 374 375 376 377 378 379 380 381

        if return_step == 'sure_fg':
            img = np.zeros_like(preview_surefg)
            img[np.nonzero(preview_surefg)] |= 1
            img[np.nonzero(preview_surebg)] |= 2
            return img, 1        
        
        elif return_step == 'watershed':
            return np.uint8(255*(previewImage!=0)), 0
382
        
383 384 385 386
        elif return_step is not None:
            raise NotImplementedError(f"this particular return_step: {return_step} is not implemented yet")

        print("particle detection took:", time()-t0, "seconds")
387 388 389 390
        
        if self.measurefrac < 1.0:
            nMeasurementsDesired = int(np.round(self.measurefrac * len(measurementPoints)))
            print(f'selecting {nMeasurementsDesired} of {len(measurementPoints)} measuring spots')
391
            measurementPoints = random.sample(measurementPoints, nMeasurementsDesired)
392 393 394

        total_time = time()-t0
        print('segmentation took', total_time, 'seconds')
395 396
        total_time = round(total_time, 2)
        self.detectionState.emit(f'finished particle detection after {total_time} seconds')
397 398
        return measurementPoints, finalcontours
    
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
    def getMinMaxParticleArea(self, dataset):
        """
        Converts specified particle sizes into particle areas that are used for filtering detection results.
        Size is interpreted as sphere equivalent diameter.
        a = pi*(d/2)²
        :return:
        """
        pixelscale = dataset.getPixelScale()
        minRadius = (self.minparticlesize / pixelscale) / 2
        minArea = np.pi * minRadius**2
        
        if self.enableMaxArea:
            maxRadius = (self.maxparticlesize / pixelscale) / 2
            maxArea = np.pi * maxRadius**2
        else:
            maxArea = np.inf
        
        return minArea, maxArea
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
    
    def addToPreviewImage(self, subimg, up, left, previewImage):
        """
        Adds a subimage at given position to the previewimage
        :return:
        """
        height, width = subimg.shape[0], subimg.shape[1]
        previewImage[up:up+height, left:left+width] += subimg
        previewImage = np.array(previewImage, dtype = np.int32)
        return previewImage
    
    
    def setParameters(self, **kwargs):
        """
        Parameters that were set in the parameter are updated to the classes dictionary and can later be referenced to as self.key
        :return:
        """
        for key in kwargs:
            self.__dict__[key] = kwargs[key]
        
    def convert2Gray(self, img):
        gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
        return gray
        
    def calculateHist(self, gray):
        hist = cv2.calcHist([gray],[0],None,[256],[0,256])
        return hist
    
    def calculateHistFunction(self, points):
        """
        Calculates the curve to plot in the histogram widget
        :return:
        """
        t = np.linspace(0,1,800)
        x0 = np.concatenate(([-1.],points[:,0],[256.]))
        y0 = np.concatenate(([0.],points[:,1],[255.]))
        t0 = np.concatenate(([0.],np.cumsum(np.sqrt(np.diff(x0)**2+np.diff(y0)**2))))
        t0 /= t0[-1]
        fx = InterpolatedUnivariateSpline(t0, x0, k=3)
        fy = InterpolatedUnivariateSpline(t0, y0, k=3)
        x = fx(t)
        y = fy(t)
        arr = np.zeros(256, dtype=np.uint8)
        xi = np.arange(256)
        ind = np.searchsorted(xi, x)
        arr[ind[ind<256]] = y[ind<256]
        arr[xi>points[:,0].max()] = 255
        arr[xi<points[:,0].min()] = 0.
        arr[arr>255] = 255.
        arr[arr<0] = 0.
467
        
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
        return xi, arr      
    
    def closeHoles(self, thresh):
        """
        Closes holes in a binary image
        :return:
        """
        n, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 8, cv2.CV_32S)
        newthresh = np.zeros_like(thresh)
    
        for label in range(1, n):
            up = stats[label, cv2.CC_STAT_TOP]
            left = stats[label, cv2.CC_STAT_LEFT]
            width = stats[label, cv2.CC_STAT_WIDTH]
            height = stats[label, cv2.CC_STAT_HEIGHT]
            subimg = np.uint8(255 * (labels[up:(up+height), left:(left+width)] == label))
            
            newthresh[up:(up+height), left:(left+width)] += closeHolesOfSubImage(subimg)

        return newthresh
    
    def closeBrightHoles(self, thresh, grayimage, minBrightness):
        """
        Only closes holes that are brighter than a given minimal Brightness
        :return:
        """
        n, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 8, cv2.CV_32S)
        minBrightness = np.uint8(minBrightness * 255)
        print('num comps in brightHoles:', n)
        
        for label in range(1, n):
            up = stats[label, cv2.CC_STAT_TOP]
            left = stats[label, cv2.CC_STAT_LEFT]
            width = stats[label, cv2.CC_STAT_WIDTH]
            height = stats[label, cv2.CC_STAT_HEIGHT]
            subimg = np.uint8(255 * (labels[up:(up+height), left:(left+width)] == label))
            
            subimg = cv2.copyMakeBorder(subimg, 1, 1, 1, 1, 0)
            im_floodfill = subimg.copy()
            h, w = subimg.shape[:2]
            mask = np.zeros((h+2, w+2), np.uint8)
            cv2.floodFill(im_floodfill, mask, (0,0), 255);
            
            indices = np.where(im_floodfill == 0)[0]
            if len(indices) > 0:
                if np.mean(grayimage[indices[0]]) > minBrightness:
                    im_floodfill_inv = cv2.bitwise_not(im_floodfill)
                    im_out = subimg | im_floodfill_inv
                    thresh[up:(up+height), left:(left+width)] += im_out[1:-1, 1:-1]
                    
        return thresh

    
    def getSureForeground(self, thresh, disttransform, mindistance):
        """
        Calculates sure_fg (i.e, seedpoints) for the markerbased watershed
        Currently the function only takes a distance-transform and extracts their (local) maxima.
        If desired, a fuzzy Clustering is applied to these to reduce the number of considered seed points.
        :return:
        """
JosefBrandt's avatar
JosefBrandt committed
528 529 530 531 532
        def simplifyByFuzzyClustering(points, maxNumPoints=100, maxNumClusters=50):
            """
            Runs fuzzy-c-means clustering on the points to reduce the number of seed points
            :return:
            """
533
            newPoints = []
JosefBrandt's avatar
JosefBrandt committed
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
            numPeaks = len(points)
            if len(points) <= maxNumPoints:
                xpts = [peak[1] for peak in points]
                ypts = [peak[0] for peak in points]
                alldata = np.vstack((ypts, xpts))
                
                fpcs = []
                cntrs = []
                maxNumClusters = min([maxNumClusters, numPeaks])
                for ncenters in range(2, maxNumClusters):
                    cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(alldata, ncenters, 2, error=0.005, maxiter=1000, init=None)
                    fpcs.append(fpc/(ncenters**0.3))   #makes larger cluster numbers less preferred
                    cntrs.append(cntr)
                
                bestMatchIndex = fpcs.index(max(fpcs))
                bestMatchCentres = cntrs[bestMatchIndex]
                for point in bestMatchCentres:
                    newPoints.append([int(round(point[0])), int(round(point[1]))])
                print(f'reduced {numPeaks} to {len(newPoints)} maxima')
            else:
                newPoints = points
555
            return newPoints
JosefBrandt's avatar
JosefBrandt committed
556 557 558 559 560 561 562 563 564 565 566 567 568 569
        
        def sortOutTooClosePoints(points, minDistance):
            """
            The points-array is taken point-by-point and each point is only taken if it is minDistance away from the last one.
            This removes directly adjacent points
            :return:
            """
            lastPoint = points[0]
            fewerPoints = [points[0]]
            for point in points[2:]:
                if np.linalg.norm(lastPoint-point) > minDistance:
                    fewerPoints.append(point)
                lastPoint = point    
            return fewerPoints
570
            
JosefBrandt's avatar
JosefBrandt committed
571
        sure_fg = np.zeros_like(thresh)
572
        
JosefBrandt's avatar
JosefBrandt committed
573 574 575 576 577 578 579
        localMax = np.uint8(peak_local_max(disttransform, mindistance, exclude_border=False, indices = False))
        localMax[disttransform == np.max(disttransform)] = 1
        
        maxPoints = np.where(localMax == np.max(localMax))
        maxPoints = np.transpose(np.array(maxPoints))
        if len(maxPoints) > 3:
            maxPoints = sortOutTooClosePoints(maxPoints, mindistance)
580

JosefBrandt's avatar
JosefBrandt committed
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
        if len(maxPoints) > 3 and self.fuzzycluster:
            clusteredPoints = simplifyByFuzzyClustering(maxPoints)
            atLeastOnePointAdded = False
            for point in clusteredPoints:
                if thresh[point[0], point[1]] != 0:
                    sure_fg[point[0], point[1]] = 1
                    atLeastOnePointAdded = True
                    
            if not atLeastOnePointAdded:
                point = maxPoints[0]
                sure_fg[point[0], point[1]] = 1    
        
        else:
            for point in maxPoints:
                sure_fg[point[0], point[1]] = 1
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
                
        sure_fg = cv2.dilate(sure_fg, np.ones((3, 3)))
        return sure_fg
    
    def getMeasurementPoints(self, binParticle, numPoints=1):
        """
        Sets coordinates for later measurement points.
        :return:
        """
        binParticle = cv2.copyMakeBorder(binParticle, 1, 1, 1, 1, 0)
        dist = cv2.distanceTransform(np.uint8(binParticle), cv2.DIST_L2,3)
        ind = np.argmax(dist)
        y = [ind//dist.shape[1]-1]
        x = [ind%dist.shape[1]-1]
        for i in range(numPoints-1):
            binParticle.flat[ind] = 0
            dist = cv2.distanceTransform(np.uint8(binParticle), cv2.DIST_L2,3)
            ind = np.argmax(dist)
            y.append(ind//dist.shape[1]-1)
            x.append(ind%dist.shape[1]-1)
        return y, x
    

if __name__ == '__main__':
    import matplotlib.pyplot as plt
#    img = cv2.imread('/home/brandt/Schreibtisch/Segmentation/fullimage_III.png')
    seg = Segmentation()
    
    kwargs = {}
    seedpoints, deletepoints = [], []
    for parameter in seg.parlist:
        kwargs[parameter.name] = parameter.value
    seg.setParameters(**kwargs)
    
    size = 25000
    stepSize = 2000
    maxSize = 40000
    
    sizes, times = [], []
    while size <= maxSize:
        try:
            print('newsize =', size)
            img = cv2.resize(img, (size, size))
            points, contours, tf = seg.apply2Image(img, np.array([]), np.array([]), 1, None)
            sizes.append(size)
            times.append(tf)
            size += stepSize
        except:
            print('segmentation failed at size', size)
            raise
    #    imgSegmented = cv2.drawContours(img, contours, -1, (255, 255, 0), thickness=2)
#    plt.imshow(imgSegmented)