diff options
| author | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:26:27 -0500 |
|---|---|---|
| committer | Gertjan van den Burg <gertjanvandenburg@gmail.com> | 2019-03-06 13:26:27 -0500 |
| commit | 10a77e16a4176d6f015d408dd2d162fc922ebff4 (patch) | |
| tree | 6f3a6be596fbeed6bf80537827197344a52d9e9e /gensvm | |
| parent | update submodule (diff) | |
| download | pygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.tar.gz pygensvm-10a77e16a4176d6f015d408dd2d162fc922ebff4.zip | |
Add predefined parameter grids
Diffstat (limited to 'gensvm')
| -rw-r--r-- | gensvm/gridsearch.py | 120 |
1 files changed, 113 insertions, 7 deletions
diff --git a/gensvm/gridsearch.py b/gensvm/gridsearch.py index 3fdfa9a..ed12b97 100644 --- a/gensvm/gridsearch.py +++ b/gensvm/gridsearch.py @@ -284,10 +284,15 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): Parameters ---------- - param_grid : dict or list of dicts - Dictionary of parameter names (strings) as keys and lists of parameter - settings to evaluate as values, or a list of such dicts. The GenSVM - model will be evaluated at all combinations of the parameters. + param_grid : string, dict, or list of dicts + If a string, it must be either 'tiny', 'small', or 'full' to load the + predefined parameter grids (see the functions :func:`load_grid_tiny`, + :func:`load_grid_small`, and :func:`load_grid_full`). + + Otherwise, a dictionary of parameter names (strings) as keys and lists + of parameter settings to evaluate as values, or a list of such dicts. + The GenSVM model will be evaluated at all combinations of the + parameters. scoring : string, callable, list/tuple, dict or None A single string (see :ref:`scoring_parameter`) or a callable (see @@ -491,7 +496,7 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): def __init__( self, - param_grid, + param_grid="tiny", scoring=None, iid=True, cv=None, @@ -501,6 +506,15 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): ): self.param_grid = param_grid + if isinstance(self.param_grid, str): + if self.param_grid == "tiny": + self.param_grid = load_grid_tiny() + elif self.param_grid == "small": + self.param_grid = load_grid_small() + elif self.param_grid == "full": + self.param_grid = load_grid_full() + else: + raise ValueError("Unknown param grid %r" % self.param_grid) _check_param_grid(self.param_grid) _validate_param_grid(self.param_grid) @@ -644,8 +658,100 @@ class GenSVMGridSearchCV(BaseEstimator, MetaEstimatorMixin): return self.best_estimator_.predict(X) -def load_default_grid(): - """Load the default parameter grid for GenSVM +def load_grid_tiny(): + """ Load a tiny parameter grid for the GenSVM grid search + + This function returns a parameter grid to use in the GenSVM grid search. + This grid was obtained by analyzing the experiments done for the GenSVM + paper and selecting the configurations that achieve accuracy within the + 95th percentile on over 90% of the datasets. It is a good start for a + parameter search with a reasonably high chance of achieving good + performance on most datasets. + + Note that this grid is only tested to work well in combination with the + linear kernel. + + Returns + ------- + + pg : list + List of 10 parameter configurations that are likely to perform + reasonably well. + + """ + + pg = [ + {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -16)], "weights": ["unit"]}, + {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -18)], "weights": ["unit"]}, + {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -18)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [5.0], + "lmd": [pow(2, -18)], + "weights": ["group"], + }, + { + "p": [2.0], + "kappa": [-0.9], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, + {"p": [2.0], "kappa": [5.0], "lmd": [pow(2, -14)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -18)], + "weights": ["group"], + }, + { + "p": [1.5], + "kappa": [-0.9], + "lmd": [pow(2, -18)], + "weights": ["unit"], + }, + {"p": [2.0], "kappa": [0.5], "lmd": [pow(2, -16)], "weights": ["unit"]}, + { + "p": [2.0], + "kappa": [0.5], + "lmd": [pow(2, -16)], + "weights": ["group"], + }, + ] + return pg + + +def load_grid_small(): + """Load a small parameter grid for GenSVM + + This function loads a default parameter grid to use for the #' GenSVM + gridsearch. It contains all possible combinations of the following #' + parameter sets:: + + pg = { + 'p': [1.0, 1.5, 2.0], + 'lmd': [1e-8, 1e-6, 1e-4, 1e-2, 1], + 'kappa': [-0.9, 0.5, 5.0], + 'weights': ['unit', 'group'], + } + + Returns + ------- + + pg : dict + Mapping from parameters to lists of values for those parameters. To be + used as input for the :class:`.GenSVMGridSearchCV` class. + """ + pg = { + "p": [1.0, 1.5, 2.0], + "lmd": [1e-8, 1e-6, 1e-4, 1e-2, 1], + "kappa": [-0.9, 0.5, 5.0], + "weights": ["unit", "group"], + } + return pg + + +def load_grid_full(): + """Load the full parameter grid for GenSVM This is the parameter grid used in the GenSVM paper to run the grid search experiments. It uses a large grid for the ``lmd`` regularization parameter |
