Skip to content

Commit 9d319ff

Browse files
authored
Fix the keras_hub package for typecheckers and IDEs (#2222)
Mirror of keras-team/keras#21187 for Keras Hub.
1 parent 3fb6a24 commit 9d319ff

20 files changed

+779
-623
lines changed

api_gen.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import namex
1313

14-
package = "keras_hub"
14+
PACKAGE = "keras_hub"
15+
BUILD_DIR_NAME = "tmp_build_dir"
1516

1617

1718
def ignore_files(_, filenames):
@@ -20,50 +21,47 @@ def ignore_files(_, filenames):
2021

2122
def copy_source_to_build_directory(root_path):
2223
# Copy sources (`keras_hub/` directory and setup files) to build dir
23-
build_dir = os.path.join(root_path, "tmp_build_dir")
24+
build_dir = os.path.join(root_path, BUILD_DIR_NAME)
25+
build_package_dir = os.path.join(build_dir, PACKAGE)
26+
build_src_dir = os.path.join(build_package_dir, "src")
27+
root_src_dir = os.path.join(root_path, PACKAGE, "src")
2428
if os.path.exists(build_dir):
2529
shutil.rmtree(build_dir)
26-
os.mkdir(build_dir)
27-
shutil.copytree(
28-
package, os.path.join(build_dir, package), ignore=ignore_files
29-
)
30+
os.makedirs(build_package_dir)
31+
shutil.copytree(root_src_dir, build_src_dir)
3032
return build_dir
3133

3234

3335
def export_version_string(api_init_fname):
3436
with open(api_init_fname) as f:
3537
contents = f.read()
3638
with open(api_init_fname, "w") as f:
37-
contents += "from keras_hub.src.version_utils import __version__\n"
39+
contents += "from keras_hub.src.version_utils import __version__ as __version__\n" # noqa: E501
3840
f.write(contents)
3941

4042

4143
def build():
42-
# Backup the `keras_hub/__init__.py` and restore it on error in api gen.
4344
root_path = os.path.dirname(os.path.abspath(__file__))
44-
code_api_dir = os.path.join(root_path, package, "api")
45+
code_api_dir = os.path.join(root_path, PACKAGE, "api")
4546
# Create temp build dir
4647
build_dir = copy_source_to_build_directory(root_path)
47-
build_api_dir = os.path.join(build_dir, package, "api")
48-
build_init_fname = os.path.join(build_dir, package, "__init__.py")
48+
build_api_dir = os.path.join(build_dir, PACKAGE)
49+
build_src_dir = os.path.join(build_api_dir, "src")
4950
build_api_init_fname = os.path.join(build_api_dir, "__init__.py")
5051
try:
5152
os.chdir(build_dir)
52-
# Generates `keras_hub/api` directory.
53-
if os.path.exists(build_api_dir):
54-
shutil.rmtree(build_api_dir)
55-
if os.path.exists(build_init_fname):
56-
os.remove(build_init_fname)
57-
os.makedirs(build_api_dir)
58-
namex.generate_api_files(
59-
"keras_hub", code_directory="src", target_directory="api"
60-
)
61-
# Add __version__ to keras package
53+
open(build_api_init_fname, "w").close()
54+
namex.generate_api_files("keras_hub", code_directory="src")
55+
# Add __version__ to `api/`.
6256
export_version_string(build_api_init_fname)
63-
# Copy back the keras_hub/api and keras_hub/__init__.py from build dir
57+
# Copy back the keras/api from build directory
58+
if os.path.exists(build_src_dir):
59+
shutil.rmtree(build_src_dir)
6460
if os.path.exists(code_api_dir):
6561
shutil.rmtree(code_api_dir)
66-
shutil.copytree(build_api_dir, code_api_dir)
62+
shutil.copytree(
63+
build_api_dir, code_api_dir, ignore=shutil.ignore_patterns("src/")
64+
)
6765
finally:
6866
# Clean up: remove the build directory (no longer needed)
6967
shutil.rmtree(build_dir)

keras_hub/__init__.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,13 @@
1-
import os
1+
# This file should NEVER be packaged! This is a hack to make "import keras_hub"
2+
# from the base of the repo import the api correctly. We'll keep it for compat.
23

3-
# sentencepiece segfaults on some version of tensorflow if tf is imported first.
4-
try:
5-
import sentencepiece
6-
except ImportError:
7-
pass
8-
9-
# Import everything from /api/ into keras.
10-
from keras_hub.api import * # noqa: F403
11-
from keras_hub.api import __version__ # Import * ignores names start with "_".
4+
import os # isort: skip
125

136
# Add everything in /api/ to the module search path.
147
__path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405
158

9+
from keras_hub.api import * # noqa: F403, E402
10+
from keras_hub.api import __version__ # noqa: E402
11+
1612
# Don't pollute namespace.
1713
del os
18-
19-
20-
# Never autocomplete `.src` or `.api` on an imported keras object.
21-
def __dir__():
22-
keys = dict.fromkeys((globals().keys()))
23-
keys.pop("src")
24-
keys.pop("api")
25-
return list(keys)
26-
27-
28-
# Don't import `.src` or `.api` during `from keras import *`.
29-
__all__ = [
30-
name
31-
for name in globals().keys()
32-
if not (name.startswith("_") or name in ("src", "api"))
33-
]

keras_hub/api/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras_hub.api import layers
8-
from keras_hub.api import metrics
9-
from keras_hub.api import models
10-
from keras_hub.api import samplers
11-
from keras_hub.api import tokenizers
12-
from keras_hub.api import utils
13-
from keras_hub.src.utils.preset_utils import upload_preset
14-
from keras_hub.src.version_utils import __version__
15-
from keras_hub.src.version_utils import version
7+
from keras_hub import layers as layers
8+
from keras_hub import metrics as metrics
9+
from keras_hub import models as models
10+
from keras_hub import samplers as samplers
11+
from keras_hub import tokenizers as tokenizers
12+
from keras_hub import utils as utils
13+
from keras_hub.src.utils.preset_utils import upload_preset as upload_preset
14+
from keras_hub.src.version_utils import __version__ as __version__
15+
from keras_hub.src.version_utils import version as version

keras_hub/api/layers/__init__.py

+85-43
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,128 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras_hub.src.layers.modeling.alibi_bias import AlibiBias
8-
from keras_hub.src.layers.modeling.anchor_generator import AnchorGenerator
9-
from keras_hub.src.layers.modeling.box_matcher import BoxMatcher
7+
from keras_hub.src.layers.modeling.alibi_bias import AlibiBias as AlibiBias
8+
from keras_hub.src.layers.modeling.anchor_generator import (
9+
AnchorGenerator as AnchorGenerator,
10+
)
11+
from keras_hub.src.layers.modeling.box_matcher import BoxMatcher as BoxMatcher
1012
from keras_hub.src.layers.modeling.cached_multi_head_attention import (
11-
CachedMultiHeadAttention,
13+
CachedMultiHeadAttention as CachedMultiHeadAttention,
14+
)
15+
from keras_hub.src.layers.modeling.f_net_encoder import (
16+
FNetEncoder as FNetEncoder,
17+
)
18+
from keras_hub.src.layers.modeling.masked_lm_head import (
19+
MaskedLMHead as MaskedLMHead,
20+
)
21+
from keras_hub.src.layers.modeling.non_max_supression import (
22+
NonMaxSuppression as NonMaxSuppression,
23+
)
24+
from keras_hub.src.layers.modeling.position_embedding import (
25+
PositionEmbedding as PositionEmbedding,
1226
)
13-
from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder
14-
from keras_hub.src.layers.modeling.masked_lm_head import MaskedLMHead
15-
from keras_hub.src.layers.modeling.non_max_supression import NonMaxSuppression
16-
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
1727
from keras_hub.src.layers.modeling.reversible_embedding import (
18-
ReversibleEmbedding,
28+
ReversibleEmbedding as ReversibleEmbedding,
29+
)
30+
from keras_hub.src.layers.modeling.rms_normalization import (
31+
RMSNormalization as RMSNormalization,
32+
)
33+
from keras_hub.src.layers.modeling.rotary_embedding import (
34+
RotaryEmbedding as RotaryEmbedding,
1935
)
20-
from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
21-
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
2236
from keras_hub.src.layers.modeling.sine_position_encoding import (
23-
SinePositionEncoding,
37+
SinePositionEncoding as SinePositionEncoding,
2438
)
2539
from keras_hub.src.layers.modeling.token_and_position_embedding import (
26-
TokenAndPositionEmbedding,
40+
TokenAndPositionEmbedding as TokenAndPositionEmbedding,
41+
)
42+
from keras_hub.src.layers.modeling.transformer_decoder import (
43+
TransformerDecoder as TransformerDecoder,
44+
)
45+
from keras_hub.src.layers.modeling.transformer_encoder import (
46+
TransformerEncoder as TransformerEncoder,
47+
)
48+
from keras_hub.src.layers.preprocessing.audio_converter import (
49+
AudioConverter as AudioConverter,
50+
)
51+
from keras_hub.src.layers.preprocessing.image_converter import (
52+
ImageConverter as ImageConverter,
2753
)
28-
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
29-
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
30-
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
31-
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3254
from keras_hub.src.layers.preprocessing.masked_lm_mask_generator import (
33-
MaskedLMMaskGenerator,
55+
MaskedLMMaskGenerator as MaskedLMMaskGenerator,
3456
)
3557
from keras_hub.src.layers.preprocessing.multi_segment_packer import (
36-
MultiSegmentPacker,
58+
MultiSegmentPacker as MultiSegmentPacker,
59+
)
60+
from keras_hub.src.layers.preprocessing.random_deletion import (
61+
RandomDeletion as RandomDeletion,
62+
)
63+
from keras_hub.src.layers.preprocessing.random_swap import (
64+
RandomSwap as RandomSwap,
65+
)
66+
from keras_hub.src.layers.preprocessing.start_end_packer import (
67+
StartEndPacker as StartEndPacker,
3768
)
38-
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
39-
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
40-
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
4169
from keras_hub.src.models.basnet.basnet_image_converter import (
42-
BASNetImageConverter,
70+
BASNetImageConverter as BASNetImageConverter,
71+
)
72+
from keras_hub.src.models.clip.clip_image_converter import (
73+
CLIPImageConverter as CLIPImageConverter,
4374
)
44-
from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
4575
from keras_hub.src.models.cspnet.cspnet_image_converter import (
46-
CSPNetImageConverter,
76+
CSPNetImageConverter as CSPNetImageConverter,
4777
)
4878
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
49-
DeepLabV3ImageConverter,
79+
DeepLabV3ImageConverter as DeepLabV3ImageConverter,
5080
)
5181
from keras_hub.src.models.densenet.densenet_image_converter import (
52-
DenseNetImageConverter,
82+
DenseNetImageConverter as DenseNetImageConverter,
5383
)
5484
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
55-
EfficientNetImageConverter,
85+
EfficientNetImageConverter as EfficientNetImageConverter,
5686
)
5787
from keras_hub.src.models.gemma3.gemma3_image_converter import (
58-
Gemma3ImageConverter,
88+
Gemma3ImageConverter as Gemma3ImageConverter,
89+
)
90+
from keras_hub.src.models.mit.mit_image_converter import (
91+
MiTImageConverter as MiTImageConverter,
5992
)
60-
from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
6193
from keras_hub.src.models.mobilenet.mobilenet_image_converter import (
62-
MobileNetImageConverter,
94+
MobileNetImageConverter as MobileNetImageConverter,
6395
)
6496
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
65-
PaliGemmaImageConverter,
97+
PaliGemmaImageConverter as PaliGemmaImageConverter,
6698
)
6799
from keras_hub.src.models.resnet.resnet_image_converter import (
68-
ResNetImageConverter,
100+
ResNetImageConverter as ResNetImageConverter,
69101
)
70102
from keras_hub.src.models.retinanet.retinanet_image_converter import (
71-
RetinaNetImageConverter,
103+
RetinaNetImageConverter as RetinaNetImageConverter,
104+
)
105+
from keras_hub.src.models.sam.sam_image_converter import (
106+
SAMImageConverter as SAMImageConverter,
107+
)
108+
from keras_hub.src.models.sam.sam_mask_decoder import (
109+
SAMMaskDecoder as SAMMaskDecoder,
110+
)
111+
from keras_hub.src.models.sam.sam_prompt_encoder import (
112+
SAMPromptEncoder as SAMPromptEncoder,
72113
)
73-
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
74-
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
75-
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
76114
from keras_hub.src.models.segformer.segformer_image_converter import (
77-
SegFormerImageConverter,
115+
SegFormerImageConverter as SegFormerImageConverter,
78116
)
79117
from keras_hub.src.models.siglip.siglip_image_converter import (
80-
SigLIPImageConverter,
118+
SigLIPImageConverter as SigLIPImageConverter,
119+
)
120+
from keras_hub.src.models.vgg.vgg_image_converter import (
121+
VGGImageConverter as VGGImageConverter,
122+
)
123+
from keras_hub.src.models.vit.vit_image_converter import (
124+
ViTImageConverter as ViTImageConverter,
81125
)
82-
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
83-
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
84126
from keras_hub.src.models.whisper.whisper_audio_converter import (
85-
WhisperAudioConverter,
127+
WhisperAudioConverter as WhisperAudioConverter,
86128
)
87129
from keras_hub.src.models.xception.xception_image_converter import (
88-
XceptionImageConverter,
130+
XceptionImageConverter as XceptionImageConverter,
89131
)

keras_hub/api/metrics/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
since your modifications would be overwritten.
55
"""
66

7-
from keras_hub.src.metrics.bleu import Bleu
8-
from keras_hub.src.metrics.edit_distance import EditDistance
9-
from keras_hub.src.metrics.perplexity import Perplexity
10-
from keras_hub.src.metrics.rouge_l import RougeL
11-
from keras_hub.src.metrics.rouge_n import RougeN
7+
from keras_hub.src.metrics.bleu import Bleu as Bleu
8+
from keras_hub.src.metrics.edit_distance import EditDistance as EditDistance
9+
from keras_hub.src.metrics.perplexity import Perplexity as Perplexity
10+
from keras_hub.src.metrics.rouge_l import RougeL as RougeL
11+
from keras_hub.src.metrics.rouge_n import RougeN as RougeN

0 commit comments

Comments
 (0)