dataset.py 15.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
# -*- coding: utf-8 -*-
"""
GEPARD - Gepard-Enabled PARticle Detection
Copyright (C) 2018  Lars Bittrich and Josef Brandt, Leibniz-Institut für 
Polymerforschung Dresden e. V. <bittrich-lars@ipfdd.de>    

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program, see COPYING.  
If not, see <https://www.gnu.org/licenses/>.
"""
import os
import pickle
import numpy as np
import cv2
from helperfunctions import cv2imread_fix, cv2imwrite_fix
from copy import copy

28
currentversion = 2
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

def loadData(fname):
    retds = None
    with open(fname, "rb") as fp:
        ds = pickle.load(fp)
        ds.fname = fname
        ds.readin = True
        ds.updatePath()
        retds = DataSet(fname)
        retds.version = 0
        retds.__dict__.update(ds.__dict__)
        if retds.version < currentversion:
            retds.legacyConversion()
        elif retds.zvalimg=="saved":
            retds.loadZvalImg()
    return retds

def saveData(dataset, fname):
    with open(fname, "wb") as fp:
        # zvalimg is rather large and thus it is saved separately in a tif file 
        # only onces after its creation
        zvalimg = dataset.zvalimg
        if zvalimg is not None:
            dataset.zvalimg = "saved"
        pickle.dump(dataset, fp, protocol=-1)
        dataset.zvalimg = zvalimg

56
def arrayCompare(a1, a2):
57 58 59 60
    if a1.shape!=a2.shape:
        return False
    if a1.dtype!=np.float32 and a1.dtype!=np.float64:
        return np.all(a1==a2)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    ind = np.isnan(a1)
    if not np.any(ind):
        return np.all(a1==a2)
    return np.all(a1[~ind]==a2[~ind])

def listCompare(l1, l2):
    if len(l1)!=len(l2):
        return False
    for l1i, l2i in zip(l1, l2):
        if isinstance(l1i, np.ndarray):
            if not isinstance(l2i, np.ndarray) or not arrayCompare(l1i, l2i):
                return False
        elif isinstance(l1i, (list, tuple)):
            if not isinstance(l2i, (list, tuple)) or not listCompare(l1i, l2i):
                return False
        elif l1i!=l2i and ((~np.isnan(l1i)) or (~np.isnan(l2i))):
            return False
    return True

def recursiveDictCompare(d1, d2):
    for key in d1:
        if not key in d2:
83
            print("key missing in d2:", key, flush=True)
84 85 86
            return False
        a = d1[key]
        b = d2[key]
87
        print(key, type(a), type(b), flush=True)
88 89
        if isinstance(a, np.ndarray):
            if not isinstance(b, np.ndarray) or not arrayCompare(a, b):
Lars Bittrich's avatar
Lars Bittrich committed
90
                print("data is different!", a, b)
91 92 93
                return False
        elif isinstance(a, dict):
            if not isinstance(b, dict):
Lars Bittrich's avatar
Lars Bittrich committed
94
                print("data is different!", a, b)
95 96 97 98 99
                return False
            if not recursiveDictCompare(a, b):
                return False
        elif isinstance(a, (list, tuple)):
            if not isinstance(b, (list, tuple)) or not listCompare(a, b):
Lars Bittrich's avatar
Lars Bittrich committed
100
                print("data is different!", a, b)
101 102
                return False
        elif a != b:
103 104 105
            if (a is not None) and (b is not None):
                print("data is different!", a, b)
                return False
106 107
    return True

108 109 110 111 112 113 114
class DataSet(object):
    def __init__(self, fname, newProject=False):
        self.fname = fname
        # parameters specifically for optical scan
        self.version = currentversion
        self.lastpos = None
        self.maxdim = None
115 116 117 118 119
        self.pixelscale_df = None # µm / pixel --> scale of DARK FIELD camera (used for image stitching)
        self.pixelscale_bf = None # µm / pixel of DARK FIELD camera (set to same as bright field, if both use the same camera)
        self.imagedim_bf = None  # width, height, angle of BRIGHT FIELD camera
        self.imagedim_df = None  # width, height, angle of DARK FIELD camera (set to same as bright field, if both use the same camera)
        self.imagescanMode = 'df'    #was the fullimage acquired in dark- or brightfield?
