Skip to content

Commit aeeecef

Browse files
committed
lazily load torch in coreml parser
1 parent 4a235ff commit aeeecef

File tree

3 files changed

+115
-106
lines changed

3 files changed

+115
-106
lines changed

kraken/models/_coreml.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
kraken.models.loaders
3+
~~~~~~~~~~~~~~~~~~~~~~~~~
4+
5+
Implementation for model metadata and weight loading from various formats.
6+
"""
7+
import torch
8+
9+
10+
def _coreml_lin(spec):
11+
weights = {}
12+
for layer in spec:
13+
if layer.WhichOneof('layer') == 'innerProduct':
14+
name = layer.name.removesuffix('_lin')
15+
lin = layer.innerProduct
16+
weights[f'nn.{name}.lin.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels)
17+
weights[f'nn.{name}.lin.bias'] = torch.Tensor(lin.bias.floatValue)
18+
return weights
19+
20+
21+
def _coreml_rnn(spec):
22+
weights = {}
23+
for layer in spec:
24+
if (arch := layer.WhichOneof('layer')) in ['uniDirectionalLSTM', 'biDirectionalLSTM']:
25+
rnn = getattr(layer, arch)
26+
output_size = rnn.outputVectorSize
27+
input_size = rnn.inputVectorSize
28+
name = layer.name.removesuffix('_transposed')
29+
30+
def _deserialize_weights(params, direction):
31+
# ih_matrix
32+
weight_ih = torch.Tensor([params.inputGateWeightMatrix.floatValue, # wi
33+
params.forgetGateWeightMatrix.floatValue, # wf
34+
params.blockInputWeightMatrix.floatValue, # wz/wg
35+
params.outputGateWeightMatrix.floatValue]) # wo
36+
# hh_matrix
37+
weight_hh = torch.Tensor([params.inputGateRecursionMatrix.floatValue, # wi
38+
params.forgetGateRecursionMatrix.floatValue, # wf
39+
params.blockInputRecursionMatrix.floatValue, # wz/wg
40+
params.outputGateRecursionMatrix.floatValue]) # wo
41+
weights[f'nn.{name}.layer.weight_ih_l0{"_reverse" if direction == "bwd" else ""}'] = weight_ih.view(-1, input_size)
42+
weights[f'nn.{name}.layer.weight_hh_l0{"_reverse" if direction == "bwd" else ""}'] = weight_hh.view(-1, output_size)
43+
biases = torch.Tensor([params.inputGateBiasVector.floatValue, # bi
44+
params.forgetGateBiasVector.floatValue, # bf
45+
params.blockInputBiasVector.floatValue, # bz/bg
46+
params.outputGateBiasVector.floatValue]).view(-1) # bo
47+
weights[f'nn.{name}.layer.bias_hh_l0{"_reverse" if direction == "bwd" else ""}'] = biases
48+
# no ih_biases
49+
weights[f'nn.{name}.layer.bias_ih_l0{"_reverse" if direction == "bwd" else ""}'] = torch.zeros_like(biases)
50+
51+
fwd_params = rnn.weightParams if arch == 'uniDirectionalLSTM' else rnn.weightParams[0]
52+
_deserialize_weights(fwd_params, 'fwd')
53+
54+
# get backward weights
55+
if arch == 'biDirectionalLSTM':
56+
_deserialize_weights(rnn.weightParams[1], 'bwd')
57+
return weights
58+
59+
60+
def _coreml_conv(spec):
61+
weights = {}
62+
for layer in spec:
63+
if layer.WhichOneof('layer') == 'convolution':
64+
name = layer.name.removesuffix('_conv')
65+
conv = layer.convolution
66+
in_channels = conv.kernelChannels
67+
out_channels = conv.outputChannels
68+
kernel_size = conv.kernelSize
69+
if conv.isDeconvolution:
70+
weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(in_channels, out_channels, *kernel_size)
71+
else:
72+
weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(out_channels, in_channels, *kernel_size)
73+
weights[f'nn.{name}.co.bias'] = torch.Tensor(conv.bias.floatValue)
74+
return weights
75+
76+
77+
def _coreml_groupnorm(spec):
78+
weights = {}
79+
for layer in spec:
80+
if layer.WhichOneof('layer') == 'custom' and layer.custom.className == 'groupnorm':
81+
gn = layer.custom
82+
in_channels = gn.parameters['in_channels'].intValue
83+
weights[f'nn.{layer.name}.layer.weight'] = torch.Tensor(gn.weights[0].floatValue).view(in_channels)
84+
weights[f'nn.{layer.name}.layer.bias'] = torch.Tensor(gn.weights[1].floatValue).view(in_channels)
85+
return weights
86+
87+
88+
def _coreml_romlp(spec):
89+
weights = {}
90+
return weights
91+
92+
93+
def _coreml_wav2vec2mask(spec):
94+
weights = {}
95+
# extract embedding parameters
96+
if len(emb := [x for x in spec if x.name.endswith('_wave2vec2_emb')]):
97+
emb = emb[0].embedding
98+
weights['nn._wave2vec2mask.mask_emb.weight'] = torch.Tensor(emb.weights.floatValue)
99+
# extract linear projection parameters
100+
if len(lin := [x for x in spec if x.name.endswith('_wave2vec2_lin')]):
101+
lin = lin[0].innerProduct
102+
weights['nn._wave2vec2mask.project_q.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels)
103+
weights['nn._wave2vec2mask.project_q.bias'] = torch.Tensor(lin.bias.floatValue)
104+
return weights
105+
106+
107+
_coreml_parsers = [_coreml_conv, _coreml_rnn, _coreml_lin, _coreml_groupnorm,
108+
_coreml_wav2vec2mask, _coreml_romlp]

kraken/models/configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class TrainingConfig(Config):
270270
Evaluation and checkpoint saving frequency
271271
checkpoint_path (PathLike, defaults to `model`):
272272
Path prefix to save checkpoints during training.
273+
weights_format (Literal[safetensors, coreml], defaults to 'safetensors'):
274+
Weight format to convert checkpoint at end of training to.
273275
274276
> Optimizer configuration
275277
@@ -318,6 +320,7 @@ def __init__(self, **kwargs):
318320
self.completed_epochs = kwargs.pop('completed_epochs', 0)
319321
self.freq = kwargs.pop('freq', 1.0)
320322
self.checkpoint_path = kwargs.pop('checkpoint_path', 'model')
323+
self.weights_format = kwargs.pop('weights_format', 'safetensors')
321324
self.optimizer = kwargs.pop('optimizer', 'AdamW')
322325
self.lrate = kwargs.pop('lrate', 1e-5)
323326
self.momentum = kwargs.pop('momentum', 0.9)

kraken/models/loaders.py

Lines changed: 4 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
Implementation for model metadata and weight loading from various formats.
66
"""
77
import json
8-
import torch
98
import logging
109
import importlib
1110

