Skip to content

Commit accd0b7

Browse files
Refactor code for 5G standard-compliant cyclic prefix length
1 parent 6bc104a commit accd0b7

File tree

8 files changed

+146
-240
lines changed

8 files changed

+146
-240
lines changed

sionna/nr/carrier_config.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"""
77
# pylint: disable=line-too-long
88

9+
import numpy as np
10+
911
from .config import Config
1012

1113
class CarrierConfig(Config):
@@ -244,27 +246,22 @@ def kappa(self):
244246
@property
245247
def cyclic_prefix_length(self):
246248
r"""
247-
float, read-only : Cyclic prefix length of all symbols except for the
248-
first symbol in each half subframe
249-
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
250-
"""
251-
if self.cyclic_prefix=="extended":
252-
cp = 512*self.kappa*2**(-self.mu)
253-
else:
254-
cp = 144*self.kappa*2**(-self.mu)
255-
return cp*self.t_c
256-
257-
@property
258-
def cyclic_prefix_length_first_symbol(self):
259-
r"""
260-
float, read-only : Cyclic prefix length of first symbol in each
261-
half subframe
249+
np.ndarray[float], read-only : Vector of cyclic prefix length of all
250+
symbols in the current slot as defined in Section 5.3.1 [3GPP38211]_
262251
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
263252
"""
253+
cp = np.zeros(self.num_symbols_per_slot)
264254
if self.cyclic_prefix=="extended":
265-
cp = 512*self.kappa*2**(-self.mu)
255+
cp[:] = 512*self.kappa*2**(-self.mu)
266256
else:
267-
cp = 144*self.kappa*2**(-self.mu) + 16*self.kappa
257+
cp[:] = 144*self.kappa*2**(-self.mu)
258+
259+
# Extend cyclic prefix for l=0 or l=7*2^\mu
260+
long_cp_period = 7 * 2 ** self.mu
261+
l_start = self.slot_number * self.num_symbols_per_slot
262+
for i in range(l_start % long_cp_period,
263+
self.num_symbols_per_slot, long_cp_period):
264+
cp[i] += 16*self.kappa
268265
return cp*self.t_c
269266

270267
#-------------------#

sionna/nr/pusch_config.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1050,11 +1050,8 @@ def check_pusch_configs(pusch_configs):
10501050
"num_cdm_groups_without_data" : pc.dmrs.num_cdm_groups_without_data
10511051
}
10521052
params["bandwidth"] = params["num_subcarriers"]*params["subcarrier_spacing"]
1053-
params["cyclic_prefix_length"] = int(np.ceil(carrier.cyclic_prefix_length *
1054-
params["bandwidth"]))
1055-
params["cyclic_prefix_length_first_symbol"] =\
1056-
int(np.ceil(carrier.cyclic_prefix_length_first_symbol
1057-
* params["bandwidth"]))
1053+
params["cyclic_prefix_length"] = np.ceil(carrier.cyclic_prefix_length *
1054+
params["bandwidth"]).astype(int)
10581055

10591056
for pusch_config in pusch_configs:
10601057
if params["precoding"]=="codebook":

sionna/nr/pusch_receiver.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,10 @@ def __init__(self,
156156
assert l_min is not None, \
157157
"l_min must be provided for input_domain==time"
158158
self._l_min = l_min
159-
symbols_per_block = (
160-
pusch_transmitter._carrier_config.num_slots_per_subframe *
161-
pusch_transmitter._carrier_config.num_symbols_per_slot // 2)
162159
self._ofdm_demodulator = OFDMDemodulator(
163160
fft_size=pusch_transmitter._num_subcarriers,
164161
l_min=self._l_min,
165-
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length,
166-
cyclic_prefix_length_first_symbol=
167-
pusch_transmitter._cyclic_prefix_length_first_symbol,
168-
symbols_per_block=symbols_per_block)
162+
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length)
169163

170164
# Use or create default ChannelEstimator
171165
self._perfect_csi = False

sionna/nr/pusch_transmitter.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def __init__(self,
168168
subcarrier_spacing=self._subcarrier_spacing,
169169
num_tx=self._num_tx,
170170
num_streams_per_tx=self._num_layers,
171-
cyclic_prefix_length=self._cyclic_prefix_length,
171+
# TODO: pass vector of cyclic prefix lengths
172+
# (requires rewrite of channel simulation code)
173+
cyclic_prefix_length=self._cyclic_prefix_length[1],
172174
pilot_pattern=self._pilot_pattern,
173175
dtype=dtype)
174176

@@ -183,13 +185,7 @@ def __init__(self,
183185

184186
# (Optionally) Create OFDMModulator
185187
if self._output_domain=="time":
186-
symbols_per_block = (self._carrier_config.num_slots_per_subframe *
187-
self._carrier_config.num_symbols_per_slot // 2)
188-
self._ofdm_modulator = OFDMModulator(
189-
cyclic_prefix_length=self._cyclic_prefix_length,
190-
cyclic_prefix_length_first_symbol=
191-
self._cyclic_prefix_length_first_symbol,
192-
symbols_per_block=symbols_per_block)
188+
self._ofdm_modulator = OFDMModulator(cyclic_prefix_length=self._cyclic_prefix_length)
193189

194190
#########################################
195191
# Public methods and properties

sionna/ofdm/demodulator.py

+33-96
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,11 @@
1515
class OFDMDemodulator(Layer):
1616
# pylint: disable=line-too-long
1717
r"""
18-
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs)
18+
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, **kwargs)
1919
2020
Computes the frequency-domain representation of an OFDM waveform
2121
with cyclic prefix removal.
2222
23-
When only `cyclic_prefix_length` is given then a cyclic prefix of length
24-
`cyclic_prefix_length` is removed from each symbol. When additionally
25-
`cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then
26-
the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for
27-
the first symbol of each block and `cyclic_prefix_length` for the
28-
remaining symbols. For LTE one block corresponds to one slot (i.e., 7
29-
symbols). For 5G NR one block corresponds to one half subframe and the
30-
number of symbols depends on the numerology.
31-
3223
The demodulator assumes that the input sequence is generated by the
3324
:class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
3425
the received signal sequence is given as:
@@ -72,17 +63,9 @@ class OFDMDemodulator(Layer):
7263
impulse response. It should be the same value as that used by the
7364
`cir_to_time_channel` function.
7465
75-
cyclic_prefix_length : int
76-
Integer indicating the length of the cyclic prefix that it prepended
77-
to each OFDM symbol (except for the first symbol of each block if
78-
`cyclic_prefix_length_first_symbol` and `symbols per block` is given).
79-
80-
cyclic_prefix_length_first_symbol : int
81-
Integer indicating the length of the cyclic prefix that it prepended
82-
to the first OFDM symbol of each block.
83-
84-
symbols_per_block : int
85-
Integer indicating the number of symbols per block.
66+
cyclic_prefix_length : int or list[int] or np.ndarray[int]
67+
Integer or vector of integers indicating the length of the cyclic
68+
prefix that it prepended to each OFDM symbol.
8669
8770
Input
8871
-----
@@ -98,15 +81,11 @@ class OFDMDemodulator(Layer):
9881
"""
9982

10083
def __init__(self, fft_size, l_min, cyclic_prefix_length=0,
101-
cyclic_prefix_length_first_symbol=None, symbols_per_block=1,
10284
**kwargs):
10385
super().__init__(**kwargs)
10486
self.fft_size = fft_size
10587
self.l_min = l_min
10688
self.cyclic_prefix_length = cyclic_prefix_length
107-
self.cyclic_prefix_length_first_symbol =(
108-
cyclic_prefix_length_first_symbol)
109-
self.symbols_per_block = symbols_per_block
11089

11190
@property
11291
def fft_size(self):
@@ -133,34 +112,18 @@ def cyclic_prefix_length(self):
133112

134113
@cyclic_prefix_length.setter
135114
def cyclic_prefix_length(self, value):
136-
assert isinstance(value, int) and value >=0,\
137-
"`cyclic_prefix_length` must be a nonnegative integer."
138-
self._cyclic_prefix_length = int(value)
139-
140-
@property
141-
def cyclic_prefix_length_first_symbol(self):
142-
if self._cyclic_prefix_length_first_symbol is None:
143-
return self._cyclic_prefix_length
115+
if isinstance(value, list):
116+
value = np.array(value)
117+
if isinstance(value, np.ndarray):
118+
assert (np.issubdtype(value.dtype, np.integer) and
119+
value.ndim == 1 and np.all(value >= 0)),\
120+
("`cyclic_prefix_length` must be a 1D array with"
121+
" only nonnegative integers.")
144122
else:
145-
return self._cyclic_prefix_length_first_symbol
123+
assert isinstance(value, int) and value >=0,\
124+
"`cyclic_prefix_length` must be a nonnegative integer."
125+
self._cyclic_prefix_length = value
146126

147-
@cyclic_prefix_length_first_symbol.setter
148-
def cyclic_prefix_length_first_symbol(self, value):
149-
assert (value is None or isinstance(value, int) and
150-
value >= self._cyclic_prefix_length),\
151-
("`cyclic_prefix_length_first_symbol` must be integer and " +
152-
"larger or equal to `cyclic_prefix_length`.")
153-
self._cyclic_prefix_length_first_symbol = value
154-
155-
@property
156-
def symbols_per_block(self):
157-
return self._symbols_per_block
158-
159-
@symbols_per_block.setter
160-
def symbols_per_block(self, value):
161-
assert isinstance(value, int) and value >= 1,\
162-
"`symbols_per_block` must be a positive integer."
163-
self._symbols_per_block = value
164127

165128
def build(self, input_shape):
166129
num_samples = input_shape[-1]
@@ -170,24 +133,23 @@ def build(self, input_shape):
170133
* tf.range(self.fft_size, dtype=tf.float32)
171134
self._phase_compensation = tf.exp(tf.complex(0., tmp))
172135

173-
self._samples_per_block = (self.cyclic_prefix_length_first_symbol +
174-
(self.symbols_per_block - 1) * self.cyclic_prefix_length +
175-
self.symbols_per_block * self.fft_size)
176-
177-
# Compute number of elements that will be truncated and number of
178-
# symbols for padding
179-
self._rest = num_samples % self._samples_per_block
180-
samples_first_symbol = (self.cyclic_prefix_length_first_symbol +
181-
self.fft_size)
182-
samples_other_symbols = (self.cyclic_prefix_length + self.fft_size)
183-
if self._rest > samples_first_symbol:
184-
self._rest -= samples_first_symbol
185-
excess_symbols = self._rest // samples_other_symbols
186-
self._rest -= excess_symbols * samples_other_symbols
187-
excess_symbols += 1 # Because of first symbol in block
188-
self._num_pad_symbols = self.symbols_per_block - excess_symbols
136+
if isinstance(self.cyclic_prefix_length, int):
137+
self._num_ofdm_symbols = (input_shape[-1] //
138+
(self.fft_size + self.cyclic_prefix_length))
139+
self.cyclic_prefix_length = np.full(self._num_ofdm_symbols,
140+
self.cyclic_prefix_length)
189141
else:
190-
self._num_pad_symbols = 0
142+
self._num_ofdm_symbols = self.cyclic_prefix_length.shape[0]
143+
144+
symbol_ends = tf.math.cumsum(self.cyclic_prefix_length + self.fft_size)
145+
assert num_samples >= symbol_ends[-1],\
146+
"shape(inputs)[-1] must be larger or equal than samples per slot"
147+
148+
gather_idx = []
149+
for i in range(self._num_ofdm_symbols):
150+
gather_idx.append(tf.range(symbol_ends[i] - self.fft_size,
151+
symbol_ends[i]))
152+
self._gather_idx = tf.concat(gather_idx, 0)
191153

192154
def call(self, inputs):
193155
"""Demodulate OFDM waveform onto a resource grid.
@@ -202,39 +164,14 @@ def call(self, inputs):
202164
"""
203165
batch_dims = tf.shape(inputs)[:-1]
204166

205-
# Cut last samples that do not fit into an OFDM symbol
206-
x = inputs if self._rest == 0 else inputs[..., :-self._rest]
207-
208-
if self._num_pad_symbols > 0:
209-
pad_samples = self._num_pad_symbols * (self.fft_size +
210-
self.cyclic_prefix_length)
211-
padding_shape = tf.concat([batch_dims, [pad_samples]], axis=0)
212-
padding = tf.zeros(padding_shape, dtype=x.dtype)
213-
x = tf.concat([x, padding], axis=-1)
214-
215-
# Reshape input to blocks
216-
num_blocks = tf.shape(x)[-1] // self._samples_per_block
217-
new_shape = tf.concat([batch_dims,
218-
[num_blocks, self._samples_per_block]], 0)
219-
x = tf.reshape(x, new_shape)
220-
221-
# Remove extra cyclic prefix from first symbol
222-
x = x[...,(self.cyclic_prefix_length_first_symbol -
223-
self.cyclic_prefix_length):]
167+
# Gather samples that do not belong to cyclic prefix
168+
x = tf.gather(inputs, self._gather_idx, axis=-1)
224169

225170
# Reshape input to separate OFDM symbols
226171
new_shape = tf.concat([batch_dims,
227-
[num_blocks * self.symbols_per_block],
228-
[self.fft_size + self.cyclic_prefix_length]], 0)
172+
[self._num_ofdm_symbols, self.fft_size]], 0)
229173
x = tf.reshape(x, new_shape)
230174

231-
# Remove padding
232-
if self._num_pad_symbols > 0:
233-
x = x[..., :-self._num_pad_symbols, :]
234-
235-
# Remove cyclic prefix
236-
x = x[...,self.cyclic_prefix_length:]
237-
238175
# Compute FFT
239176
x = fft(x)
240177

0 commit comments

Comments
 (0)