120 121 122 123 124 125 126 127 128 129
        self.fitpoints = []   # manually adjusted positions aquired to define the specimen geometry
        self.fitindices = []  # which of the five positions in the ui are already known
        self.boundary = []    # scan boundary computed by a circle around the fitpoints + manual adjustments
        self.grid = []        # scan grid positions for optical scan
        self.zpositions = []  # z-positions for optical scan
        self.heightmap = None
        self.zvalimg = None
        
        # parameters specifically for raman scan
        self.pshift = None    # shift of raman scan position relative to image center
Hackmet's avatar
Hackmet committed
130
        self.coordOffset = [0, 0]   #offset of entire coordinate system
131 132
        self.seedpoints = np.array([])
        self.seeddeletepoints = np.array([])
Josef Brandt's avatar
Josef Brandt committed
133 134 135 136 137 138 139 140 141 142 143 144
        self.detectParams = {'points': np.array([[50,0],[100,200],[200,255]]),
                             'contrastcurve': True,
                             'blurRadius': 9,
                             'threshold': 0.2,
                             'maxholebrightness': 0.5,
                             'erodeconvexdefects': 0,
                             'minparticlearea': 20,
                             'minparticledistance': 20,
                             'measurefrac': 1,
                             'compactness': 0.1,
                             'seedRad': 3}
        
145 146 147 148 149 150
        self.ramanpoints = []
        self.particlecontours = []
        self.particlestats = []
        self.ramanscansortindex = None
        self.ramanscandone = False
        
151 152 153 154 155 156 157
        self.results = {'polymers': None,
                        'hqis': None,
                        'additives': None,
                        'additive_hqis': None}

        self.resultParams = {'minHQI': None,
                             'compHQI': None}
Hackmet's avatar
Hackmet committed
158
        self.spectraPath = None
159 160 161 162
        self.particles2spectra = None    #links idParticle to corresponding idSpectra (i.e., first measured particle (ID=0) is linked to spectra indices 0 and 1)
        self.colorSeed = 'default'
        self.resultsUploadedToSQL = []
        
163 164 165 166 167 168 169
        self.readin = True    # a value that is always set to True at loadData 
                              # and mark that the coordinate system might be changed in the meantime
        self.mode = "prepare"
        if newProject:
            self.fname = self.newProject(fname)
        self.updatePath()
        
170 171 172 173 174 175 176 177
    def __eq__(self, other):
        return recursiveDictCompare(self.__dict__, other.__dict__)
        
    def getPixelScale(self, mode=None):
        if mode is None:
            mode = self.imagescanMode
        return (self.pixelscale_df if mode == 'df' else self.pixelscale_bf)
        
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    def saveZvalImg(self):
        if self.zvalimg is not None:
            cv2imwrite_fix(self.getZvalImageName(), self.zvalimg)
            
    def loadZvalImg(self):
        if os.path.exists(self.getZvalImageName()):
            self.zvalimg = cv2imread_fix(self.getZvalImageName(), cv2.IMREAD_GRAYSCALE)
        
    def legacyConversion(self, recreatefullimage=False):
        if self.version==0:
            print("Converting legacy version 0 to 1")
            print("This may take some time")
            
            # local imports as these functions are only needed for the rare occasion of legacy conversion
            from opticalscan import loadAndPasteImage
            
            # try to load png and check for detection contours
195 196
            recreatefullimage = recreatefullimage or not os.path.exists(self.getLegacyImageName())
            if not recreatefullimage:
