|
308 | 308 | " self._spm_processor = spm_processor\n", |
309 | 309 | "\n", |
310 | 310 | " @property\n", |
311 | | - " def pad_id(self) -\u003e int:\n", |
| 311 | + " def pad_id(self) -> int:\n", |
312 | 312 | " \"\"\"Fast access to the pad id.\"\"\"\n", |
313 | 313 | " return self._spm_processor.pad_id()\n", |
314 | 314 | "\n", |
315 | 315 | " def tokenize(self,\n", |
316 | 316 | " example: str | bytes,\n", |
317 | 317 | " prefix: str = '',\n", |
318 | 318 | " suffix: str = '',\n", |
319 | | - " add_eos: bool = True) -\u003e jax.Array:\n", |
| 319 | + " add_eos: bool = True) -> jax.Array:\n", |
320 | 320 | " \"\"\"\n", |
321 | 321 | " Tokenization function.\n", |
322 | 322 | "\n", |
|
340 | 340 | " str_tensor: tf.Tensor,\n", |
341 | 341 | " prefix: str = '',\n", |
342 | 342 | " suffix: str = '',\n", |
343 | | - " add_eos: bool = True) -\u003e tf.Tensor:\n", |
| 343 | + " add_eos: bool = True) -> tf.Tensor:\n", |
344 | 344 | " \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n", |
345 | 345 | " encoded = tf.numpy_function(\n", |
346 | 346 | " self.tokenize,\n", |
|
349 | 349 | " encoded.set_shape([None])\n", |
350 | 350 | " return encoded\n", |
351 | 351 | "\n", |
352 | | - " def to_string(self, tokens: jax.Array) -\u003e str:\n", |
| 352 | + " def to_string(self, tokens: jax.Array) -> str:\n", |
353 | 353 | " \"\"\"Convert an array of tokens to a string.\"\"\"\n", |
354 | 354 | " return self._spm_processor.EncodeIds(tokens.tolist())" |
355 | 355 | ] |
|
396 | 396 | "\n", |
397 | 397 | " def _pad_up_to_max_len(\n", |
398 | 398 | " self, input_tensor: tf.Tensor, pad_value: int | bool\n", |
399 | | - " ) -\u003e tf.Tensor:\n", |
| 399 | + " ) -> tf.Tensor:\n", |
400 | 400 | " \"\"\"Pads the given tensor up to max_seq_len.\"\"\"\n", |
401 | 401 | " seq_len = tf.shape(input_tensor)[0]\n", |
402 | 402 | " to_pad = tf.maximum(0, self._max_seq_len - seq_len)\n", |
|
518 | 518 | " )\n", |
519 | 519 | " ds = ds.map(lambda x, y: self._to_training_input(x, y),\n", |
520 | 520 | " num_parallel_calls=tf.data.AUTOTUNE)\n", |
521 | | - " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", |
| 521 | + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", |
522 | 522 | " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", |
523 | 523 | " return ds" |
524 | 524 | ] |
|
656 | 656 | " )\n", |
657 | 657 | " )\n", |
658 | 658 | " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", |
659 | | - " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", |
| 659 | + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", |
660 | 660 | " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", |
661 | 661 | " return ds" |
662 | 662 | ] |
|
802 | 802 | " )\n", |
803 | 803 | " )\n", |
804 | 804 | " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", |
805 | | - " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", |
| 805 | + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", |
806 | 806 | " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", |
807 | 807 | "\n", |
808 | 808 | " return ds" |
|
949 | 949 | " )\n", |
950 | 950 | " )\n", |
951 | 951 | " ds = ds.map(lambda x, y, z: self._to_training_input(x, y, z))\n", |
952 | | - " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] \u003c= self._max_seq_len)\n", |
| 952 | + " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", |
953 | 953 | " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", |
954 | 954 | "\n", |
955 | 955 | " return ds" |
|
0 commit comments