Skip to content

Commit 6bc104a

Browse files
Fix cyclic prefix length of first symbol for 5g NR PUSCH
Signed-off-by: Daniel Schäufele <[email protected]>
1 parent 0a4a22e commit 6bc104a

File tree

7 files changed

+287
-39
lines changed

7 files changed

+287
-39
lines changed

Diff for: sionna/nr/carrier_config.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,27 @@ def kappa(self):
244244
@property
245245
def cyclic_prefix_length(self):
246246
r"""
247-
float, read-only : Cyclic prefix length
247+
float, read-only : Cyclic prefix length of all symbols except for the
248+
first symbol in each half subframe
248249
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
249250
"""
250251
if self.cyclic_prefix=="extended":
251-
cp = 512*self.kappa*2**(-self.mu)
252+
cp = 512*self.kappa*2**(-self.mu)
252253
else:
253254
cp = 144*self.kappa*2**(-self.mu)
254-
if self.slot_number in [0, 7*2**self.mu]:
255-
cp += 16*self.kappa
255+
return cp*self.t_c
256+
257+
@property
258+
def cyclic_prefix_length_first_symbol(self):
259+
r"""
260+
float, read-only : Cyclic prefix length of first symbol in each
261+
half subframe
262+
:math:`N_{\text{CP},l}^{\mu} \cdot T_{\text{c}}` [s]
263+
"""
264+
if self.cyclic_prefix=="extended":
265+
cp = 512*self.kappa*2**(-self.mu)
266+
else:
267+
cp = 144*self.kappa*2**(-self.mu) + 16*self.kappa
256268
return cp*self.t_c
257269

258270
#-------------------#

Diff for: sionna/nr/pusch_config.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1050,8 +1050,11 @@ def check_pusch_configs(pusch_configs):
10501050
"num_cdm_groups_without_data" : pc.dmrs.num_cdm_groups_without_data
10511051
}
10521052
params["bandwidth"] = params["num_subcarriers"]*params["subcarrier_spacing"]
1053-
params["cyclic_prefix_length"] = np.ceil(carrier.cyclic_prefix_length *
1054-
params["bandwidth"])
1053+
params["cyclic_prefix_length"] = int(np.ceil(carrier.cyclic_prefix_length *
1054+
params["bandwidth"]))
1055+
params["cyclic_prefix_length_first_symbol"] =\
1056+
int(np.ceil(carrier.cyclic_prefix_length_first_symbol
1057+
* params["bandwidth"]))
10551058

10561059
for pusch_config in pusch_configs:
10571060
if params["precoding"]=="codebook":

