We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 9a32441 + 787d93f commit d6810a8Copy full SHA for d6810a8
tensordiffeq/utils.py
@@ -30,12 +30,8 @@ def get_weights(model):
30
31
32
def get_sizes(layer_sizes):
33
- sizes_w = []
34
- sizes_b = []
35
- for i, width in enumerate(layer_sizes):
36
- if i != 1:
37
- sizes_w.append(int(width * layer_sizes[1]))
38
- sizes_b.append(int(width if i != 0 else layer_sizes[1]))
+ sizes_w = [layer_sizes[i] * layer_sizes[i - 1] for i in range(len(layer_sizes)) if i != 0]
+ sizes_b = layer_sizes[1:]
39
return sizes_w, sizes_b
40
41
0 commit comments