|
| 1 | +// Copyright 2024 The XLS Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "xls/solvers/z3_extract_counterexample.h" |
| 16 | + |
| 17 | +#include <string> |
| 18 | +#include <string_view> |
| 19 | +#include <utility> |
| 20 | +#include <vector> |
| 21 | + |
| 22 | +#include "absl/base/nullability.h" |
| 23 | +#include "absl/container/flat_hash_map.h" |
| 24 | +#include "absl/status/status.h" |
| 25 | +#include "absl/status/statusor.h" |
| 26 | +#include "absl/strings/ascii.h" |
| 27 | +#include "absl/strings/escaping.h" |
| 28 | +#include "absl/strings/match.h" |
| 29 | +#include "absl/strings/str_format.h" |
| 30 | +#include "absl/strings/str_split.h" |
| 31 | +#include "absl/types/span.h" |
| 32 | +#include "xls/common/status/ret_check.h" |
| 33 | +#include "xls/common/status/status_macros.h" |
| 34 | +#include "xls/ir/bits.h" |
| 35 | +#include "xls/ir/format_preference.h" |
| 36 | +#include "xls/ir/number_parser.h" |
| 37 | +#include "xls/ir/type.h" |
| 38 | +#include "xls/ir/value.h" |
| 39 | + |
| 40 | +namespace xls::solvers::z3 { |
| 41 | +namespace { |
| 42 | + |
| 43 | +absl::StatusOr<Value> Z3ValueToXlsValue(std::string_view z3_value_text, |
| 44 | + const BitsType& bits_type) { |
| 45 | + if (!absl::StartsWith(z3_value_text, "#x")) { |
| 46 | + return absl::InvalidArgumentError( |
| 47 | + absl::StrFormat("Expect Z3 value to start with '#x'; got: %s", |
| 48 | + absl::CEscape(z3_value_text))); |
| 49 | + } |
| 50 | + |
| 51 | + std::string_view bits_text = z3_value_text.substr(2); |
| 52 | + XLS_ASSIGN_OR_RETURN( |
| 53 | + Bits bits, ParseUnsignedNumberWithoutPrefix( |
| 54 | + bits_text, FormatPreference::kHex, bits_type.bit_count())); |
| 55 | + return Value(bits); |
| 56 | +} |
| 57 | + |
| 58 | +} // namespace |
| 59 | + |
| 60 | +absl::StatusOr<absl::flat_hash_map<std::string, Value>> ExtractCounterexample( |
| 61 | + std::string_view message, absl::Span<const IrParamSpec> params) { |
| 62 | + // Split on the opening ``` fence. |
| 63 | + std::vector<std::string_view> pieces = |
| 64 | + absl::StrSplit(message, absl::MaxSplits("```", 1)); |
| 65 | + if (pieces.size() != 2) { |
| 66 | + return absl::InvalidArgumentError( |
| 67 | + absl::StrFormat("Could not find model within solver message: `%s`", |
| 68 | + absl::CEscape(message))); |
| 69 | + } |
| 70 | + |
| 71 | + // Split against the closing ``` fence, the model data resides in the middle. |
| 72 | + pieces = absl::StrSplit(pieces[1], absl::MaxSplits("```", 1)); |
| 73 | + if (pieces.size() != 2) { |
| 74 | + return absl::InvalidArgumentError(absl::StrFormat( |
| 75 | + "Could not find model (closing fence) within solver message: `%s`", |
| 76 | + absl::CEscape(message))); |
| 77 | + } |
| 78 | + std::string_view model_text = pieces[0]; |
| 79 | + |
| 80 | + // Since the parameters can be presented out of order we build up a map from |
| 81 | + // parameter name to the spec (which includes the type) for easy lookup in |
| 82 | + // arbitrary order. |
| 83 | + absl::flat_hash_map<std::string, absl::Nonnull<const IrParamSpec*>> |
| 84 | + name_to_spec; |
| 85 | + for (const IrParamSpec& spec : params) { |
| 86 | + auto [it, inserted] = name_to_spec.insert({spec.name, &spec}); |
| 87 | + XLS_RET_CHECK(inserted); |
| 88 | + } |
| 89 | + |
| 90 | + // Accumulate values as we iterate through lines -- note that when there are |
| 91 | + // more than one parameter there will be more than one line specifying data. |
| 92 | + absl::flat_hash_map<std::string, Value> results; |
| 93 | + results.reserve(params.size()); |
| 94 | + |
| 95 | + // Go through each line, skipping empty ones, to ensure we understand all the |
| 96 | + // data the model presents. |
| 97 | + std::vector<std::string_view> model_lines = absl::StrSplit(model_text, '\n'); |
| 98 | + for (std::string_view line : model_lines) { |
| 99 | + if (line.empty()) { |
| 100 | + continue; |
| 101 | + } |
| 102 | + |
| 103 | + // On each line split the parameter name from the data so we can parse the |
| 104 | + // data out. |
| 105 | + std::vector<std::string_view> sides = |
| 106 | + absl::StrSplit(line, absl::MaxSplits(" -> ", 1)); |
| 107 | + if (sides.size() != 2) { |
| 108 | + return absl::InvalidArgumentError(absl::StrFormat( |
| 109 | + "Could not parse line in solver model: `%s`", absl::CEscape(line))); |
| 110 | + } |
| 111 | + |
| 112 | + // Ensure the parameter name matches our understanding of what parameter |
| 113 | + // we're currently processing (since we're placing it in a vector directly). |
| 114 | + std::string_view param_name = sides[0]; |
| 115 | + param_name = absl::StripAsciiWhitespace(param_name); |
| 116 | + |
| 117 | + auto it = name_to_spec.find(param_name); |
| 118 | + if (it == name_to_spec.end()) { |
| 119 | + return absl::InvalidArgumentError( |
| 120 | + absl::StrFormat("ExtractCounterexample; could not find parameter " |
| 121 | + "name from model in user-provided spec: `%s`", |
| 122 | + param_name)); |
| 123 | + } |
| 124 | + |
| 125 | + absl::Nonnull<const IrParamSpec*> spec = it->second; |
| 126 | + |
| 127 | + // Use the type to guide how we parse out the solver-presented data item -- |
| 128 | + // right now it must be a bits type, we give an unimplemented error for |
| 129 | + // unsupported types. |
| 130 | + absl::Nonnull<const Type*> type = spec->type; |
| 131 | + auto* bits_type = dynamic_cast<const BitsType*>(type); |
| 132 | + if (bits_type == nullptr) { |
| 133 | + return absl::UnimplementedError( |
| 134 | + absl::StrFormat("ExtractCounterexample; only bits-typed parameters " |
| 135 | + "are currently supported; got: %s", |
| 136 | + type->ToString())); |
| 137 | + } |
| 138 | + XLS_ASSIGN_OR_RETURN(Value value, Z3ValueToXlsValue(sides[1], *bits_type)); |
| 139 | + |
| 140 | + auto it2 = results.find(param_name); |
| 141 | + if (it2 != results.end()) { |
| 142 | + return absl::InvalidArgumentError(absl::StrFormat( |
| 143 | + "ExtractCounterexample; saw duplicate param value for `%s`: %s", |
| 144 | + param_name, absl::CEscape(model_text))); |
| 145 | + } |
| 146 | + results.emplace_hint(it2, std::string{param_name}, std::move(value)); |
| 147 | + } |
| 148 | + |
| 149 | + // Validate that we populated the right number of values for the parameter |
| 150 | + // specification. |
| 151 | + if (results.size() != params.size()) { |
| 152 | + return absl::InvalidArgumentError(absl::StrFormat( |
| 153 | + "Z3 solver model counterexample did not include all parameters: %s", |
| 154 | + absl::CEscape(model_text))); |
| 155 | + } |
| 156 | + |
| 157 | + return results; |
| 158 | +} |
| 159 | + |
| 160 | +} // namespace xls::solvers::z3 |
0 commit comments