Diff for: sionna/nr/pusch_receiver.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,16 @@ def __init__(self,
156156
assert l_min is not None, \
157157
"l_min must be provided for input_domain==time"
158158
self._l_min = l_min
159+
symbols_per_block = (
160+
pusch_transmitter._carrier_config.num_slots_per_subframe *
161+
pusch_transmitter._carrier_config.num_symbols_per_slot // 2)
159162
self._ofdm_demodulator = OFDMDemodulator(
160163
fft_size=pusch_transmitter._num_subcarriers,
161164
l_min=self._l_min,
162-
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length)
165+
cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length,
166+
cyclic_prefix_length_first_symbol=
167+
pusch_transmitter._cyclic_prefix_length_first_symbol,
168+
symbols_per_block=symbols_per_block)
163169

164170
# Use or create default ChannelEstimator
165171
self._perfect_csi = False

Diff for: sionna/nr/pusch_transmitter.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,13 @@ def __init__(self,
183183

184184
# (Optionally) Create OFDMModulator
185185
if self._output_domain=="time":
186-
self._ofdm_modulator = OFDMModulator(self._cyclic_prefix_length)
186+
symbols_per_block = (self._carrier_config.num_slots_per_subframe *
187+
self._carrier_config.num_symbols_per_slot // 2)
188+
self._ofdm_modulator = OFDMModulator(
189+
cyclic_prefix_length=self._cyclic_prefix_length,
190+
cyclic_prefix_length_first_symbol=
191+
self._cyclic_prefix_length_first_symbol,
192+
symbols_per_block=symbols_per_block)
187193

188194
#########################################
189195
# Public methods and properties

Diff for: sionna/ofdm/demodulator.py

+103-19
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,20 @@
1515
class OFDMDemodulator(Layer):
1616
# pylint: disable=line-too-long
1717
r"""
18-
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs)
18+
OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs)
1919
2020
Computes the frequency-domain representation of an OFDM waveform
2121
with cyclic prefix removal.
2222
23+
When only `cyclic_prefix_length` is given then a cyclic prefix of length
24+
`cyclic_prefix_length` is removed from each symbol. When additionally
25+
`cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then
26+
the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for
27+
the first symbol of each block and `cyclic_prefix_length` for the
28+
remaining symbols. For LTE one block corresponds to one slot (i.e., 7
29+
symbols). For 5G NR one block corresponds to one half subframe and the
30+
number of symbols depends on the numerology.
31+
2332
The demodulator assumes that the input sequence is generated by the
2433
:class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
2534
the received signal sequence is given as:
@@ -49,7 +58,7 @@ class OFDMDemodulator(Layer):
4958
each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`.
5059
This is a very important step to enable channel estimation with
5160
sparse pilot patterns that needs to interpolate the channel frequency
52-
response accross subcarriers. It also ensures that the
61+
response across subcarriers. It also ensures that the
5362
channel frequency response `seen` by the time-domain channel
5463
is close to the :class:`~sionna.channel.OFDMChannel`.
5564
@@ -64,8 +73,16 @@ class OFDMDemodulator(Layer):
6473
`cir_to_time_channel` function.
6574
6675
cyclic_prefix_length : int
67-
Integer indicating the length of the cyclic prefix that
68-
is prepended to each OFDM symbol.
76+
Integer indicating the length of the cyclic prefix that it prepended
77+
to each OFDM symbol (except for the first symbol of each block if
78+
`cyclic_prefix_length_first_symbol` and `symbols per block` is given).
79+
80+
cyclic_prefix_length_first_symbol : int
81+
Integer indicating the length of the cyclic prefix that it prepended
82+
to the first OFDM symbol of each block.
83+
84+
symbols_per_block : int
85+
Integer indicating the number of symbols per block.
6986
7087
Input
7188
-----
@@ -80,19 +97,25 @@ class OFDMDemodulator(Layer):
8097
two dimension.
8198
"""
8299

83-
def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs):
100+
def __init__(self, fft_size, l_min, cyclic_prefix_length=0,
101+
cyclic_prefix_length_first_symbol=None, symbols_per_block=1,
102+
**kwargs):
84103
super().__init__(**kwargs)
85104
self.fft_size = fft_size
86105
self.l_min = l_min
87106
self.cyclic_prefix_length = cyclic_prefix_length
107+
self.cyclic_prefix_length_first_symbol =(
108+
cyclic_prefix_length_first_symbol)
109+
self.symbols_per_block = symbols_per_block
88110

89111
@property
90112
def fft_size(self):
91113
return self._fft_size
92114

93115
@fft_size.setter
94116
def fft_size(self, value):
95-
assert value>0, "`fft_size` must be positive."
117+
assert isinstance(value, int) and value>0,\
118+
"`fft_size` must be a positive integer."
96119
self._fft_size = int(value)
97120

98121
@property
@@ -110,23 +133,61 @@ def cyclic_prefix_length(self):
110133

111134
@cyclic_prefix_length.setter
112135
def cyclic_prefix_length(self, value):
113-
assert value >=0, "`cyclic_prefix_length` must be nonnegative."
136+
assert isinstance(value, int) and value >=0,\
137+
"`cyclic_prefix_length` must be a nonnegative integer."
114138
self._cyclic_prefix_length = int(value)
115139

116-
def build(self, input_shape): # pylint: disable=unused-argument
140+
@property
141+
def cyclic_prefix_length_first_symbol(self):
142+
if self._cyclic_prefix_length_first_symbol is None:
143+
return self._cyclic_prefix_length
144+
else:
145+
return self._cyclic_prefix_length_first_symbol
146+
147+
@cyclic_prefix_length_first_symbol.setter
148+
def cyclic_prefix_length_first_symbol(self, value):
149+
assert (value is None or isinstance(value, int) and
150+
value >= self._cyclic_prefix_length),\
151+
("`cyclic_prefix_length_first_symbol` must be integer and " +
152+
"larger or equal to `cyclic_prefix_length`.")
153+
self._cyclic_prefix_length_first_symbol = value
154+
155+
@property
156+
def symbols_per_block(self):
157+
return self._symbols_per_block
158+
159+
@symbols_per_block.setter
160+
def symbols_per_block(self, value):
161+
assert isinstance(value, int) and value >= 1,\
162+
"`symbols_per_block` must be a positive integer."
163+
self._symbols_per_block = value
164+
165+
def build(self, input_shape):
166+
num_samples = input_shape[-1]
167+
117168
tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \
118169
/ tf.cast(self.fft_size, tf.float32) \
119170
* tf.range(self.fft_size, dtype=tf.float32)
120171
self._phase_compensation = tf.exp(tf.complex(0., tmp))
121172

122-
# Compute number of elements that will be truncated
123-
self._rest = np.mod(input_shape[-1],
124-
self.fft_size + self.cyclic_prefix_length)
125-
126-
# Compute number of full OFDM symbols to be demodulated
127-
self._num_ofdm_symbols = np.floor_divide(
128-
input_shape[-1]-self._rest,
129-
self.fft_size + self.cyclic_prefix_length)
173+
self._samples_per_block = (self.cyclic_prefix_length_first_symbol +
174+
(self.symbols_per_block - 1) * self.cyclic_prefix_length +
175+
self.symbols_per_block * self.fft_size)
176+
177+
# Compute number of elements that will be truncated and number of
178+
# symbols for padding
179+
self._rest = num_samples % self._samples_per_block
180+
samples_first_symbol = (self.cyclic_prefix_length_first_symbol +
181+
self.fft_size)
182+
samples_other_symbols = (self.cyclic_prefix_length + self.fft_size)
183+
if self._rest > samples_first_symbol:
184+
self._rest -= samples_first_symbol
185+
excess_symbols = self._rest // samples_other_symbols
186+
self._rest -= excess_symbols * samples_other_symbols
187+
excess_symbols += 1 # Because of first symbol in block
188+
self._num_pad_symbols = self.symbols_per_block - excess_symbols
189+
else:
190+
self._num_pad_symbols = 0
130191

131192
def call(self, inputs):
132193
"""Demodulate OFDM waveform onto a resource grid.
@@ -139,14 +200,37 @@ def call(self, inputs):
139200
`tf.complex64` : The demodulated inputs of shape
140201
`[...,num_ofdm_symbols, fft_size]`.
141202
"""
203+
batch_dims = tf.shape(inputs)[:-1]
142204

143205
# Cut last samples that do not fit into an OFDM symbol
144-
inputs = inputs if self._rest==0 else inputs[...,:-self._rest]
206+
x = inputs if self._rest == 0 else inputs[..., :-self._rest]
207+
208+
if self._num_pad_symbols > 0:
209+
pad_samples = self._num_pad_symbols * (self.fft_size +
210+
self.cyclic_prefix_length)
211+
padding_shape = tf.concat([batch_dims, [pad_samples]], axis=0)
212+
padding = tf.zeros(padding_shape, dtype=x.dtype)
213+
x = tf.concat([x, padding], axis=-1)
214+
215+
# Reshape input to blocks
216+
num_blocks = tf.shape(x)[-1] // self._samples_per_block
217+
new_shape = tf.concat([batch_dims,
218+
[num_blocks, self._samples_per_block]], 0)
219+
x = tf.reshape(x, new_shape)
220+
221+
# Remove extra cyclic prefix from first symbol
222+
x = x[...,(self.cyclic_prefix_length_first_symbol -
223+
self.cyclic_prefix_length):]
145224

146225
# Reshape input to separate OFDM symbols
147-
new_shape = tf.concat([tf.shape(inputs)[:-1], [self._num_ofdm_symbols],
226+
new_shape = tf.concat([batch_dims,
227+
[num_blocks * self.symbols_per_block],
148228
[self.fft_size + self.cyclic_prefix_length]], 0)
149-
x = tf.reshape(inputs, new_shape)
229+
x = tf.reshape(x, new_shape)
230+
231+
# Remove padding
232+
if self._num_pad_symbols > 0:
233+
x = x[..., :-self._num_pad_symbols, :]
150234

151235
# Remove cyclic prefix
152236
x = x[...,self.cyclic_prefix_length:]

0 commit comments

Comments
 (0)