import torch.nn as nn import torch import numpy as np def channel_OSNR(): Sys_rate = 32e9 r = 0.05 Dispersion = 16.48e-6 B_2 = Dispersion Non_linear_index = 1.3e3 Gam = Non_linear_index Loss = 10**20 Alpha = Loss Span_count = 20 N_s = Span_count Span_length = 10e5 # (km) L_s = Span_length Noise_figure = 10**0.5 # (dB) h = 6.6261e-34 v = 299792458 B_WDM = Sys_rate*(1+r) B_N = 0.1 P_ASE_1 = h*v*B_N*(Loss*Span_length*Noise_figure-1) P_ASE = P_ASE_1 * Span_count L_eff = 1-np.exp(-Loss*Span_length)/2/Alpha eps = 0.3*np.log(1+(6/L_s)*(L_eff/np.arcsinh((np.pi**2/3)*B_2*L_eff*B_WDM**2))) b = P_ASE_1/(2*(N_s**eps)*B_N*(Gam**2)*L_eff*np.arcsinh((np.pi**2/3)*B_2*L_eff*B_WDM**2)) P_ch = Sys_rate*(((27*np.pi*B_2/16)*b)**(1/3)) OSNR = (2*P_ch/3/P_ASE) OSNR_dB = 10*np.log10(OSNR) return OSNR_dB def Const_Points_Pow(IQ): """ Calculate the average power of a set of vectors. """ p_enc_avg = (torch.norm(IQ, dim=1) ** 2).mean() p_enc_avg_dB = 10 * torch.log10(p_enc_avg) return p_enc_avg_dB def Pow_Noise(CH_OSNR, CPP): """ Calculate the power of channel noise. """ P_N_dB = CPP - CH_OSNR p_N_watt = 10**(P_N_dB/10) Var_Noise = p_N_watt return Var_Noise def Channel_Noise_Model(NV, S): """ Compute the Gaussian noise to be added to each vector to simulate passing through a channel. """ return torch.distributions.normal.Normal( 0, torch.sqrt(NV*5000) ).sample(S) class GaussianChannel(nn.Module): def __init__(self): super().__init__() def forward(self, x): Noise_Variance = Pow_Noise(channel_OSNR(), Const_Points_Pow(x)) Noise_Volts = Channel_Noise_Model(Noise_Variance, [len(x), 2]) return x + Noise_Volts