Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ public void SetActiveAdapter(Adapters adapters, string adapterName)
StringUtils.ToUtf8(adapterName)));
}

/// <summary>
/// Toggles the guidance (constrained decoding) on or off.
/// Throws on error.
/// </summary>
/// <param name="enable">true to enable, false to disable</param>
public void ToggleGuidance(bool enable)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorToggleGuidance(_generatorHandle, enable));
}

/// <summary>
/// Returns whether the guidance (constrained decoding) is enabled.
/// </summary>
/// <returns>true if guidance is enabled, false otherwise</returns>
public bool IsGuidanceEnabled()
{
return NativeMethods.OgaGeneratorIsGuidanceEnabled(_generatorHandle);
}

~Generator()
{
Dispose(false);
Expand Down
6 changes: 6 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ internal class NativeLib
byte[] /* const char* */ data,
bool /* boolean */ enable_ff_tokens);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGeneratorToggleGuidance(IntPtr /* OgaGenerator* */ generator, bool /* boolean */ enable);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern bool OgaGeneratorIsGuidanceEnabled(IntPtr /* const OgaGenerator* */ generator);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyGenerator(IntPtr /* OgaGenerator* */ generator);

Expand Down
12 changes: 10 additions & 2 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,14 @@ GeneratorParams::GeneratorParams(const Model& model)
}
}

void Generator::ToggleGuidance(bool enable) {
guidance_enabled_ = enable;
}

bool Generator::IsGuidanceEnabled() {
return guidance_enabled_;
}

void GeneratorParams::SetGuidance(std::string_view type, std::string_view data, bool enable_ff_tokens = false) {
guidance_type = type;
guidance_data = data;
Expand Down Expand Up @@ -413,7 +421,7 @@ void Generator::ComputeLogits(DeviceSpan<int32_t> next_tokens) {
if (computed_logits_)
throw std::runtime_error("ComputeLogits called again without calling AppendTokens or GenerateNextToken first");

if (last_action_ == Action::generated && guidance_logits_processor_) {
if (last_action_ == Action::generated && guidance_logits_processor_ && guidance_enabled_) {
auto next_tokens_span = next_tokens.CopyDeviceToCpu();
guidance_logits_processor_->CommitTokens(next_tokens_span);
}
Expand Down Expand Up @@ -511,7 +519,7 @@ void Generator::GenerateNextToken() {
search_->AppendTokens(next_tokens);
ComputeLogits(next_tokens);
}
if (guidance_logits_processor_) {
if (guidance_logits_processor_ && guidance_enabled_) {
auto logits = GetLogits();
guidance_logits_processor_->ProcessLogits(logits);
}
Expand Down
3 changes: 3 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ struct Generator : LeakChecked<Generator> {
void AppendTokens(cpu_span<const int32_t> input_ids);
void GenerateNextToken();
void RewindToLength(size_t new_length); // Rewind state to new_length
void ToggleGuidance(bool enable);
bool IsGuidanceEnabled();
DeviceSpan<float> GetLogits();
void SetLogits(DeviceSpan<float> logits);
void SetRuntimeOption(const char* key, const char* value);
Expand All @@ -116,6 +118,7 @@ struct Generator : LeakChecked<Generator> {

bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio
bool set_extra_inputs_{true}; // Set to false once SetExtraInputs() is called once
bool guidance_enabled_{true}; // Track whether guidance is enabled

private:
DeviceSpan<int32_t> AllocateInputIdsOnDevice(cpu_span<const int32_t> input_ids);
Expand Down
8 changes: 8 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ struct OgaGenerator : OgaAbstract {
OgaCheckResult(OgaCreateGenerator(&model, &params, &p));
return std::unique_ptr<OgaGenerator>(p);
}

void ToggleGuidance(bool enable) {
OgaCheckResult(OgaGeneratorToggleGuidance(this, enable));
}

bool IsGuidanceEnabled() {
return OgaGeneratorIsGuidanceEnabled(this);
}

bool IsDone() {
return OgaGenerator_IsDone(this);
Expand Down
11 changes: 11 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,17 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorToggleGuidance(OgaGenerator* generator, bool enable) {
OGA_TRY
generator->ToggleGuidance(enable);
return nullptr;
OGA_CATCH
}

bool OGA_API_CALL OgaGeneratorIsGuidanceEnabled(OgaGenerator* generator) {
return generator->IsGuidanceEnabled();
}

OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out) {
OGA_TRY
*out = ReturnUnique<OgaGenerator>(CreateGenerator(*model, *params));
Expand Down
16 changes: 16 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,22 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, con
*/
OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);

/**
* \brief Returns true if guidance (constrained decoding) is enabled for the generator.
* \param[in] generator The generator to check if guidance is enabled.
* \param[out] out True if guidance is enabled, false otherwise.
* \return OgaResult containing the error message if the checking of guidance status failed.
*/
OGA_EXPORT bool OGA_API_CALL OgaGeneratorIsGuidanceEnabled(OgaGenerator* generator);

/**
* \brief Toggles guidance (constrained decoding) for the generator.
* \param[in] generator The generator to toggle guidance on.
* \param[in] enable True to enable guidance, false to disable.
* \return OgaResult containing the error message if toggling guidance failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorToggleGuidance(OgaGenerator* generator, bool enable);

/**
* \brief Returns true if the generator has finished generating all the sequences.
* \param[in] generator The generator to check if it is done with generating all sequences.
Expand Down
Loading