Commit 213419d2 authored by Josef Brandt's avatar Josef Brandt

Extract mpCountError vs fraction data

parent 39031448
......@@ -8,6 +8,8 @@ Created on Wed Jan 22 13:57:28 2020
import pickle
import sys
import os
import numpy as np
from helpers import ParticleBinSorter
import methods as meth
import geometricMethods as gmeth
......@@ -22,7 +24,7 @@ def get_name_from_directory(dirPath: str) -> str:
class TotalResults(object):
methods: list = [meth.RandomSampling, meth.SizeBinFractioning, gmeth.CrossBoxSubSampling,
gmeth.SpiralBoxSubsampling]
measuredFreactions: list = [0.1, 0.3, 0.5, 0.9]
measuredFreactions: list = [0.05, 0.1, 0.15, 0.2, 0.3, 0.5, 0.9]
def __init__(self):
super(TotalResults, self).__init__()
......@@ -49,16 +51,41 @@ class TotalResults(object):
Updates all samples with all methods and all fractions
:return:
"""
for sample in self.sampleResults:
for index, sample in enumerate(self.sampleResults):
sample.load_dataset()
for fraction in self.measuredFreactions:
possibleMethods = self._get_methods_for_fraction(sample.dataset, fraction)
for curMethod in possibleMethods:
print(f'updating {sample.sampleName} with {curMethod.label} at fraction {fraction}')
# print(f'updating {sample.sampleName} with {curMethod.label} at fraction {fraction}')
sample.update_result_with_method(curMethod)
print(f'processed {index+1} of {len(self.sampleResults)} samples')
def get_error_vs_fraction_data(self) -> dict:
return {}
"""
Returns Dict: Key: Method Label, Value: (Dict: Key:Measured Fraction, Value: averaged MPCountError over all samples)
:return:
"""
result: dict = {}
for sample in self.sampleResults:
for res in sample.results:
res: SubsamplingResult = res
label: str = res.method.label
frac: float = res.method.fraction
error: float = res.mpCountError
if label not in result.keys():
result[label] = {frac: [error]}
elif frac not in result[label].keys():
result[label][frac] = [error]
else:
result[label][frac].append(error)
for method in result.keys():
methodRes: dict = result[method]
for fraction in methodRes.keys():
methodRes[fraction] = np.mean(methodRes[fraction])
return result
def _get_methods_for_fraction(self, dataset: dataset.DataSet, fraction: float) -> list:
"""
......@@ -145,7 +172,7 @@ class SampleResult(object):
"""
isPresent: bool = False
for result in self.results:
if type(result.method) == type(method) and result.fraction == method.fraction:
if method.equals(result.method):
isPresent = True
break
return isPresent
......@@ -177,8 +204,9 @@ class SubsamplingResult(object):
fraction: float = self.method.fraction
self.mpCountError = self._get_mp_count_error(origParticles, subParticles, fraction)
print(f'{self.origParticleCount} particles, thereof {self.subSampledParticleCount} measured, error: {self.mpCountError}')
# print(f'{self.origParticleCount} particles, thereof {self.subSampledParticleCount} measured, error: {self.mpCountError}')
self.mpCountErrorPerBin = self._get_mp_count_error_per_bin(origParticles, subParticles, fraction)
# print(f'method {self.method.label} updated, result is {self.mpCountError}')
def _get_mp_count_error_per_bin(self, allParticles: list, subParticles: list, fractionMeasured: float) -> tuple:
binSorter = ParticleBinSorter()
......
import numpy as np
import matplotlib.pyplot as plt
import time
import sys
sys.path.append("C://Users//xbrjos//Desktop//Python")
from gepard import dataset
import gepardevaluation
from methods import RandomSampling, SizeBinFractioning
from geometricMethods import BoxSelectionCreator
from helpers import ParticleBinSorter
from evaluation import TotalResults, SampleResult
from input_output import get_pkls_from_directory, get_attributes_from_foldername
......@@ -27,7 +19,20 @@ for folder in pklsInFolders.keys():
for attr in get_attributes_from_foldername(folder):
newSampleResult.set_attribute(attr)
t0 = time.time()
results.update_all()
print('updating all took', time.time()-t0, 'seconds')
errorPerFraction: dict = results.get_error_vs_fraction_data()
plt.clf()
for methodLabel in errorPerFraction.keys():
fractions: list = list(errorPerFraction[methodLabel].keys())
errors: list = list(errorPerFraction[methodLabel].values())
plt.plot(fractions, errors, label=methodLabel)
plt.xscale('log')
plt.xlabel('measured fraction')
plt.ylabel('mpCountError')
plt.legend()
plt.show()
print('done')
\ No newline at end of file
......@@ -94,7 +94,56 @@ class TestTotalResults(unittest.TestCase):
self.assertFalse(containsMethod(methods, gmeth.SpiralBoxSubsampling(dset, desiredFraction)))
def test_get_error_vs_fraction_data(self):
pass
firstSample: SampleResult = self.totalResults.add_sample('sample1.pkl')
secondSample: SampleResult = self.totalResults.add_sample('sample2.pkl')
firstMethod: meth.RandomSampling = meth.RandomSampling(None, 0.1)
firstResult: SubsamplingResult = SubsamplingResult(firstMethod)
firstResult.mpCountError = 0.8
secondMethod: gmeth.CrossBoxSubSampling = gmeth.CrossBoxSubSampling(None, 0.1)
secondMethod.numBoxesAcross = 3
secondResult: SubsamplingResult = SubsamplingResult(secondMethod)
secondResult.mpCountError = 0.6
thirdMethod: gmeth.CrossBoxSubSampling = gmeth.CrossBoxSubSampling(None, 0.1)
thirdMethod.numBoxesAcross = 5
self.assertEqual(thirdMethod.fraction, 0.1)
thirdResult: SubsamplingResult = SubsamplingResult(thirdMethod)
thirdResult.mpCountError = 0.4
thirdMethod2: gmeth.CrossBoxSubSampling = gmeth.CrossBoxSubSampling(None, 0.1)
thirdMethod2.numBoxesAcross = 5
self.assertEqual(thirdMethod2.fraction, 0.1)
thirdResult2: SubsamplingResult = SubsamplingResult(thirdMethod)
thirdResult2.mpCountError = 0.8
thirdMethod3: gmeth.CrossBoxSubSampling = gmeth.CrossBoxSubSampling(None, 0.2)
thirdMethod3.numBoxesAcross = 5
self.assertEqual(thirdMethod3.fraction, 0.2)
thirdResult3: SubsamplingResult = SubsamplingResult(thirdMethod3)
thirdResult3.mpCountError = 0.5
firstSample.results = [firstResult, secondResult, thirdResult, thirdResult3]
secondSample.results = [firstResult, secondResult, thirdResult2, thirdResult3]
resultDict: dict = self.totalResults.get_error_vs_fraction_data()
self.assertEqual(list(resultDict.keys()), [firstMethod.label, secondMethod.label, thirdMethod.label])
for i in range(3):
res: dict = list(resultDict.values())[i]
if i == 0:
self.assertEqual(list(res.keys()), [0.1])
self.assertAlmostEqual(res[0.1], 0.8)
if i == 1:
self.assertEqual(list(res.keys()), [0.1])
self.assertAlmostEqual(res[0.1], 0.6)
if i == 2:
self.assertEqual(list(res.keys()), [0.1, 0.2])
self.assertAlmostEqual(res[0.1], 0.6) # i.e., mean([0.4, 0.8])
self.assertAlmostEqual(res[0.2], 0.5)
# if i == 3:
# self.assertEqual(list(res.keys()), [0.1, 0.2])
# self.assertAlmostEqual(res[0.1], )
class TestSampleResult(unittest.TestCase):
......
......@@ -34,6 +34,7 @@ class TestBoxSelector(unittest.TestCase):
self.assertEqual(newTopLefts[2], (5, -25))
self.assertEqual(newTopLefts[3], (10, 0))
class TestSelectCrossBoxes(unittest.TestCase):
def setUp(self) -> None:
self.crossBoxSubsampler = CrossBoxSubSampling(None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment