@@ -13,6 +13,8 @@ pub enum Sampling {
1313 TopK { k : usize , temperature : f64 } ,
1414 TopP { p : f64 , temperature : f64 } ,
1515 TopKThenTopP { k : usize , p : f64 , temperature : f64 } ,
16+ // Note that the rng is not used for the Gumbel-Softmax sampling.
17+ GumbelSoftmax { temperature : f64 } ,
1618}
1719
1820pub struct LogitsProcessor {
@@ -49,6 +51,11 @@ impl LogitsProcessor {
4951 Ok ( next_token)
5052 }
5153
54+ fn sample_gumbel_softmax ( & mut self , logits : & Tensor , temperature : f64 ) -> Result < u32 > {
55+ let sampled = candle_nn:: sampling:: gumbel_softmax ( logits, temperature, candle:: D :: Minus1 ) ?;
56+ sampled. to_vec0 :: < u32 > ( )
57+ }
58+
5259 fn sample_multinomial ( & mut self , prs : & Vec < f32 > ) -> Result < u32 > {
5360 let distr = rand:: distr:: weighted:: WeightedIndex :: new ( prs) . map_err ( Error :: wrap) ?;
5461 let next_token = distr. sample ( & mut self . rng ) as u32 ;
@@ -127,6 +134,9 @@ impl LogitsProcessor {
127134
128135 let next_token = match & self . sampling {
129136 Sampling :: ArgMax => self . sample_argmax ( logits) ?,
137+ Sampling :: GumbelSoftmax { temperature } => {
138+ self . sample_gumbel_softmax ( & logits, * temperature) ?
139+ }
130140 Sampling :: All { temperature } => {
131141 let prs = prs ( * temperature) ?;
132142 self . sample_multinomial ( & prs) ?
0 commit comments