diff --git a/train.py b/train.py index f6d274e..87fad5e 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,8 @@ import torch from matplotlib import pyplot from mpl_toolkits.axisartist.axislines import SubplotZero +torch.manual_seed(42) + # Number of symbols to learn order = 4