197 198 199 200 201 202 203
                img = cv2imread_fix(self.getLegacyImageName())
                Nc = len(self.particlecontours)
                if Nc>0:
                    contour = self.particlecontours[Nc//2]
                    contpixels = img[contour[:,0,1],contour[:,0,0]]
                    if np.all(contpixels[:,1]==255) and np.all(contpixels[:,2]==0) \
                        and np.all(contpixels[:,0]==0):
204 205
                        recreatefullimage = True
                if not recreatefullimage:
206 207 208
                    cv2imwrite_fix(self.getImageName(), img)
                del img
            
209
            if recreatefullimage:
210 211 212 213 214
                print("recreating fullimage from grid data")
                imgdata = None
                zvalimg = None
                Ngrid = len(self.grid)
                
215
                width, height, rotationvalue = self.imagedim_df
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
                p0, p1 = self.maxdim[:2], self.maxdim[2:]
                for i in range(Ngrid):
                    print(f"Processing image {i+1} of {Ngrid}")
                    names = []
                    for k in range(len(self.zpositions)):
                        names.append(os.path.join(self.getScanPath(), f"image_{i}_{k}.bmp"))
                    p = self.grid[i]
                    imgdata, zvalimg = loadAndPasteImage(names, imgdata, zvalimg, width, 
                                                            height, rotationvalue, p0, p1, p)
                self.zvalimg = zvalimg
                cv2imwrite_fix(self.getImageName(), cv2.cvtColor(imgdata, cv2.COLOR_RGB2BGR))
                del imgdata
            self.saveZvalImg()
            if "particleimgs" in self.__dict__:
                del self.particleimgs
            
            self.version = 1
233 234 235 236 237 238 239 240
            
            
        if self.version == 1:
            print("Converting legacy version 1 to 2")
            if hasattr(self, 'pixelscale'):
                print('pixelscale was', self.pixelscale)
                self.pixelscale_bf = self.pixelscale
                self.pixelscale_df = self.pixelscale
Hackmet's avatar
Hackmet committed
241
                del self.pixelscale
242 243 244 245
            
            if hasattr(self, 'imagedim'):
                self.imagedim_bf = self.imagedim
                self.imagedim_df = self.imagedim
Hackmet's avatar
Hackmet committed
246 247
                del self.imagedim

248 249
            self.version = 2
            
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
        # add later conversion for higher version numbers here
        
    def getSubImage(self, img, index, draw=True):
        contour = self.particlecontours[index]
        x0, x1 = contour[:,0,0].min(), contour[:,0,0].max()
        y0, y1 = contour[:,0,1].min(), contour[:,0,1].max()
        subimg = img[y0:y1+1,x0:x1+1].copy()
        if draw:
            cv2.drawContours(subimg, [contour], -1, (0,255,0), 1)
        return subimg
        
    def getZval(self, pixelpos):
        assert self.zvalimg is not None
        zp = self.zvalimg[round(pixelpos[1]), round(pixelpos[0])]
        z0, z1 = self.zpositions.min(), self.zpositions.max()
        return zp/255.*(z1-z0) + z0
        
    def mapHeight(self, x, y):
        assert not self.readin
        assert self.heightmap is not None
        return self.heightmap[0]*x + self.heightmap[1]*y + self.heightmap[2]
        
272
    def mapToPixel(self, p, mode='df', force=False):
273 274 275
        if not force:
            assert not self.readin
        p0 = copy(self.lastpos)
276 277 278 279 280 281 282 283 284 285 286 287 288
        
        if mode == 'df':
            p0[0] -= self.imagedim_df[0]/2
            p0[1] += self.imagedim_df[1]/2
            return (p[0] - p0[0])/self.pixelscale_df, (p0[1] - p[1])/self.pixelscale_df
            
        elif mode == 'bf':
            p0[0] -= self.imagedim_bf[0]/2
            p0[1] += self.imagedim_bf[1]/2
            return (p[0] - p0[0])/self.pixelscale_bf, (p0[1] - p[1])/self.pixelscale_bf
        else:
            print('mapToPixelMode not understood')
            return
289
    
290
    def mapToLength(self, pixelpos, mode='df', force=False):
291 292 293
        if not force:
            assert not self.readin
        p0 = copy(self.lastpos)
Hackmet's avatar
Hackmet committed
294 295 296
        p0[0] += self.coordOffset[0] 
        p0[1] += self.coordOffset[1]
        
297 298 299 300 301 302 303 304 305
        if mode == 'df':
            p0[0] -= self.imagedim_df[0]/2
            p0[1] += self.imagedim_df[1]/2
            return (pixelpos[0]*self.pixelscale_df + p0[0]), (p0[1] - pixelpos[1]*self.pixelscale_df)
        elif mode == 'bf':
            p0[0] -= self.imagedim_bf[0]/2
            p0[1] += self.imagedim_bf[1]/2
            return (pixelpos[0]*self.pixelscale_bf + p0[0]), (p0[1] - pixelpos[1]*self.pixelscale_bf)
        else:
306
            raise ValueError(f'mapToLength mode: {mode} not understood')
307
    
308 309
    def mapToLengthRaman(self, pixelpos, microscopeMode='df', noz=False):
        p0x, p0y = self.mapToLength(pixelpos, mode = microscopeMode)
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
        x, y = p0x + self.pshift[0], p0y + self.pshift[1]
        z = None
        if not noz:
            z = self.mapHeight(x, y)
            z += self.getZval(pixelpos)
        return x, y, z
        
    def newProject(self, fname):
        path = os.path.split(fname)[0]
        name = os.path.splitext(os.path.basename(fname))[0]
        newpath = os.path.join(path, name)
        fname = os.path.join(newpath, name + ".pkl")
        if not os.path.exists(newpath):
            os.mkdir(newpath)        # for new projects a directory will be created
        elif os.path.exists(fname):  # if this project is already there, load it instead
            self.__dict__.update(loadData(fname).__dict__)
        return fname
    
    def getScanPath(self):
        scandir = os.path.join(self.path, "scanimages")
        if not os.path.exists(scandir):
            os.mkdir(scandir)
        return scandir
        
    def updatePath(self):
        self.path = os.path.split(self.fname)[0]
        self.name = os.path.splitext(os.path.basename(self.fname))[0]
        
    def getImageName(self):
339 340
        return os.path.join(self.path, 'fullimage.tif')

341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    def getZvalImageName(self):
        return os.path.join(self.path, "zvalues.tif")
    
    def getLegacyImageName(self):
        return os.path.join(self.path, "fullimage.png")
    
    def getLegacyDetectImageName(self):
        return os.path.join(self.path, "detectimage.png")
    
    def getDetectImageName(self):
        raise NotImplementedError("No longer implemented due to change in API")
    
    def getTmpImageName(self):
        return os.path.join(self.path, "tmp.bmp")
    
    def saveParticleData(self):
Hackmet's avatar
Hackmet committed
357
        print('Not saving ParticleData into text file...:\nThe current output format might be wrong, if multiple spectra per particle are present...')
358 359 360 361 362 363 364 365 366 367 368 369 370 371
#        if len(self.ramanscansortindex)>0:
#            data = []
#            pixelscale = (self.pixelscale_df if self.imagescanMode == 'df' else self.pixelscale_bf)
#            for i in self.ramanscansortindex:
#                data.append(list(self.ramanpoints[i])+list(self.particlestats[i]))
#            data = np.array(data)
#            data[:,0], data[:,1], z = self.mapToLengthRaman((data[:,0], data[:,1]), microscopeMode=self.imagescanMode, noz=True)
#            data[:,2:7] *= pixelscale
#            header = "x [µm], y [µm], length [µm], height [µm], length_ellipse [µm], height_ellipse [µm]"
#            if data.shape[1]>6:
#                header = header + ", area [µm^2]"
#                data[:,6] *= pixelscale
#            np.savetxt(os.path.join(self.path, "particledata.txt"), data, 
#                       header=header)
372 373 374
            
    def save(self):
        saveData(self, self.fname)
375 376 377
    
    def saveBackup(self):
        inc = 0
Hackmet's avatar
Hackmet committed
378
        while True:
379 380 381 382 383 384 385
            directory = os.path.dirname(self.fname)
            filename = self.name + '_backup_' + str(inc) + '.pkl'
            path = os.path.join(directory, filename)
            if os.path.exists(path):
                inc += 1
            else:
                saveData(self, path)
Hackmet's avatar
Hackmet committed
386 387
                return filename