diff --git a/constellation/util.py b/constellation/util.py index a02c2c0..bc8f649 100644 --- a/constellation/util.py +++ b/constellation/util.py @@ -1,5 +1,7 @@ import torch import matplotlib +from matplotlib.colors import ListedColormap +import seaborn def get_random_messages(count, order): @@ -61,7 +63,7 @@ def plot_constellation( ax.grid() order = len(constellation) - color_map = matplotlib.cm.Dark2 + color_map = ListedColormap(seaborn.color_palette('husl', n_colors=order)) color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order) # Extend axes symmetrically around zero so that they fit data