diff --git a/constellation/util.py b/constellation/util.py index fd2713b..eba9789 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -1,3 +1,4 @@ +import itertools import torch import matplotlib from matplotlib.colors import ListedColormap @@ -153,3 +154,22 @@ def plot_constellation( alpha=0.3, zorder=8 ) + + +def product_dict(**kwargs): + """ + Compute cartesian product of a set of parameters. + + >>> list(product_dict(first=[1, 2, 3], second=['a', 'b'])) + [{'first': 1, 'second': 'a'}, + {'first': 1, 'second': 'b'}, + {'first': 2, 'second': 'a'}, + {'first': 2, 'second': 'b'}, + {'first': 3, 'second': 'a'}, + {'first': 3, 'second': 'b'}] + """ + keys = kwargs.keys() + vals = kwargs.values() + + for instance in itertools.product(*vals): + yield dict(zip(keys, instance))