|
5 | 5 | Implementation for model metadata and weight loading from various formats. |
6 | 6 | """ |
7 | 7 | import json |
8 | | -import torch |
9 | 8 | import logging |
10 | 9 | import importlib |
11 | 10 |
|
|
27 | 26 | __all__ = ['load_models', 'load_coreml', 'load_safetensors'] |
28 | 27 |
|
29 | 28 |
|
30 | | -# deserializers for coreml layers with weights |
31 | | -def _coreml_lin(spec): |
32 | | - weights = {} |
33 | | - for layer in spec: |
34 | | - if layer.WhichOneof('layer') == 'innerProduct': |
35 | | - name = layer.name.removesuffix('_lin') |
36 | | - lin = layer.innerProduct |
37 | | - weights[f'nn.{name}.lin.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels) |
38 | | - weights[f'nn.{name}.lin.bias'] = torch.Tensor(lin.bias.floatValue) |
39 | | - return weights |
40 | | - |
41 | | - |
42 | | -def _coreml_rnn(spec): |
43 | | - weights = {} |
44 | | - for layer in spec: |
45 | | - if (arch := layer.WhichOneof('layer')) in ['uniDirectionalLSTM', 'biDirectionalLSTM']: |
46 | | - rnn = getattr(layer, arch) |
47 | | - output_size = rnn.outputVectorSize |
48 | | - input_size = rnn.inputVectorSize |
49 | | - name = layer.name.removesuffix('_transposed') |
50 | | - |
51 | | - def _deserialize_weights(params, direction): |
52 | | - # ih_matrix |
53 | | - weight_ih = torch.Tensor([params.inputGateWeightMatrix.floatValue, # wi |
54 | | - params.forgetGateWeightMatrix.floatValue, # wf |
55 | | - params.blockInputWeightMatrix.floatValue, # wz/wg |
56 | | - params.outputGateWeightMatrix.floatValue]) # wo |
57 | | - # hh_matrix |
58 | | - weight_hh = torch.Tensor([params.inputGateRecursionMatrix.floatValue, # wi |
59 | | - params.forgetGateRecursionMatrix.floatValue, # wf |
60 | | - params.blockInputRecursionMatrix.floatValue, # wz/wg |
61 | | - params.outputGateRecursionMatrix.floatValue]) # wo |
62 | | - weights[f'nn.{name}.layer.weight_ih_l0{"_reverse" if direction == "bwd" else ""}'] = weight_ih.view(-1, input_size) |
63 | | - weights[f'nn.{name}.layer.weight_hh_l0{"_reverse" if direction == "bwd" else ""}'] = weight_hh.view(-1, output_size) |
64 | | - biases = torch.Tensor([params.inputGateBiasVector.floatValue, # bi |
65 | | - params.forgetGateBiasVector.floatValue, # bf |
66 | | - params.blockInputBiasVector.floatValue, # bz/bg |
67 | | - params.outputGateBiasVector.floatValue]).view(-1) # bo |
68 | | - weights[f'nn.{name}.layer.bias_hh_l0{"_reverse" if direction == "bwd" else ""}'] = biases |
69 | | - # no ih_biases |
70 | | - weights[f'nn.{name}.layer.bias_ih_l0{"_reverse" if direction == "bwd" else ""}'] = torch.zeros_like(biases) |
71 | | - |
72 | | - fwd_params = rnn.weightParams if arch == 'uniDirectionalLSTM' else rnn.weightParams[0] |
73 | | - _deserialize_weights(fwd_params, 'fwd') |
74 | | - |
75 | | - # get backward weights |
76 | | - if arch == 'biDirectionalLSTM': |
77 | | - _deserialize_weights(rnn.weightParams[1], 'bwd') |
78 | | - return weights |
79 | | - |
80 | | - |
81 | | -def _coreml_conv(spec): |
82 | | - weights = {} |
83 | | - for layer in spec: |
84 | | - if layer.WhichOneof('layer') == 'convolution': |
85 | | - name = layer.name.removesuffix('_conv') |
86 | | - conv = layer.convolution |
87 | | - in_channels = conv.kernelChannels |
88 | | - out_channels = conv.outputChannels |
89 | | - kernel_size = conv.kernelSize |
90 | | - if conv.isDeconvolution: |
91 | | - weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(in_channels, out_channels, *kernel_size) |
92 | | - else: |
93 | | - weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(out_channels, in_channels, *kernel_size) |
94 | | - weights[f'nn.{name}.co.bias'] = torch.Tensor(conv.bias.floatValue) |
95 | | - return weights |
96 | | - |
97 | | - |
98 | | -def _coreml_groupnorm(spec): |
99 | | - weights = {} |
100 | | - for layer in spec: |
101 | | - if layer.WhichOneof('layer') == 'custom' and layer.custom.className == 'groupnorm': |
102 | | - gn = layer.custom |
103 | | - in_channels = gn.parameters['in_channels'].intValue |
104 | | - weights[f'nn.{layer.name}.layer.weight'] = torch.Tensor(gn.weights[0].floatValue).view(in_channels) |
105 | | - weights[f'nn.{layer.name}.layer.bias'] = torch.Tensor(gn.weights[1].floatValue).view(in_channels) |
106 | | - return weights |
107 | | - |
108 | | - |
109 | | -def _coreml_romlp(spec): |
110 | | - weights = {} |
111 | | - return weights |
112 | | - |
113 | | - |
114 | | -def _coreml_wav2vec2mask(spec): |
115 | | - weights = {} |
116 | | - # extract embedding parameters |
117 | | - if len(emb := [x for x in spec if x.name.endswith('_wave2vec2_emb')]): |
118 | | - emb = emb[0].embedding |
119 | | - weights['nn._wave2vec2mask.mask_emb.weight'] = torch.Tensor(emb.weights.floatValue) |
120 | | - # extract linear projection parameters |
121 | | - if len(lin := [x for x in spec if x.name.endswith('_wave2vec2_lin')]): |
122 | | - lin = lin[0].innerProduct |
123 | | - weights['nn._wave2vec2mask.project_q.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels) |
124 | | - weights['nn._wave2vec2mask.project_q.bias'] = torch.Tensor(lin.bias.floatValue) |
125 | | - return weights |
126 | | - |
127 | | - |
128 | | -_coreml_parsers = [_coreml_conv, _coreml_rnn, _coreml_lin, _coreml_groupnorm, |
129 | | - _coreml_wav2vec2mask, _coreml_romlp] |
130 | | - |
131 | | - |
132 | 29 | def load_models(path: Union[str, 'PathLike'], tasks: Optional[Sequence[_T_tasks]] = None) -> list[BaseModel]: |
133 | 30 | """ |
134 | 31 | Tries all loaders in sequence to deserialize models found in file at path. |
@@ -218,16 +115,17 @@ def load_coreml(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tasks]] |
218 | 115 | # construct state dict |
219 | 116 | weights = {} |
220 | 117 | spec = mlmodel.get_spec().neuralNetwork.layers |
| 118 | + from ._coreml import _coreml_parsers |
221 | 119 | for cml_parser in _coreml_parsers: |
222 | 120 | weights.update(cml_parser(spec)) |
223 | 121 |
|
224 | 122 | model.load_state_dict(weights) |
225 | 123 |
|
226 | 124 | # construct additional models if auxiliary layers are defined. |
227 | 125 |
|
228 | | - #if 'aux_layers' in mlmodel.user_defined_metadata: |
229 | | - # logger.info('Deserializing auxiliary layers.') |
| 126 | + # if 'aux_layers' in mlmodel.user_defined_metadata: |
| 127 | + # logger.info('Deserializing auxiliary layers.') |
230 | 128 |
|
231 | | - # nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()} |
| 129 | + # nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()} |
232 | 130 |
|
233 | 131 | return [model] |
0 commit comments