#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed Jan 22 13:58:25 2020 @author: luna """ import unittest import random import sys sys.path.append("C://Users//xbrjos//Desktop//Python") import gepard from gepard.analysis.particleAndMeasurement import Particle, Measurement from evaluation import ResultComparer class TestResultComparer(unittest.TestCase): def setUp(self): self.resultComparer = ResultComparer() def test_get_error_per_bin(self): def get_full_and_sub_particles(particleSizes, numParticlesPerSizeFull, numParticlesPerSizeSub): fullParticles = [] subParticles = [] for particleSize in particleSizes: for _ in range(numParticlesPerSizeFull): mpParticle = self._get_MP_particle() mpParticle.longSize = mpParticle.shortSize = particleSize fullParticles.append(mpParticle) for _ in range(numParticlesPerSizeSub): mpParticle = self._get_MP_particle() mpParticle.longSize = mpParticle.shortSize = particleSize subParticles.append(mpParticle) return fullParticles, subParticles binSizes = [5, 10, 20, 50, 100, 200, 500] particleSizes = [upperLimit - 1 for upperLimit in binSizes] numParticlesPerSizeFull = 20 numParticlesPerSizeSub = 10 fullParticles, subParticles = get_full_and_sub_particles(particleSizes, numParticlesPerSizeFull, numParticlesPerSizeSub) #assume everything was measured bins, mpCountErrorsPerBin = self.resultComparer._get_mp_count_error_per_bin(fullParticles, subParticles, 1.) for binIndex, binError in enumerate(mpCountErrorsPerBin): if binIndex <= 6: self.assertEqual(binError, 0.5) else: #it's the last and largest bin, no particles where added there self.assertEqual(binError, 0) #assume only 50 % was measured bins, mpCountErrorsPerBin = self.resultComparer._get_mp_count_error_per_bin(fullParticles, subParticles, 0.5) for binIndex, binError in enumerate(mpCountErrorsPerBin): self.assertEqual(binError, 0) def test_get_number_of_MP_particles(self): mpParticles = self._get_MP_particles(5) numMPParticles = len(mpParticles) nonMPparticles = self._get_non_MP_particles(50) allParticles = mpParticles + nonMPparticles calculatedNumMPParticles = self.resultComparer._get_number_of_MP_particles(allParticles) self.assertEqual(numMPParticles, calculatedNumMPParticles) def test_get_mp_count_error(self): mpParticles1 = self._get_MP_particles(20) nonMPparticles1 = self._get_non_MP_particles(20) origParticles = mpParticles1 + nonMPparticles1 mpParticles2 = self._get_MP_particles(30) estimateParticles = mpParticles2 + nonMPparticles1 mpCountError = self.resultComparer._get_mp_count_error(origParticles, estimateParticles, 1.0) self.assertEqual(mpCountError, 0.5) mpParticles2 = self._get_MP_particles(20) estimateParticles = mpParticles2 + nonMPparticles1 mpCountError = self.resultComparer._get_mp_count_error(origParticles, estimateParticles, 1.0) self.assertEqual(mpCountError, 0) mpCountError = self.resultComparer._get_mp_count_error(origParticles, estimateParticles, 0.5) self.assertEqual(mpCountError, 1.0) def test_get_error_from_values(self): exact, estimate = 100, 90 error = self.resultComparer._get_error_from_values(exact, estimate) self.assertEqual(error, 0.1) exact, estimate = 100, 110 error = self.resultComparer._get_error_from_values(exact, estimate) self.assertEqual(error, 0.1) exact, estimate = 100, 50 error = self.resultComparer._get_error_from_values(exact, estimate) self.assertEqual(error, 0.5) exact, estimate = 100, 150 error = self.resultComparer._get_error_from_values(exact, estimate) self.assertEqual(error, 0.5) def _get_MP_particles(self, numParticles): mpParticles = [] for _ in range(numParticles): mpParticles.append(self._get_MP_particle()) return mpParticles def _get_non_MP_particles(self, numParticles): nonMPParticles = [] for _ in range(numParticles): nonMPParticles.append(self._get_non_MP_particle()) return nonMPParticles def _get_MP_particle(self): polymerNames = ['Poly (methyl methacrylate', 'Polyethylene', 'Silicone rubber', 'PB15', 'PY13', 'PR20'] polymName = random.sample(polymerNames, 1)[0] newParticle = Particle() newMeas = Measurement() newMeas.setAssignment(polymName) newParticle.addMeasurement(newMeas) return newParticle def _get_non_MP_particle(self): newParticle = Particle() newParticle.addMeasurement(Measurement()) return newParticle if __name__ == '__main__': unittest.main()