15
15
class OFDMDemodulator (Layer ):
16
16
# pylint: disable=line-too-long
17
17
r"""
18
- OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs)
18
+ OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1 , **kwargs)
19
19
20
20
Computes the frequency-domain representation of an OFDM waveform
21
21
with cyclic prefix removal.
22
22
23
+ When only `cyclic_prefix_length` is given then a cyclic prefix of length
24
+ `cyclic_prefix_length` is removed from each symbol. When additionally
25
+ `cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then
26
+ the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for
27
+ the first symbol of each block and `cyclic_prefix_length` for the
28
+ remaining symbols. For LTE one block corresponds to one slot (i.e., 7
29
+ symbols). For 5G NR one block corresponds to one half subframe and the
30
+ number of symbols depends on the numerology.
31
+
23
32
The demodulator assumes that the input sequence is generated by the
24
33
:class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
25
34
the received signal sequence is given as:
@@ -49,7 +58,7 @@ class OFDMDemodulator(Layer):
49
58
each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`.
50
59
This is a very important step to enable channel estimation with
51
60
sparse pilot patterns that needs to interpolate the channel frequency
52
- response accross subcarriers. It also ensures that the
61
+ response across subcarriers. It also ensures that the
53
62
channel frequency response `seen` by the time-domain channel
54
63
is close to the :class:`~sionna.channel.OFDMChannel`.
55
64
@@ -64,8 +73,16 @@ class OFDMDemodulator(Layer):
64
73
`cir_to_time_channel` function.
65
74
66
75
cyclic_prefix_length : int
67
- Integer indicating the length of the cyclic prefix that
68
- is prepended to each OFDM symbol.
76
+ Integer indicating the length of the cyclic prefix that it prepended
77
+ to each OFDM symbol (except for the first symbol of each block if
78
+ `cyclic_prefix_length_first_symbol` and `symbols per block` is given).
79
+
80
+ cyclic_prefix_length_first_symbol : int
81
+ Integer indicating the length of the cyclic prefix that it prepended
82
+ to the first OFDM symbol of each block.
83
+
84
+ symbols_per_block : int
85
+ Integer indicating the number of symbols per block.
69
86
70
87
Input
71
88
-----
@@ -80,19 +97,25 @@ class OFDMDemodulator(Layer):
80
97
two dimension.
81
98
"""
82
99
83
- def __init__ (self , fft_size , l_min , cyclic_prefix_length = 0 , ** kwargs ):
100
+ def __init__ (self , fft_size , l_min , cyclic_prefix_length = 0 ,
101
+ cyclic_prefix_length_first_symbol = None , symbols_per_block = 1 ,
102
+ ** kwargs ):
84
103
super ().__init__ (** kwargs )
85
104
self .fft_size = fft_size
86
105
self .l_min = l_min
87
106
self .cyclic_prefix_length = cyclic_prefix_length
107
+ self .cyclic_prefix_length_first_symbol = (
108
+ cyclic_prefix_length_first_symbol )
109
+ self .symbols_per_block = symbols_per_block
88
110
89
111
@property
90
112
def fft_size (self ):
91
113
return self ._fft_size
92
114
93
115
@fft_size .setter
94
116
def fft_size (self , value ):
95
- assert value > 0 , "`fft_size` must be positive."
117
+ assert isinstance (value , int ) and value > 0 ,\
118
+ "`fft_size` must be a positive integer."
96
119
self ._fft_size = int (value )
97
120
98
121
@property
@@ -110,23 +133,61 @@ def cyclic_prefix_length(self):
110
133
111
134
@cyclic_prefix_length .setter
112
135
def cyclic_prefix_length (self , value ):
113
- assert value >= 0 , "`cyclic_prefix_length` must be nonnegative."
136
+ assert isinstance (value , int ) and value >= 0 ,\
137
+ "`cyclic_prefix_length` must be a nonnegative integer."
114
138
self ._cyclic_prefix_length = int (value )
115
139
116
- def build (self , input_shape ): # pylint: disable=unused-argument
140
+ @property
141
+ def cyclic_prefix_length_first_symbol (self ):
142
+ if self ._cyclic_prefix_length_first_symbol is None :
143
+ return self ._cyclic_prefix_length
144
+ else :
145
+ return self ._cyclic_prefix_length_first_symbol
146
+
147
+ @cyclic_prefix_length_first_symbol .setter
148
+ def cyclic_prefix_length_first_symbol (self , value ):
149
+ assert (value is None or isinstance (value , int ) and
150
+ value >= self ._cyclic_prefix_length ),\
151
+ ("`cyclic_prefix_length_first_symbol` must be integer and " +
152
+ "larger or equal to `cyclic_prefix_length`." )
153
+ self ._cyclic_prefix_length_first_symbol = value
154
+
155
+ @property
156
+ def symbols_per_block (self ):
157
+ return self ._symbols_per_block
158
+
159
+ @symbols_per_block .setter
160
+ def symbols_per_block (self , value ):
161
+ assert isinstance (value , int ) and value >= 1 ,\
162
+ "`symbols_per_block` must be a positive integer."
163
+ self ._symbols_per_block = value
164
+
165
+ def build (self , input_shape ):
166
+ num_samples = input_shape [- 1 ]
167
+
117
168
tmp = - 2 * PI * tf .cast (self .l_min , tf .float32 ) \
118
169
/ tf .cast (self .fft_size , tf .float32 ) \
119
170
* tf .range (self .fft_size , dtype = tf .float32 )
120
171
self ._phase_compensation = tf .exp (tf .complex (0. , tmp ))
121
172
122
- # Compute number of elements that will be truncated
123
- self ._rest = np .mod (input_shape [- 1 ],
124
- self .fft_size + self .cyclic_prefix_length )
125
-
126
- # Compute number of full OFDM symbols to be demodulated
127
- self ._num_ofdm_symbols = np .floor_divide (
128
- input_shape [- 1 ]- self ._rest ,
129
- self .fft_size + self .cyclic_prefix_length )
173
+ self ._samples_per_block = (self .cyclic_prefix_length_first_symbol +
174
+ (self .symbols_per_block - 1 ) * self .cyclic_prefix_length +
175
+ self .symbols_per_block * self .fft_size )
176
+
177
+ # Compute number of elements that will be truncated and number of
178
+ # symbols for padding
179
+ self ._rest = num_samples % self ._samples_per_block
180
+ samples_first_symbol = (self .cyclic_prefix_length_first_symbol +
181
+ self .fft_size )
182
+ samples_other_symbols = (self .cyclic_prefix_length + self .fft_size )
183
+ if self ._rest > samples_first_symbol :
184
+ self ._rest -= samples_first_symbol
185
+ excess_symbols = self ._rest // samples_other_symbols
186
+ self ._rest -= excess_symbols * samples_other_symbols
187
+ excess_symbols += 1 # Because of first symbol in block
188
+ self ._num_pad_symbols = self .symbols_per_block - excess_symbols
189
+ else :
190
+ self ._num_pad_symbols = 0
130
191
131
192
def call (self , inputs ):
132
193
"""Demodulate OFDM waveform onto a resource grid.
@@ -139,14 +200,37 @@ def call(self, inputs):
139
200
`tf.complex64` : The demodulated inputs of shape
140
201
`[...,num_ofdm_symbols, fft_size]`.
141
202
"""
203
+ batch_dims = tf .shape (inputs )[:- 1 ]
142
204
143
205
# Cut last samples that do not fit into an OFDM symbol
144
- inputs = inputs if self ._rest == 0 else inputs [...,:- self ._rest ]
206
+ x = inputs if self ._rest == 0 else inputs [..., :- self ._rest ]
207
+
208
+ if self ._num_pad_symbols > 0 :
209
+ pad_samples = self ._num_pad_symbols * (self .fft_size +
210
+ self .cyclic_prefix_length )
211
+ padding_shape = tf .concat ([batch_dims , [pad_samples ]], axis = 0 )
212
+ padding = tf .zeros (padding_shape , dtype = x .dtype )
213
+ x = tf .concat ([x , padding ], axis = - 1 )
214
+
215
+ # Reshape input to blocks
216
+ num_blocks = tf .shape (x )[- 1 ] // self ._samples_per_block
217
+ new_shape = tf .concat ([batch_dims ,
218
+ [num_blocks , self ._samples_per_block ]], 0 )
219
+ x = tf .reshape (x , new_shape )
220
+
221
+ # Remove extra cyclic prefix from first symbol
222
+ x = x [...,(self .cyclic_prefix_length_first_symbol -
223
+ self .cyclic_prefix_length ):]
145
224
146
225
# Reshape input to separate OFDM symbols
147
- new_shape = tf .concat ([tf .shape (inputs )[:- 1 ], [self ._num_ofdm_symbols ],
226
+ new_shape = tf .concat ([batch_dims ,
227
+ [num_blocks * self .symbols_per_block ],
148
228
[self .fft_size + self .cyclic_prefix_length ]], 0 )
149
- x = tf .reshape (inputs , new_shape )
229
+ x = tf .reshape (x , new_shape )
230
+
231
+ # Remove padding
232
+ if self ._num_pad_symbols > 0 :
233
+ x = x [..., :- self ._num_pad_symbols , :]
150
234
151
235
# Remove cyclic prefix
152
236
x = x [...,self .cyclic_prefix_length :]
0 commit comments