@@ -27,108 +26,6 @@
2726
__all__ = ['load_models', 'load_coreml', 'load_safetensors']
2827

2928

30-
# deserializers for coreml layers with weights
31-
def _coreml_lin(spec):
32-
weights = {}
33-
for layer in spec:
34-
if layer.WhichOneof('layer') == 'innerProduct':
35-
name = layer.name.removesuffix('_lin')
36-
lin = layer.innerProduct
37-
weights[f'nn.{name}.lin.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels)
38-
weights[f'nn.{name}.lin.bias'] = torch.Tensor(lin.bias.floatValue)
39-
return weights
40-
41-
42-
def _coreml_rnn(spec):
43-
weights = {}
44-
for layer in spec:
45-
if (arch := layer.WhichOneof('layer')) in ['uniDirectionalLSTM', 'biDirectionalLSTM']:
46-
rnn = getattr(layer, arch)
47-
output_size = rnn.outputVectorSize
48-
input_size = rnn.inputVectorSize
49-
name = layer.name.removesuffix('_transposed')
50-
51-
def _deserialize_weights(params, direction):
52-
# ih_matrix
53-
weight_ih = torch.Tensor([params.inputGateWeightMatrix.floatValue, # wi
54-
params.forgetGateWeightMatrix.floatValue, # wf
55-
params.blockInputWeightMatrix.floatValue, # wz/wg
56-
params.outputGateWeightMatrix.floatValue]) # wo
57-
# hh_matrix
58-
weight_hh = torch.Tensor([params.inputGateRecursionMatrix.floatValue, # wi
59-
params.forgetGateRecursionMatrix.floatValue, # wf
60-
params.blockInputRecursionMatrix.floatValue, # wz/wg
61-
params.outputGateRecursionMatrix.floatValue]) # wo
62-
weights[f'nn.{name}.layer.weight_ih_l0{"_reverse" if direction == "bwd" else ""}'] = weight_ih.view(-1, input_size)
63-
weights[f'nn.{name}.layer.weight_hh_l0{"_reverse" if direction == "bwd" else ""}'] = weight_hh.view(-1, output_size)
64-
biases = torch.Tensor([params.inputGateBiasVector.floatValue, # bi
65-
params.forgetGateBiasVector.floatValue, # bf
66-
params.blockInputBiasVector.floatValue, # bz/bg
67-
params.outputGateBiasVector.floatValue]).view(-1) # bo
68-
weights[f'nn.{name}.layer.bias_hh_l0{"_reverse" if direction == "bwd" else ""}'] = biases
69-
# no ih_biases
70-
weights[f'nn.{name}.layer.bias_ih_l0{"_reverse" if direction == "bwd" else ""}'] = torch.zeros_like(biases)
71-
72-
fwd_params = rnn.weightParams if arch == 'uniDirectionalLSTM' else rnn.weightParams[0]
73-
_deserialize_weights(fwd_params, 'fwd')
74-
75-
# get backward weights
76-
if arch == 'biDirectionalLSTM':
77-
_deserialize_weights(rnn.weightParams[1], 'bwd')
78-
return weights
79-
80-
81-
def _coreml_conv(spec):
82-
weights = {}
83-
for layer in spec:
84-
if layer.WhichOneof('layer') == 'convolution':
85-
name = layer.name.removesuffix('_conv')
86-
conv = layer.convolution
87-
in_channels = conv.kernelChannels
88-
out_channels = conv.outputChannels
89-
kernel_size = conv.kernelSize
90-
if conv.isDeconvolution:
91-
weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(in_channels, out_channels, *kernel_size)
92-
else:
93-
weights[f'nn.{name}.co.weight'] = torch.Tensor(conv.weights.floatValue).view(out_channels, in_channels, *kernel_size)
94-
weights[f'nn.{name}.co.bias'] = torch.Tensor(conv.bias.floatValue)
95-
return weights
96-
97-
98-
def _coreml_groupnorm(spec):
99-
weights = {}
100-
for layer in spec:
101-
if layer.WhichOneof('layer') == 'custom' and layer.custom.className == 'groupnorm':
102-
gn = layer.custom
103-
in_channels = gn.parameters['in_channels'].intValue
104-
weights[f'nn.{layer.name}.layer.weight'] = torch.Tensor(gn.weights[0].floatValue).view(in_channels)
105-
weights[f'nn.{layer.name}.layer.bias'] = torch.Tensor(gn.weights[1].floatValue).view(in_channels)
106-
return weights
107-
108-
109-
def _coreml_romlp(spec):
110-
weights = {}
111-
return weights
112-
113-
114-
def _coreml_wav2vec2mask(spec):
115-
weights = {}
116-
# extract embedding parameters
117-
if len(emb := [x for x in spec if x.name.endswith('_wave2vec2_emb')]):
118-
emb = emb[0].embedding
119-
weights['nn._wave2vec2mask.mask_emb.weight'] = torch.Tensor(emb.weights.floatValue)
120-
# extract linear projection parameters
121-
if len(lin := [x for x in spec if x.name.endswith('_wave2vec2_lin')]):
122-
lin = lin[0].innerProduct
123-
weights['nn._wave2vec2mask.project_q.weight'] = torch.Tensor(lin.weights.floatValue).view(lin.outputChannels, lin.inputChannels)
124-
weights['nn._wave2vec2mask.project_q.bias'] = torch.Tensor(lin.bias.floatValue)
125-
return weights
126-
127-
128-
_coreml_parsers = [_coreml_conv, _coreml_rnn, _coreml_lin, _coreml_groupnorm,
129-
_coreml_wav2vec2mask, _coreml_romlp]
130-
131-
13229
def load_models(path: Union[str, 'PathLike'], tasks: Optional[Sequence[_T_tasks]] = None) -> list[BaseModel]:
13330
"""
13431
Tries all loaders in sequence to deserialize models found in file at path.
@@ -218,16 +115,17 @@ def load_coreml(path: Union[str, PathLike], tasks: Optional[Sequence[_T_tasks]]
218115
# construct state dict
219116
weights = {}
220117
spec = mlmodel.get_spec().neuralNetwork.layers
118+
from ._coreml import _coreml_parsers
221119
for cml_parser in _coreml_parsers:
222120
weights.update(cml_parser(spec))
223121

224122
model.load_state_dict(weights)
225123

226124
# construct additional models if auxiliary layers are defined.
227125

228-
#if 'aux_layers' in mlmodel.user_defined_metadata:
229-
# logger.info('Deserializing auxiliary layers.')
126+
# if 'aux_layers' in mlmodel.user_defined_metadata:
127+
# logger.info('Deserializing auxiliary layers.')
230128

231-
# nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()}
129+
# nn.aux_layers = {k: cls(v).nn.get_submodule(k) for k, v in json.loads(mlmodel.user_defined_metadata['aux_layers']).items()}
232130

233131
return [model]

0 commit comments

Comments
 (0)