dataset.py 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# -*- coding: utf-8 -*-
"""
GEPARD - Gepard-Enabled PARticle Detection
Copyright (C) 2018  Lars Bittrich and Josef Brandt, Leibniz-Institut für 
Polymerforschung Dresden e. V. <bittrich-lars@ipfdd.de>    

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program, see COPYING.  
If not, see <https://www.gnu.org/licenses/>.
"""
import os
import pickle
import numpy as np
import cv2
25 26
import sys
from .helperfunctions import cv2imread_fix, cv2imwrite_fix
27
from copy import copy
28 29 30 31 32
from .analysis.particleContainer import ParticleContainer
from .legacyConvert import legacyConversion, currentVersion
# for legacy pickle import the old module name dataset must be found 
# (no relative import)
from . import dataset
JosefBrandt's avatar
JosefBrandt committed
33
from . import analysis
34
sys.modules['dataset'] = dataset
JosefBrandt's avatar
JosefBrandt committed
35
sys.modules['analysis'] = analysis
36 37 38 39 40 41 42 43 44 45 46

def loadData(fname):
    retds = None
    with open(fname, "rb") as fp:
        ds = pickle.load(fp)
        ds.fname = fname
        ds.readin = True
        ds.updatePath()
        retds = DataSet(fname)
        retds.version = 0
        retds.__dict__.update(ds.__dict__)
47
        if retds.version < currentVersion:
48
            legacyConversion(retds)
49 50 51
    return retds

def saveData(dataset, fname):
52 53 54 55 56 57 58 59
    with open(fname, "wb") as fp:
        # zvalimg is rather large and thus it is saved separately in a tif file 
        # only onces after its creation
        zvalimg = dataset.zvalimg
        if zvalimg is not None:
            dataset.zvalimg = "saved"
        pickle.dump(dataset, fp, protocol=-1)
        dataset.zvalimg = zvalimg
60

61
def arrayCompare(a1, a2):
Lars Bittrich's avatar
Lars Bittrich committed
62 63 64 65
    if a1.shape!=a2.shape:
        return False
    if a1.dtype!=np.float32 and a1.dtype!=np.float64:
        return np.all(a1==a2)
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    ind = np.isnan(a1)
    if not np.any(ind):
        return np.all(a1==a2)
    return np.all(a1[~ind]==a2[~ind])

def listCompare(l1, l2):
    if len(l1)!=len(l2):
        return False
    for l1i, l2i in zip(l1, l2):
        if isinstance(l1i, np.ndarray):
            if not isinstance(l2i, np.ndarray) or not arrayCompare(l1i, l2i):
                return False
        elif isinstance(l1i, (list, tuple)):
            if not isinstance(l2i, (list, tuple)) or not listCompare(l1i, l2i):
                return False
        elif l1i!=l2i and ((~np.isnan(l1i)) or (~np.isnan(l2i))):
            return False
    return True

def recursiveDictCompare(d1, d2):
    for key in d1:
        if not key in d2:
Lars Bittrich's avatar
Lars Bittrich committed
88
            print("key missing in d2:", key, flush=True)
89 90 91
            return False
        a = d1[key]
        b = d2[key]
Lars Bittrich's avatar
Lars Bittrich committed
92
        print(key, type(a), type(b), flush=True)
93 94
        if isinstance(a, np.ndarray):
            if not isinstance(b, np.ndarray) or not arrayCompare(a, b):
Lars Bittrich's avatar
Lars Bittrich committed
95
                print("data is different!", a, b)
96 97 98
                return False
        elif isinstance(a, dict):
            if not isinstance(b, dict):
Lars Bittrich's avatar
Lars Bittrich committed
99
                print("data is different!", a, b)
100 101 102 103 104
                return False
            if not recursiveDictCompare(a, b):
                return False
        elif isinstance(a, (list, tuple)):
            if not isinstance(b, (list, tuple)) or not listCompare(a, b):
Lars Bittrich's avatar
Lars Bittrich committed
105
                print("data is different!", a, b)
106 107
                return False
        elif a != b:
Lars Bittrich's avatar
Lars Bittrich committed
108 109 110
            if (a is not None) and (b is not None):
                print("data is different!", a, b)
                return False
111 112
    return True

113 114 115 116
class DataSet(object):
    def __init__(self, fname, newProject=False):
        self.fname = fname
        # parameters specifically for optical scan
117
        self.version = currentVersion
118 119
        self.lastpos = None
        self.maxdim = None
120 121 122 123 124
        self.pixelscale_df = None # µm / pixel --> scale of DARK FIELD camera (used for image stitching)
        self.pixelscale_bf = None # µm / pixel of DARK FIELD camera (set to same as bright field, if both use the same camera)
        self.imagedim_bf = None  # width, height, angle of BRIGHT FIELD camera
        self.imagedim_df = None  # width, height, angle of DARK FIELD camera (set to same as bright field, if both use the same camera)
        self.imagescanMode = 'df'    #was the fullimage acquired in dark- or brightfield?
