-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathscikit_image.py
117 lines (97 loc) · 4.11 KB
/
scikit_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import types
from generic_utils import has_arg
from skimage.segmentation import felzenszwalb, slic, quickshift
class BaseWrapper(object):
"""Base class for LIME Scikit-Image wrapper
Args:
target_fn: callable function or class instance
target_params: dict, parameters to pass to the target_fn
'target_params' takes parameters required to instanciate the
desired Scikit-Image class/model
"""
def __init__(self, target_fn=None, **target_params):
self.target_fn = target_fn
self.target_params = target_params
self.target_fn = target_fn
self.target_params = target_params
def _check_params(self, parameters):
"""Checks for mistakes in 'parameters'
Args :
parameters: dict, parameters to be checked
Raises :
ValueError: if any parameter is not a valid argument for the target function
or the target function is not defined
TypeError: if argument parameters is not iterable
"""
a_valid_fn = []
if self.target_fn is None:
if callable(self):
a_valid_fn.append(self.__call__)
else:
raise TypeError('invalid argument: tested object is not callable,\
please provide a valid target_fn')
elif isinstance(self.target_fn, types.FunctionType) \
or isinstance(self.target_fn, types.MethodType):
a_valid_fn.append(self.target_fn)
else:
a_valid_fn.append(self.target_fn.__call__)
if not isinstance(parameters, str):
for p in parameters:
for fn in a_valid_fn:
if has_arg(fn, p):
pass
else:
raise ValueError('{} is not a valid parameter'.format(p))
else:
raise TypeError('invalid argument: list or dictionnary expected')
def set_params(self, **params):
"""Sets the parameters of this estimator.
Args:
**params: Dictionary of parameter names mapped to their values.
Raises :
ValueError: if any parameter is not a valid argument
for the target function
"""
self._check_params(params)
self.target_params = params
def filter_params(self, fn, override=None):
"""Filters `target_params` and return those in `fn`'s arguments.
Args:
fn : arbitrary function
override: dict, values to override target_params
Returns:
result : dict, dictionary containing variables
in both target_params and fn's arguments.
"""
override = override or {}
result = {}
for name, value in self.target_params.items():
if has_arg(fn, name):
result.update({name: value})
result.update(override)
return result
class SegmentationAlgorithm(BaseWrapper):
""" Define the image segmentation function based on Scikit-Image
implementation and a set of provided parameters
Args:
algo_type: string, segmentation algorithm among the following:
'quickshift', 'slic', 'felzenszwalb'
target_params: dict, algorithm parameters (valid model paramters
as define in Scikit-Image documentation)
"""
def __init__(self, algo_type, **target_params):
self.algo_type = algo_type
if (self.algo_type == 'quickshift'):
BaseWrapper.__init__(self, quickshift, **target_params)
kwargs = self.filter_params(quickshift)
self.set_params(**kwargs)
elif (self.algo_type == 'felzenszwalb'):
BaseWrapper.__init__(self, felzenszwalb, **target_params)
kwargs = self.filter_params(felzenszwalb)
self.set_params(**kwargs)
elif (self.algo_type == 'slic'):
BaseWrapper.__init__(self, slic, **target_params)
kwargs = self.filter_params(slic)
self.set_params(**kwargs)
def __call__(self, *args):
return self.target_fn(args[0], **self.target_params)