Skip to content

Commit b4d7fbb

Browse files
authored
Support returning last hidden state (#2461)
1 parent 80da6a5 commit b4d7fbb

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

torchtune/modules/transformer.py

+3
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ def forward(
653653
input_pos=input_pos,
654654
)
655655

656+
if len(self.layers) in self.output_hidden_states:
657+
hidden.append(h)
658+
656659
# shape: [b, seq_len, out_dim]
657660
output = self.unembed(h)
658661

0 commit comments

Comments
 (0)