125 126 127 128 129 130 131
        self.fitpoints = []   # manually adjusted positions aquired to define the specimen geometry
        self.fitindices = []  # which of the five positions in the ui are already known
        self.boundary = []    # scan boundary computed by a circle around the fitpoints + manual adjustments
        self.grid = []        # scan grid positions for optical scan
        self.zpositions = []  # z-positions for optical scan
        self.heightmap = None
        self.zvalimg = None
132 133 134
        self.coordinatetransform = None # if imported form extern source coordinate system may be rotated
        self.signx = 1.
        self.signy = -1.
135 136 137
        
        # parameters specifically for raman scan
        self.pshift = None    # shift of raman scan position relative to image center
Hackmet's avatar
Hackmet committed
138
        self.coordOffset = [0, 0]   #offset of entire coordinate system
139 140
        self.seedpoints = np.array([])
        self.seeddeletepoints = np.array([])
Josef Brandt's avatar
Josef Brandt committed
141 142 143 144 145 146 147 148 149 150
        self.detectParams = {'points': np.array([[50,0],[100,200],[200,255]]),
                             'contrastcurve': True,
                             'blurRadius': 9,
                             'threshold': 0.2,
                             'maxholebrightness': 0.5,
                             'minparticlearea': 20,
                             'minparticledistance': 20,
                             'measurefrac': 1,
                             'compactness': 0.1,
                             'seedRad': 3}
JosefBrandt's avatar
JosefBrandt committed
151

JosefBrandt's avatar
JosefBrandt committed
152
        self.particleContainer = ParticleContainer(self)
153
        self.particleDetectionDone = False
154
        self.ramanscandone = False
155

JosefBrandt's avatar
JosefBrandt committed
156
        self.resultParams = {'minHQI': 5}
157 158 159
        self.colorSeed = 'default'
        self.resultsUploadedToSQL = []
        
160 161 162 163 164 165 166
        self.readin = True    # a value that is always set to True at loadData 
                              # and mark that the coordinate system might be changed in the meantime
        self.mode = "prepare"
        if newProject:
            self.fname = self.newProject(fname)
        self.updatePath()
        
167 168 169 170 171 172 173
    def __eq__(self, other):
        return recursiveDictCompare(self.__dict__, other.__dict__)
        
    def getPixelScale(self, mode=None):
        if mode is None:
            mode = self.imagescanMode
        return (self.pixelscale_df if mode == 'df' else self.pixelscale_bf)
174 175 176 177 178 179
    
    def getZvalImg(self):
        if self.zvalimg == 'saved':
            self.loadZvalImg()
        return self.zvalimg
    
180 181 182 183 184 185 186
    def saveZvalImg(self):
        if self.zvalimg is not None:
            cv2imwrite_fix(self.getZvalImageName(), self.zvalimg)
            
    def loadZvalImg(self):
        if os.path.exists(self.getZvalImageName()):
            self.zvalimg = cv2imread_fix(self.getZvalImageName(), cv2.IMREAD_GRAYSCALE)
187 188 189 190
            if self.zvalimg is None:
                print(self.getZvalImageName())
        else:
            raise FileNotFoundError
191

192 193
    def getZval(self, pixelpos):
        assert self.zvalimg is not None
194 195 196 197 198 199 200 201 202
        i, j = int(round(pixelpos[1])), int(round(pixelpos[0]))
        if i>=self.zvalimg.shape[0]: 
            print('error in getZval:', self.zvalimg.shape, i, j)
            i = self.zvalimg.shape[0]-1
        if j>=self.zvalimg.shape[1]: 
            print('error in getZval:', self.zvalimg.shape, i, j)
            j = self.zvalimg.shape[1]-1
        zp = self.zvalimg[i,j]
        z0, z1 = self.zpositions[0], self.zpositions[-1]
203
        return zp/255.*(z1-z0) + z0
204
    
205 206 207 208 209
    def mapHeight(self, x, y):
        assert not self.readin
        assert self.heightmap is not None
        return self.heightmap[0]*x + self.heightmap[1]*y + self.heightmap[2]
        
210
    def mapToPixel(self, p, mode='df', force=False):
211 212 213
        if not force:
            assert not self.readin
        p0 = copy(self.lastpos)
214
        
215 216 217 218 219
        if self.coordinatetransform is not None:
            z = 0. if len(p)<3 else p[2]
            T, pc = self.coordinatetransform
            p = (np.dot(np.array([p[0], p[1], z])-pc, T.T))
        
220
        if mode == 'df':
221 222 223
            p0[0] -= self.signx*self.imagedim_df[0]/2
            p0[1] -= self.signy*self.imagedim_df[1]/2
            x, y = self.signx*(p[0] - p0[0])/self.pixelscale_df, self.signy*(p[1] - p0[1])/self.pixelscale_df
224 225
            
        elif mode == 'bf':
226 227 228 229
            p0[0] -= self.signx*self.imagedim_bf[0]/2
            p0[1] -= self.signy*self.imagedim_bf[1]/2
            x, y = self.signx*(p[0] - p0[0])/self.pixelscale_bf, self.signy*(p[1] - p0[1])/self.pixelscale_bf
            
230
        else:
231 232 233 234 235
            raise ValueError(f'mapToPixel mode: {mode} not understood')
            
        return x, y
            
    def mapToLength(self, pixelpos, mode='df', force=False, returnz=False):
236 237 238
        if not force:
            assert not self.readin
        p0 = copy(self.lastpos)
Hackmet's avatar
Hackmet committed
239 240
        p0[0] += self.coordOffset[0] 
        p0[1] += self.coordOffset[1]
241
        if mode == 'df':
242 243 244
            p0[0] -= self.signx*self.imagedim_df[0]/2
            p0[1] -= self.signy*self.imagedim_df[1]/2
            x, y = (self.signx*pixelpos[0]*self.pixelscale_df + p0[0]), (p0[1] + self.signy*pixelpos[1]*self.pixelscale_df)
245
        elif mode == 'bf':
246 247 248
            p0[0] -= self.signx*self.imagedim_bf[0]/2
            p0[1] -= self.signy*self.imagedim_bf[1]/2
            x, y = (self.signx*pixelpos[0]*self.pixelscale_bf + p0[0]), (p0[1] + self.signy*pixelpos[1]*self.pixelscale_bf)
249
        else:
250
            raise ValueError(f'mapToLength mode: {mode} not understood')
251
        
252
        z = None
253
        if (returnz and self.zvalimg is not None) or self.coordinatetransform is not None:
254 255
            z = self.mapHeight(x, y)
            z += self.getZval(pixelpos)
256 257 258 259 260 261 262 263 264 265 266 267
            
        if self.coordinatetransform is not None:
            T, pc = self.coordinatetransform
            x, y, z = (np.dot(np.array([x,y,z]), T) + pc)
        
        if returnz:
            return x, y, z
        return x, y
    
    def mapToLengthRaman(self, pixelpos, microscopeMode='df', noz=False):
        p0x, p0y, z = self.mapToLength(pixelpos, mode=microscopeMode, returnz=True)
        x, y = p0x + self.pshift[0], p0y + self.pshift[1]
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
        return x, y, z
        
    def newProject(self, fname):
        path = os.path.split(fname)[0]
        name = os.path.splitext(os.path.basename(fname))[0]
        newpath = os.path.join(path, name)
        fname = os.path.join(newpath, name + ".pkl")
        if not os.path.exists(newpath):
            os.mkdir(newpath)        # for new projects a directory will be created
        elif os.path.exists(fname):  # if this project is already there, load it instead
            self.__dict__.update(loadData(fname).__dict__)
        return fname
    
    def getScanPath(self):
        scandir = os.path.join(self.path, "scanimages")
        if not os.path.exists(scandir):
            os.mkdir(scandir)
        return scandir
        
    def updatePath(self):
        self.path = os.path.split(self.fname)[0]
        self.name = os.path.splitext(os.path.basename(self.fname))[0]
JosefBrandt's avatar
JosefBrandt committed
290 291
    
    def getSpectraFileName(self):
JosefBrandt's avatar
 
JosefBrandt committed
292
        return os.path.join(self.path, 'spectra.npy')
JosefBrandt's avatar
JosefBrandt committed
293
    
294
    def getImageName(self):
295 296
        return os.path.join(self.path, 'fullimage.tif')

297 298 299 300 301 302
    def getZvalImageName(self):
        return os.path.join(self.path, "zvalues.tif")
    
    def getLegacyImageName(self):
        return os.path.join(self.path, "fullimage.png")
    
Josef Brandt's avatar
 
Josef Brandt committed
303 304 305
    def getBackgroundImageName(self):
        return os.path.join(self.path, "background.bmp")
    
306 307 308 309 310
    def getTmpImageName(self):
        return os.path.join(self.path, "tmp.bmp")
            
    def save(self):
        saveData(self, self.fname)
311 312 313
    
    def saveBackup(self):
        inc = 0
Hackmet's avatar
Hackmet committed
314
        while True:
315 316 317 318 319 320 321
            directory = os.path.dirname(self.fname)
            filename = self.name + '_backup_' + str(inc) + '.pkl'
            path = os.path.join(directory, filename)
            if os.path.exists(path):
                inc += 1
            else:
                saveData(self, path)
Hackmet's avatar
Hackmet committed
322 323
                return filename