diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1679fbf..2078479 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -81,6 +81,7 @@ jobs: VERSION=$(echo "$DRY_OUTPUT" | grep -o "The next release version is [0-9]\+\.[0-9]\+\.[0-9]\+\(-rc\.[0-9]\+\)\?" | cut -d ' ' -f6) if [ -z "$VERSION" ]; then echo "Error: Could not determine version" + echo "Output: $DRY_OUTPUT" exit 1 fi diff --git a/CHANGELOG.md b/CHANGELOG.md index 7586470..3cd8ea9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,58 @@ All notable changes to this project will be documented in this file. +## [0.9.1-rc.5](https://github.com/inference-gateway/rust-sdk/compare/0.9.1-rc.4...0.9.1-rc.5) (2025-03-21) + +### ♻️ Improvements + +* Improve example for Tool-Use ([bc647da](https://github.com/inference-gateway/rust-sdk/commit/bc647dad7686aecd63a9fa38617dfe3a29d3bdb0)) + +## [0.9.1-rc.4](https://github.com/inference-gateway/rust-sdk/compare/0.9.1-rc.3...0.9.1-rc.4) (2025-03-21) + +### ♻️ Improvements + +* Remove ToolFunctionResponse and update ToolCallResponse to use ChatCompletionMessageToolCallFunction ([e3d29bb](https://github.com/inference-gateway/rust-sdk/commit/e3d29bbd291d84708f4f37e0cdc338829b4cface)) + +### 📚 Documentation + +* Enhance README with updated response types and improved logging ([651720b](https://github.com/inference-gateway/rust-sdk/commit/651720ba50cec350c92bce4f72a353c3d9189d9b)) + +## [0.9.1-rc.3](https://github.com/inference-gateway/rust-sdk/compare/0.9.1-rc.2...0.9.1-rc.3) (2025-03-21) + +### ♻️ Improvements + +* Update API endpoints and logging for model listing ([96fec11](https://github.com/inference-gateway/rust-sdk/commit/96fec11450919679eafca4b46a436713831ccca9)) + +## [0.9.1-rc.2](https://github.com/inference-gateway/rust-sdk/compare/0.9.1-rc.1...0.9.1-rc.2) (2025-03-21) + +### ♻️ Improvements + +* Update model response structures for clarity and consistency ([5ca2eb5](https://github.com/inference-gateway/rust-sdk/commit/5ca2eb5baad76b4ad3aaf405233036cd2e0b539c)) + +### 🔧 Miscellaneous + +* Update dependencies for improved performance and security ([6079995](https://github.com/inference-gateway/rust-sdk/commit/607999577166554c32264327e20da2daa0b8b94e)) + +## [0.9.1-rc.1](https://github.com/inference-gateway/rust-sdk/compare/0.9.0...0.9.1-rc.1) (2025-03-20) + +### ♻️ Improvements + +* Add streaming response structures for chat completion ([0dd7b56](https://github.com/inference-gateway/rust-sdk/commit/0dd7b5693f14d086acdacc7356b34caf73eb7035)) +* Download the OpenAPI spec ([7701059](https://github.com/inference-gateway/rust-sdk/commit/7701059e64ae6073e73803b1ddb4fbc93dc35417)) +* Remove Google provider from the Provider enum ([5a98051](https://github.com/inference-gateway/rust-sdk/commit/5a98051437c3cc6578b7de86ef937ae4736b7184)) +* Remove nullable fields from OpenAPI schema definitions ([461750f](https://github.com/inference-gateway/rust-sdk/commit/461750facb43b70a3112f867cb689711804ca560)) +* Rename default client method to new_default for clarity ([8735999](https://github.com/inference-gateway/rust-sdk/commit/87359996264e6d50a3ef078080794d3254ed9e06)) +* Rename GenerateRequest to CreateChatCompletionRequest and remove ssevents field ([52c4008](https://github.com/inference-gateway/rust-sdk/commit/52c40089ec6c59fce0c9748f8f1c985cd81e6fab)) +* Restructure response to be compatible with OpenAI ([88da800](https://github.com/inference-gateway/rust-sdk/commit/88da80062a23a272be04c54b2d81a667004ca066)) +* Start by implementing default and base_url ([1afdaac](https://github.com/inference-gateway/rust-sdk/commit/1afdaac83a7cccd8405db1eeddb387dc0822260e)) +* Update JSON response format to include new fields for chat completion ([38869a1](https://github.com/inference-gateway/rust-sdk/commit/38869a19293385f5c6660adb5c3b79afddf2b906)) +* Update JSON response structure with new fields for chat completion ([340c02c](https://github.com/inference-gateway/rust-sdk/commit/340c02c006aa397b0ce8e7b512b9bf9c6e244f8a)) +* Update tool call structures for chat completion and enhance argument parsing ([c867978](https://github.com/inference-gateway/rust-sdk/commit/c86797846676a266f477d906f1e9a9a2e12f41cd)) + +### 🐛 Bug Fixes + +* Add output logging for version determination error in release workflow ([d49ba50](https://github.com/inference-gateway/rust-sdk/commit/d49ba50bd4111b28b722ae78657c6660909c0932)) + ## [0.9.0](https://github.com/inference-gateway/rust-sdk/compare/0.8.0...0.9.0) (2025-02-11) ### ✨ Features diff --git a/Cargo.lock b/Cargo.lock index 01f9c8f..2bf7fc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -82,7 +82,7 @@ dependencies = [ "miniz_oxide", "object", "rustc-demangle", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -297,7 +297,19 @@ checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] @@ -601,7 +613,7 @@ dependencies = [ [[package]] name = "inference-gateway-sdk" -version = "0.9.0" +version = "0.9.1-rc.5" dependencies = [ "async-stream", "futures-util", @@ -703,15 +715,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] [[package]] name = "mockito" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "652cd6d169a36eaf9d1e6bce1a221130439a966d7f27858af66a33a66e9c4ee2" +checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48" dependencies = [ "assert-json-diff", "bytes", @@ -827,7 +839,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -860,7 +872,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -881,22 +893,28 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "rand" -version = "0.8.5" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ - "libc", "rand_chacha", "rand_core", + "zerocopy 0.8.23", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -904,11 +922,11 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom", + "getrandom 0.3.2", ] [[package]] @@ -951,9 +969,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64", "bytes", @@ -1003,7 +1021,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -1120,18 +1138,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.217" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -1140,9 +1158,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -1277,7 +1295,7 @@ checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.2.15", "once_cell", "rustix", "windows-sys 0.59.0", @@ -1285,18 +1303,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", @@ -1315,9 +1333,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.43.0" +version = "1.44.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" +checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" dependencies = [ "backtrace", "bytes", @@ -1482,6 +1500,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -1576,34 +1603,39 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "windows-link" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" + [[package]] name = "windows-registry" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ "windows-result", "windows-strings", - "windows-targets", + "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" dependencies = [ - "windows-targets", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result", - "windows-targets", + "windows-link", ] [[package]] @@ -1612,7 +1644,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -1621,7 +1653,7 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -1630,14 +1662,30 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] [[package]] @@ -1646,48 +1694,105 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags", +] + [[package]] name = "write16" version = "1.0.0" @@ -1731,7 +1836,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd97444d05a4328b90e75e503a34bad781f14e28a823ad3557f0750df1ebcbc6" +dependencies = [ + "zerocopy-derive 0.8.23", ] [[package]] @@ -1745,6 +1859,17 @@ dependencies = [ "syn", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6352c01d0edd5db859a63e2605f4ea3183ddbd15e2c4a9e7d32184df75e4f154" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 58933f6..4c2f635 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "inference-gateway-sdk" -version = "0.9.0" +version = "0.9.1-rc.5" edition = "2021" description = "Rust SDK for interacting with various language models through the Inference Gateway" license = "MIT" @@ -14,12 +14,12 @@ categories = ["api-bindings", "web-programming::http-client"] [dependencies] async-stream = "0.3.6" futures-util = "0.3.31" -reqwest = { version = "0.12.12", features = ["json", "stream"] } -serde = { version = "1.0.217", features = ["derive"] } -serde_json = "1.0.138" -thiserror = "2.0.11" -tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } +reqwest = { version = "0.12.15", features = ["json", "stream"] } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +thiserror = "2.0.12" +tokio = { version = "1.44.1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] -mockito = "1.6.1" +mockito = "1.7.0" tokio = { version = "1.43.0", features = ["macros", "rt"] } diff --git a/README.md b/README.md index 4a89473..e47412e 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,11 @@ Here is a full example of how to create a client and interact with the Inference ```rust use inference_gateway_sdk::{ + CreateChatCompletionResponse, GatewayError, InferenceGatewayAPI, InferenceGatewayClient, + ListModelsResponse, Message, Provider, MessageRole @@ -48,24 +50,20 @@ async fn main() -> Result<(), GatewayError> { let client = InferenceGatewayClient::new("http://localhost:8080"); // List all models and all providers - let models = client.list_models().await?; - for provider_models in models { - info!("Provider: {:?}", provider_models.provider); - for model in provider_models.models { - info!("Model: {:?}", model.name); - } + let response: ListModelsResponse = client.list_models().await?; + for model in response.data { + info!("Model: {:?}", model.id); } // List models for a specific provider - let resp = client.list_models_by_provider(Provider::Groq).await?; - let models = resp.models; - info!("Provider: {:?}", resp.provider); - for model in models { - info!("Model: {:?}", model.name); + let response: ListModelsResponse = client.list_models_by_provider(Provider::Groq).await?; + info!("Models for provider: {:?}", response.provider); + for model in response.data { + info!("Model: {:?}", model.id); } // Generate content - choose from available providers and models - let resp = client.generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", vec![ + let response: CreateChatCompletionResponse = client.generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", vec![ Message{ role: MessageRole::System, content: "You are an helpful assistent.".to_string() @@ -76,9 +74,10 @@ async fn main() -> Result<(), GatewayError> { } ]).await?; - log::info!("Generated from provider: {:?}", resp.provider); - log::info!("Generated response: {:?}", resp.response.role); - log::info!("Generated content: {:?}", resp.response.content); + log::info!( + "Generated content: {:?}", + response.choices[0].message.content + ); Ok(()) } @@ -93,6 +92,7 @@ use inference_gateway_sdk::{ GatewayError InferenceGatewayAPI, InferenceGatewayClient, + ListModelsResponse, Message, }; use log::info; @@ -101,13 +101,10 @@ use log::info; fn main() -> Result<(), GatewayError> { // ...Create a client - // List all models and all providers - let models = client.list_models().await?; - for provider_models in models { - info!("Provider: {:?}", provider_models.provider); - for model in provider_models.models { - info!("Model: {:?}", model.name); - } + // List models from all providers + let response: ListModelsResponse = client.list_models().await?; + for model in response.data { + info!("Model: {:?}", model.id); } // ... @@ -123,6 +120,7 @@ use inference_gateway_sdk::{ GatewayError InferenceGatewayAPI, InferenceGatewayClient, + ListModelsResponse, Provider, }; use log::info; @@ -130,11 +128,10 @@ use log::info; // ...Open main function // List models for a specific provider -let resp = client.list_models_by_provider(Provider::Ollama).await?; -let models = resp.models; -info!("Provider: {:?}", resp.provider); -for model in models { - info!("Model: {:?}", model.name); +let response: ListModelsResponse = client.list_models_by_provider(Provider::Groq).await?; +info!("Models for provider: {:?}", response.provider); +for model in response.data { + info!("Model: {:?}", model.id); } // ...Rest of the main function @@ -146,6 +143,7 @@ To generate content using a model, use the `generate_content` method: ```rust use inference_gateway_sdk::{ + CreateChatCompletionResponse, GatewayError, InferenceGatewayAPI, InferenceGatewayClient, @@ -155,20 +153,23 @@ use inference_gateway_sdk::{ }; // Generate content - choose from available providers and models -let resp = client.generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", vec![ +let response: CreateChatCompletionResponse = client.generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", vec![ Message{ role: MessageRole::System, - content: "You are an helpful assistent.".to_string() + content: "You are an helpful assistent.".to_string(), + ..Default::default() }, Message{ role: MessageRole::User, - content: "Tell me a funny joke".to_string() + content: "Tell me a funny joke".to_string(), + ..Default::default() } ]).await?; -log::info!("Generated from provider: {:?}", resp.provider); -log::info!("Generated response: {:?}", resp.response.role); -log::info!("Generated content: {:?}", resp.response.content); +log::info!( + "Generated content: {:?}", + response.choices[0].message.content +); ``` ### Streaming Content @@ -179,51 +180,78 @@ You need to add the following tiny dependencies: - `serde` with feature `derive` and `serde_json` for serialization and deserialization of the response content ```rust +use futures_util::{pin_mut, StreamExt}; use inference_gateway_sdk::{ - InferenceGatewayAPI, - InferenceGatewayClient, Message, MessageRole, Provider, ResponseContent + CreateChatCompletionStreamResponse, GatewayError, InferenceGatewayAPI, InferenceGatewayClient, + Message, MessageRole, Provider, }; -use futures_util::{StreamExt, pin_mut}; -// ...rest of the imports +use log::info; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), GatewayError> { + if env::var("RUST_LOG").is_err() { + env::set_var("RUST_LOG", "info"); + } + env_logger::init(); -// ...main function let system_message = "You are an helpful assistent.".to_string(); let model = "deepseek-r1-distill-llama-70b"; - let messages = vec![ - Message { - role: MessageRole::System, - content: system_message, - tool_call_id: None - }, - Message { - role: MessageRole::User, - content: "Write a poem".to_string(), - tool_call_id: None - }, - ]; - let client = InferenceGatewayClient::new("http://localhost:8080"); - let stream = client.generate_content_stream(Provider::Groq, model, messages); + + let client = InferenceGatewayClient::new("http://localhost:8080/v1"); + let stream = client.generate_content_stream( + Provider::Groq, + model, + vec![ + Message { + role: MessageRole::System, + content: system_message, + ..Default::default() + }, + Message { + role: MessageRole::User, + content: "Write a poem".to_string(), + ..Default::default() + }, + ], + ); pin_mut!(stream); - let content_delta = Some("content-delta".to_string()); // Iterate over the stream of Server Sent Events while let Some(ssevent) = stream.next().await { - let resp = ssevent?; - - // Only content-delta events contains the actual tokens - // There are also events like: - // - content-start - // - content-end - // - etc.. - if resp.event != content_delta { + let ssevent = ssevent?; + + // Deserialize the event response + let generate_response_stream: CreateChatCompletionStreamResponse = + serde_json::from_str(&ssevent.data)?; + + let choice = generate_response_stream.choices.get(0); + if choice.is_none() { continue; } + let choice = choice.unwrap(); + + if let Some(usage) = generate_response_stream.usage.as_ref() { + // Get the usage metrics from the response + info!("Usage Metrics: {:?}", usage); + // Probably send them over to a metrics service + break; + } - // Deserialize the event response - let generate_response: ResponseContent = serde_json::from_str(&resp.data)?; // Print the token out as it's being sent from the server - print!("{}", generate_response.content); + if let Some(content) = choice.delta.content.as_ref() { + print!("{}", content); + } + + if let Some(finish_reason) = choice.finish_reason.as_ref() { + if finish_reason == "stop" { + info!("Finished generating content"); + break; + } + } } -// ...rest of the main function + + Ok(()) +} ``` ### Tool-Use @@ -232,21 +260,36 @@ You can pass to the generate_content function also tools, which will be availabl ```rust use inference_gateway_sdk::{ - GatewayError, - InferenceGatewayAPI, - InferenceGatewayClient, - Message, - Provider, - MessageRole, - Tool, - ToolFunction, - ToolType + FunctionObject, GatewayError, InferenceGatewayAPI, InferenceGatewayClient, Message, + MessageRole, Provider, Tool, ToolType, }; +use log::{info, warn}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), GatewayError> { + // Configure logging + if env::var("RUST_LOG").is_err() { + env::set_var("RUST_LOG", "info"); + } + env_logger::init(); + + // API endpoint - store as a variable so we can reuse it + let api_endpoint = "http://localhost:8080/v1"; + + // Initialize the API client + let client = InferenceGatewayClient::new(api_endpoint); -let tools = vec![ - Tool { + // Define the model and provider + let provider = Provider::Groq; + let model = "deepseek-r1-distill-llama-70b"; + + // Define the weather tool + let tools = vec![Tool { r#type: ToolType::Function, - function: ToolFunction { + function: FunctionObject { name: "get_current_weather".to_string(), description: "Get the weather for a location".to_string(), parameters: json!({ @@ -260,32 +303,127 @@ let tools = vec![ "required": ["location"] }), }, - }, -]; -let resp = client.with_tools(Some(tools)).generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", vec![ -Message { - role: MessageRole::System, - content: "You are an helpful assistent.".to_string(), - ..Default::default() -}, -Message { - role: MessageRole::User, - content: "What is the current weather in Berlin?".to_string(), - ..Default::default() -} -]).await?; + }]; -for tool_call in resp.response.tool_calls { - log::info!("Tool Call Requested by the LLM: {:?}", tool_call); - // Make the function call with the parameters requested by the LLM + // Create initial conversation + let initial_messages = vec![ + Message { + role: MessageRole::System, + content: "You are a helpful assistant that can check the weather.".to_string(), + ..Default::default() + }, + Message { + role: MessageRole::User, + content: "What is the current weather in Berlin?".to_string(), + ..Default::default() + }, + ]; - let message = Message { - role: MessageRole::Tool, - content: "The content from the tool".to_string(), - tool_call_id: Some(tool_call.id) // the tool call id so the LLM can reference it + // Make the initial API request + info!("Sending initial request to model"); + let response = client + .with_tools(Some(tools.clone())) + .generate_content(provider, model, initial_messages) + .await?; + + info!("Received response from model"); + + // Check if we have a response + let choice = match response.choices.get(0) { + Some(choice) => choice, + None => { + warn!("No choice returned"); + return Ok(()); + } }; - // Append this message to the next request + // Check for tool calls in the response + if let Some(tool_calls) = &choice.message.tool_calls { + // Create a new conversation starting with the initial messages + let mut follow_up_convo = vec![ + Message { + role: MessageRole::System, + content: "You are a helpful assistant that can check the weather.".to_string(), + ..Default::default() + }, + Message { + role: MessageRole::User, + content: "What is the current weather in Berlin?".to_string(), + ..Default::default() + }, + Message { + role: MessageRole::Assistant, + content: choice.message.content.clone(), + tool_calls: choice.message.tool_calls.clone(), + ..Default::default() + }, + ]; + + // Process each tool call + for tool_call in tool_calls { + info!("Tool Call Requested: {}", tool_call.function.name); + + if tool_call.function.name == "get_current_weather" { + // Parse arguments + let args = tool_call.function.parse_arguments()?; + + // Call our function + let weather_result = get_current_weather(args)?; + + // Add the tool response to the conversation + follow_up_convo.push(Message { + role: MessageRole::Tool, + content: weather_result, + tool_call_id: Some(tool_call.id.clone()), + ..Default::default() + }); + } + } + + // Send the follow-up request with the tool results + info!("Sending follow-up request with tool results"); + + // Create a new client for the follow-up request + let follow_up_client = InferenceGatewayClient::new(api_endpoint); + + let follow_up_response = follow_up_client + .with_tools(Some(tools)) + .generate_content(provider, model, follow_up_convo) + .await?; + + if let Some(choice) = follow_up_response.choices.get(0) { + info!("Final response: {}", choice.message.content); + } else { + warn!("No response in follow-up"); + } + } else { + info!("No tool calls in the response"); + info!("Model response: {}", choice.message.content); + } + + Ok(()) +} + +#[derive(Debug, Deserialize, Serialize)] +struct Weather { + location: String, +} + +fn get_current_weather(args: Value) -> Result { + // Parse the location from the arguments + let weather: Weather = serde_json::from_value(args)?; + info!( + "Getting weather function was called for {}", + weather.location + ); + + // In a real application, we would call an actual weather API here + // For this example, we'll just return a mock response + let location = weather.location; + Ok(format!( + "The weather in {} is currently sunny with a temperature of 22°C", + location + )) } ``` diff --git a/openapi.yaml b/openapi.yaml index 832f941..b7c4528 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,75 +3,184 @@ openapi: 3.1.0 info: title: Inference Gateway API description: | - API for interacting with various language models through the Inference Gateway. + The API for interacting with various language models and other AI services. + OpenAI, Groq, Ollama, and other providers are supported. + OpenAI compatible API for using with existing clients. + Unified API for all providers. version: 1.0.0 + license: + name: MIT + url: https://github.com/inference-gateway/inference-gateway/blob/main/LICENSE servers: - url: http://localhost:8080 +tags: + - name: Models + description: List and describe the various models available in the API. + - name: Completions + description: Generate completions from the models. + - name: Proxy + description: Proxy requests to provider endpoints. + - name: Health + description: Health check paths: - /llms: + /v1/models: get: - summary: List all language models operationId: listModels + tags: + - Models + summary: + Lists the currently available models, and provides basic information + about each one such as the owner and availability. security: - bearerAuth: [] - responses: - "200": - description: A list of models by provider - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/ListModelsResponse" - "401": - $ref: "#/components/responses/Unauthorized" - /llms/{provider}: - get: - summary: List all models for a specific provider - operationId: listModelsByProvider parameters: - name: provider - in: path - required: true + in: query + required: false schema: $ref: "#/components/schemas/Providers" - security: - - bearerAuth: [] + description: Specific provider to query (optional) responses: "200": - description: A list of models + description: List of available models content: application/json: schema: $ref: "#/components/schemas/ListModelsResponse" - "400": - $ref: "#/components/responses/BadRequest" + examples: + allProviders: + summary: Models from all providers + value: + object: "list" + data: + - id: "gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + - id: "llama-3.3-70b-versatile" + object: "model" + created: 1723651281 + owned_by: "groq" + - id: "claude-3-opus-20240229" + object: "model" + created: 1708905600 + owned_by: "anthropic" + - id: "command-r" + object: "model" + created: 1707868800 + owned_by: "cohere" + - id: "phi3:3.8b" + object: "model" + created: 1718441600 + owned_by: "ollama" + singleProvider: + summary: Models from a specific provider + value: + object: "list" + data: + - id: "gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + - id: "gpt-4-turbo" + object: "model" + created: 1687882410 + owned_by: "openai" + - id: "gpt-3.5-turbo" + object: "model" + created: 1677649963 + owned_by: "openai" "401": $ref: "#/components/responses/Unauthorized" - /llms/{provider}/generate: + "500": + $ref: "#/components/responses/InternalError" + /v1/chat/completions: post: - summary: Generate content with a specific provider's LLM - operationId: generateContent + summary: Create a chat completion + description: Creates a completion for the chat message with the specified provider + tags: + - Completions + security: + - bearerAuth: [] parameters: - name: provider - in: path - required: true + in: query + required: false schema: $ref: "#/components/schemas/Providers" - security: - - bearerAuth: [] + description: Specific provider to use (default determined by model) requestBody: + required: true content: application/json: schema: - $ref: "#/components/schemas/GenerateRequest" + type: object + required: + - model + - messages + properties: + model: + type: string + description: Model ID to use + messages: + type: array + items: + $ref: "#/components/schemas/Message" + temperature: + type: number + format: float + default: 0.7 + stream: + type: boolean + default: false + tools: + type: array + items: + type: object + max_tokens: + type: integer responses: "200": - description: Generated content + description: Successful response content: application/json: schema: - $ref: "#/components/schemas/GenerateResponse" + type: object + properties: + id: + type: string + object: + type: string + example: "chat.completion" + created: + type: integer + format: int64 + model: + type: string + choices: + type: array + items: + type: object + properties: + index: + type: integer + message: + $ref: "#/components/schemas/Message" + finish_reason: + type: string + enum: [stop, length, tool_calls, content_filter] + usage: + type: object + properties: + prompt_tokens: + type: integer + completion_tokens: + type: integer + total_tokens: + type: integer + text/event-stream: + schema: + type: string "400": $ref: "#/components/responses/BadRequest" "401": @@ -96,6 +205,8 @@ paths: get: summary: Proxy GET request to provider operationId: proxyGet + tags: + - Proxy responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -110,6 +221,8 @@ paths: post: summary: Proxy POST request to provider operationId: proxyPost + tags: + - Proxy requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -126,6 +239,8 @@ paths: put: summary: Proxy PUT request to provider operationId: proxyPut + tags: + - Proxy requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -142,6 +257,8 @@ paths: delete: summary: Proxy DELETE request to provider operationId: proxyDelete + tags: + - Proxy responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -156,6 +273,8 @@ paths: patch: summary: Proxy PATCH request to provider operationId: proxyPatch + tags: + - Proxy requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -171,7 +290,10 @@ paths: - bearerAuth: [] /health: get: + operationId: healthCheck summary: Health check + tags: + - Health responses: "200": description: Health check successful @@ -202,23 +324,32 @@ components: type: number format: float64 default: 0.7 - examples: - - openai: - summary: OpenAI chat completion request - value: - model: "gpt-3.5-turbo" - messages: - - role: "user" - content: "Hello! How can I assist you today?" - temperature: 0.7 - - anthropic: - summary: Anthropic Claude request - value: - model: "claude-3-opus-20240229" - messages: - - role: "user" - content: "Explain quantum computing" - temperature: 0.5 + examples: + openai: + summary: OpenAI chat completion request + value: + model: "gpt-3.5-turbo" + messages: + - role: "user" + content: "Hello! How can I assist you today?" + temperature: 0.7 + anthropic: + summary: Anthropic Claude request + value: + model: "claude-3-opus-20240229" + messages: + - role: "user" + content: "Explain quantum computing" + temperature: 0.5 + CreateChatCompletionRequest: + required: true + description: | + ProviderRequest depends on the specific provider and endpoint being called + If you decide to use this approach, please follow the provider-specific documentations. + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionRequest" responses: BadRequest: description: Bad request @@ -278,6 +409,13 @@ components: To enable authentication, set ENABLE_AUTH to true. When enabled, requests must include a valid JWT token in the Authorization header. schemas: + Endpoints: + type: object + properties: + models: + type: string + chat: + type: string Providers: type: string enum: @@ -287,31 +425,119 @@ components: - cloudflare - cohere - anthropic + x-provider-configs: + ollama: + id: "ollama" + url: "http://ollama:8080/v1" + auth_type: "none" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + anthropic: + id: "anthropic" + url: "https://api.anthropic.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cohere: + id: "cohere" + url: "https://api.cohere.ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/v1/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/compatibility/v1/chat/completions" + groq: + id: "groq" + url: "https://api.groq.com/openai/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + openai: + id: "openai" + url: "https://api.openai.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cloudflare: + id: "cloudflare" + url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/finetunes/public?limit=1000" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/v1/chat/completions" ProviderSpecificResponse: type: object description: | Provider-specific response format. Examples: - OpenAI GET /v1/models response: + OpenAI GET /v1/models?provider=openai response: ```json { + "provider": "openai", + "object": "list", "data": [ { "id": "gpt-4", "object": "model", - "created": 1687882410 + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } ``` - Anthropic GET /v1/models response: + Anthropic GET /v1/models?provider=anthropic response: ```json { - "models": [ + "provider": "anthropic", + "object": "list", + "data": [ { - "name": "claude-3-opus-20240229", - "description": "Most capable model for highly complex tasks" + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } @@ -346,112 +572,574 @@ components: $ref: "#/components/schemas/MessageRole" content: type: string + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCall" + tool_call_id: + type: string + reasoning: + type: string + required: + - role + - content Model: type: object description: Common model information properties: - name: + id: + type: string + object: + type: string + created: + type: integer + format: int64 + owned_by: + type: string + served_by: type: string ListModelsResponse: type: object description: Response structure for listing models properties: provider: - $ref: "#/components/schemas/Providers" - models: + type: string + object: + type: string + data: type: array items: $ref: "#/components/schemas/Model" - Tool: - type: - type: string - name: - type: string - description: - type: string - parameters: - type: object - properties: - name: - type: string - type: - type: string - default: - type: string + default: [] + FunctionObject: + type: object + properties: + description: + type: string + description: + A description of what the function does, used by the model to + choose when and how to call the function. + name: + type: string + description: + The name of the function to be called. Must be a-z, A-Z, 0-9, or + contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + strict: + type: boolean + default: false description: + Whether to enable strict schema adherence when generating the + function call. If set to true, the model will follow the exact + schema defined in the `parameters` field. Only a subset of JSON + Schema is supported when `strict` is `true`. Learn more about + Structured Outputs in the [function calling + guide](docs/guides/function-calling). + required: + - name + ChatCompletionTool: + type: object + properties: + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + FunctionParameters: + type: object + description: >- + The parameters the functions accepts, described as a JSON Schema object. + See the [guide](/docs/guides/function-calling) for examples, and the + [JSON Schema + reference](https://json-schema.org/understanding-json-schema/) for + documentation about the format. + + Omitting `parameters` defines a function with an empty parameter list. + properties: + type: + type: string + description: The type of the parameters. Currently, only `object` is supported. + properties: + type: object + description: The properties of the parameters. + additionalProperties: + type: object + description: The schema for the parameter. + additionalProperties: true + required: + type: array + items: type: string - GenerateRequest: + description: The required properties of the parameters. + additionalProperties: + type: boolean + default: false + description: Whether additional properties are allowed. + additionalProperties: true + ChatCompletionToolType: + type: string + description: The type of the tool. Currently, only `function` is supported. + enum: + - function + CompletionUsage: type: object - description: Request structure for token generation + description: Usage statistics for the completion request. + properties: + completion_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the prompt. + total_tokens: + type: integer + default: 0 + format: int64 + description: Total number of tokens used in the request (prompt + completion). required: - - model - - messages + - prompt_tokens + - completion_tokens + - total_tokens + ChatCompletionStreamOptions: + description: > + Options for streaming response. Only set this when you set `stream: + true`. + type: object + properties: + include_usage: + type: boolean + description: > + If set, an additional chunk will be streamed before the `data: + [DONE]` message. The `usage` field on this chunk shows the token + usage statistics for the entire request, and the `choices` field + will always be an empty array. All other chunks will also include a + `usage` field, but with a null value. + default: true + CreateChatCompletionRequest: + type: object properties: model: type: string + description: Model ID to use messages: + description: > + A list of messages comprising the conversation so far. type: array + minItems: 1 items: $ref: "#/components/schemas/Message" + max_tokens: + description: > + An upper bound for the number of tokens that can be generated + for a completion, including visible output tokens and reasoning tokens. + type: integer stream: + description: > + If set to true, the model response data will be streamed to the + client as it is generated using [server-sent + events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). type: boolean default: false - description: Whether to stream tokens as they are generated in raw json - ssevents: - type: boolean - default: false - description: | - Whether to use Server-Sent Events for token generation. - When enabled, the response will be streamed as SSE with the following event types: - - message-start: Initial message event with assistant role - - stream-start: Stream initialization - - content-start: Content beginning - - content-delta: Content update with new tokens - - content-end: Content completion - - message-end: Message completion - - stream-end: Stream completion - - **Note:** Depending on the provider, some events may not be present. - max_tokens: - type: integer - description: Maximum number of tokens to generate + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" tools: type: array + description: > + A list of tools the model may call. Currently, only functions + are supported as a tool. Use this to provide a list of functions + the model may generate JSON inputs for. A max of 128 functions + are supported. items: - $ref: "#/components/schemas/Tool" - ResponseTokens: + $ref: "#/components/schemas/ChatCompletionTool" + required: + - model + - messages + ChatCompletionMessageToolCallFunction: type: object - description: Token response structure + description: The function that the model called. properties: - role: + name: + type: string + description: The name of the function to call. + arguments: type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - name + - arguments + ChatCompletionMessageToolCall: + type: object + properties: + id: + type: string + description: The ID of the tool call. + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/ChatCompletionMessageToolCallFunction" + required: + - id + - type + - function + EventType: + type: string + enum: + - message-start + - stream-start + - content-start + - content-delta + - content-end + - message-end + - stream-end + ChatCompletionChoice: + type: object + properties: + finish_reason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/Message" + required: + - finish_reason + - index + - message + - logprobs + ChatCompletionStreamChoice: + type: object + required: + - delta + - finish_reason + - index + properties: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: + description: Log probability information for the choice. + type: object + properties: + content: + description: A list of message content tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + refusal: + description: A list of message refusal tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + required: + - content + - refusal + finish_reason: + $ref: "#/components/schemas/FinishReason" + index: + type: integer + description: The index of the choice in the list of choices. + CreateChatCompletionResponse: + type: object + description: + Represents a chat completion response returned by model, based on + the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: + A list of chat completion choices. Can be more than one if `n` is + greater than 1. + items: + $ref: "#/components/schemas/ChatCompletionChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. model: type: string + description: The model used for the chat completion. + object: + type: string + description: The object type, which is always `chat.completion`. + x-stainless-const: true + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + ChatCompletionStreamResponseDelta: + type: object + description: A chat completion delta generated by streamed model responses. + properties: content: type: string + description: The contents of the chunk message. tool_calls: type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCallChunk" + role: + $ref: "#/components/schemas/MessageRole" + refusal: + type: string + description: The refusal message generated by the model. + ChatCompletionMessageToolCallChunk: + type: object + properties: + index: + type: integer + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - index + ChatCompletionTokenLogprob: + type: object + properties: + token: &a1 + description: The token. + type: string + logprob: &a2 + description: + The log probability of this token, if it is within the top 20 most + likely tokens. Otherwise, the value `-9999.0` is used to signify + that the token is very unlikely. + type: number + bytes: &a3 + description: + A list of integers representing the UTF-8 bytes representation of + the token. Useful in instances where characters are represented by + multiple tokens and their byte representations must be combined to + generate the correct text representation. Can be `null` if there is + no bytes representation for the token. + type: array + items: + type: integer + top_logprobs: + description: + List of the most likely tokens and their log probability, at this + token position. In rare cases, there may be fewer than the number of + requested `top_logprobs` returned. + type: array items: type: object properties: - function: - type: object - properties: - name: - type: string - parameters: - type: object - properties: - arguments: - type: object - GenerateResponse: + token: *a1 + logprob: *a2 + bytes: *a3 + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + FinishReason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + CreateChatCompletionStreamResponse: type: object - description: Response structure for token generation + description: | + Represents a streamed chunk of a chat completion response returned + by the model, based on the provided input. properties: - provider: + id: type: string - response: - $ref: "#/components/schemas/ResponseTokens" + description: + A unique identifier for the chat completion. Each chunk has the + same ID. + choices: + type: array + description: > + A list of chat completion choices. Can contain more than one + elements if `n` is greater than 1. Can also be empty for the + + last chunk if you set `stream_options: {"include_usage": true}`. + items: + $ref: "#/components/schemas/ChatCompletionStreamChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. Each chunk has the same timestamp. + model: + type: string + description: The model to generate the completion. + system_fingerprint: + type: string + description: > + This fingerprint represents the backend configuration that the model + runs with. + + Can be used in conjunction with the `seed` request parameter to + understand when backend changes have been made that might impact + determinism. + object: + type: string + description: The object type, which is always `chat.completion.chunk`. + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + CreateCompletionResponse: + type: object + description: > + Represents a completion response from the API. Note: both the streamed + and non-streamed response objects share the same shape (unlike the chat + endpoint). + properties: + id: + type: string + description: A unique identifier for the completion. + choices: + type: array + description: + The list of completion choices the model generated for the input + prompt. + items: + type: object + required: + - finish_reason + - index + - logprobs + - text + properties: + finish_reason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + or `content_filter` if content was omitted due to a flag from + our content filters. + enum: + - stop + - length + - content_filter + index: + type: integer + logprobs: + type: object + properties: + text_offset: + type: array + items: + type: integer + token_logprobs: + type: array + items: + type: number + tokens: + type: array + items: + type: string + top_logprobs: + type: array + items: + type: object + additionalProperties: + type: number + text: + type: string + created: + type: integer + description: The Unix timestamp (in seconds) of when the completion was created. + model: + type: string + description: The model used for completion. + object: + type: string + description: The object type, which is always "text_completion" + enum: + - text_completion + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - id + - object + - created + - model + - choices Config: x-config: sections: @@ -567,7 +1255,7 @@ components: - name: anthropic_api_url env: "ANTHROPIC_API_URL" type: string - default: "https://api.anthropic.com" + default: "https://api.anthropic.com/v1" description: "Anthropic API URL" - name: anthropic_api_key env: "ANTHROPIC_API_KEY" @@ -577,7 +1265,7 @@ components: - name: cloudflare_api_url env: "CLOUDFLARE_API_URL" type: string - default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" + default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" description: "Cloudflare API URL" - name: cloudflare_api_key env: "CLOUDFLARE_API_KEY" @@ -587,7 +1275,7 @@ components: - name: cohere_api_url env: "COHERE_API_URL" type: string - default: "https://api.cohere.com" + default: "https://api.cohere.ai" description: "Cohere API URL" - name: cohere_api_key env: "COHERE_API_KEY" @@ -597,7 +1285,7 @@ components: - name: groq_api_url env: "GROQ_API_URL" type: string - default: "https://api.groq.com" + default: "https://api.groq.com/openai/v1" description: "Groq API URL" - name: groq_api_key env: "GROQ_API_KEY" @@ -607,7 +1295,7 @@ components: - name: ollama_api_url env: "OLLAMA_API_URL" type: string - default: "http://ollama:8080" + default: "http://ollama:8080/v1" description: "Ollama API URL" - name: ollama_api_key env: "OLLAMA_API_KEY" @@ -617,481 +1305,10 @@ components: - name: openai_api_url env: "OPENAI_API_URL" type: string - default: "https://api.openai.com" + default: "https://api.openai.com/v1" description: "OpenAI API URL" - name: openai_api_key env: "OPENAI_API_KEY" type: string description: "OpenAI API Key" secret: true - x-provider-configs: - ollama: - id: "ollama" - url: "http://ollama:8080" - auth_type: "none" - endpoints: - list: - endpoint: "/api/tags" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - modified_at: - type: string - size: - type: integer - digest: - type: string - details: - type: object - properties: - format: - type: string - family: - type: string - families: - type: array - items: - type: string - parameter_size: - type: string - generate: - endpoint: "/api/chat" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - prompt: - type: string - stream: - type: boolean - system: - type: string - temperature: - type: number - format: float64 - default: 0.7 - tools: - type: array - items: - type: object - properties: - function: - type: object - properties: - name: - type: string - parameters: - type: object - properties: - arguments: - type: object - response: - type: object - properties: - provider: - type: string - response: - type: object - properties: - role: - type: string - model: - type: string - content: - type: string - openai: - id: "openai" - url: "https://api.openai.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - permission: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - allow_create_engine: - type: boolean - allow_sampling: - type: boolean - allow_logprobs: - type: boolean - allow_search_indices: - type: boolean - allow_view: - type: boolean - allow_fine_tuning: - type: boolean - root: - type: string - parent: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - groq: - id: "groq" - url: "https://api.groq.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/openai/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - active: - type: boolean - context_window: - type: integer - public_apps: - type: object - generate: - endpoint: "/openai/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - tools: - type: array - items: - type: object - properties: - function: - type: object - properties: - name: - type: string - parameters: - type: object - properties: - arguments: - type: object - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - cloudflare: - id: "cloudflare" - url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" - auth_type: "bearer" - endpoints: - list: - endpoint: "/ai/finetunes/public" - method: "GET" - schema: - response: - type: object - properties: - result: - type: array - items: - type: object - properties: - id: - type: string - name: - type: string - description: - type: string - created_at: - type: string - modified_at: - type: string - public: - type: integer - model: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - prompt: - type: string - model: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - result: - type: object - properties: - response: - type: string - cohere: - id: "cohere" - url: "https://api.cohere.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - endpoints: - type: array - items: - type: string - finetuned: - type: boolean - context_length: - type: number - format: float64 - tokenizer_url: - type: string - default_endpoints: - type: array - items: - type: string - next_page_token: - type: string - generate: - endpoint: "/v2/chat" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: array - items: - type: object - properties: - type: - type: string - text: - type: string - anthropic: - id: "anthropic" - url: "https://api.anthropic.com" - auth_type: "xheader" - extra_headers: - anthropic-version: "2023-06-01" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - type: - type: string - id: - type: string - display_name: - type: string - created_at: - type: string - has_more: - type: boolean - first_id: - type: string - last_id: - type: string - generate: - endpoint: "/v1/messages" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string diff --git a/src/lib.rs b/src/lib.rs index 6a56b6a..8fb2aff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,24 +57,35 @@ struct ErrorResponse { error: String, } -/// Represents a model available through a provider -#[derive(Debug, Serialize, Deserialize)] +/// Common model information +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Model { - /// Unique identifier of the model - pub name: String, + /// The model identifier + pub id: String, + /// The object type, usually "model" + pub object: Option, + /// The Unix timestamp (in seconds) of when the model was created + pub created: Option, + /// The organization that owns the model + pub owned_by: Option, + /// The provider that serves the model + pub served_by: Option, } -/// Collection of models available from a specific provider +/// Response structure for listing models #[derive(Debug, Serialize, Deserialize)] -pub struct ProviderModels { - /// The LLM provider - pub provider: Provider, +pub struct ListModelsResponse { + /// The provider identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + /// Response object type, usually "list" + pub object: String, /// List of available models - pub models: Vec, + pub data: Vec, } /// Supported LLM providers -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)] #[serde(rename_all = "lowercase")] pub enum Provider { #[serde(alias = "Ollama", alias = "OLLAMA")] @@ -83,8 +94,6 @@ pub enum Provider { Groq, #[serde(alias = "OpenAI", alias = "OPENAI")] OpenAI, - #[serde(alias = "Google", alias = "GOOGLE")] - Google, #[serde(alias = "Cloudflare", alias = "CLOUDFLARE")] Cloudflare, #[serde(alias = "Cohere", alias = "COHERE")] @@ -99,7 +108,6 @@ impl fmt::Display for Provider { Provider::Ollama => write!(f, "ollama"), Provider::Groq => write!(f, "groq"), Provider::OpenAI => write!(f, "openai"), - Provider::Google => write!(f, "google"), Provider::Cloudflare => write!(f, "cloudflare"), Provider::Cohere => write!(f, "cohere"), Provider::Anthropic => write!(f, "anthropic"), @@ -115,7 +123,6 @@ impl TryFrom<&str> for Provider { "ollama" => Ok(Self::Ollama), "groq" => Ok(Self::Groq), "openai" => Ok(Self::OpenAI), - "google" => Ok(Self::Google), "cloudflare" => Ok(Self::Cloudflare), "cohere" => Ok(Self::Cohere), "anthropic" => Ok(Self::Anthropic), @@ -148,24 +155,60 @@ impl fmt::Display for MessageRole { /// A message in a conversation with an LLM #[derive(Debug, Serialize, Deserialize, Clone, Default)] pub struct Message { - /// Role of the message sender ("system", "user" or "assistant") + /// Role of the message sender ("system", "user", "assistant" or "tool") pub role: MessageRole, /// Content of the message pub content: String, + /// The tools an LLM wants to invoke + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, /// Unique identifier of the tool call #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, + /// Reasoning behind the message + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, } -/// Tool to use for generation -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ToolCall { - pub function: ToolFunction, +/// A tool call in a message response +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ChatCompletionMessageToolCall { + /// Unique identifier of the tool call + pub id: String, + /// Type of the tool being called + #[serde(rename = "type")] + pub r#type: ChatCompletionToolType, + /// Function that was called + pub function: ChatCompletionMessageToolCallFunction, +} + +/// Type of tool that can be called +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +pub enum ChatCompletionToolType { + /// Function tool type + #[serde(rename = "function")] + Function, +} + +/// Function details in a tool call +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct ChatCompletionMessageToolCallFunction { + /// Name of the function to call + pub name: String, + /// Arguments to the function in JSON string format + pub arguments: String, +} + +// Add this helper method to make argument access more convenient +impl ChatCompletionMessageToolCallFunction { + pub fn parse_arguments(&self) -> Result { + serde_json::from_str(&self.arguments) + } } /// Tool function to call #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ToolFunction { +pub struct FunctionObject { pub name: String, pub description: String, pub parameters: Value, @@ -182,18 +225,16 @@ pub enum ToolType { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Tool { pub r#type: ToolType, - pub function: ToolFunction, + pub function: FunctionObject, } /// Request payload for generating content #[derive(Debug, Serialize)] -struct GenerateRequest { +struct CreateChatCompletionRequest { /// Name of the model model: String, /// Conversation history and prompt messages: Vec, - /// Enable Server-Sent Events (SSE) streaming - ssevents: bool, /// Enable streaming of responses stream: bool, /// Optional tools to use for generation @@ -204,46 +245,91 @@ struct GenerateRequest { max_tokens: Option, } -/// Function details in a tool call response -#[derive(Debug, Deserialize, Clone)] -pub struct ToolFunctionResponse { - /// Name of the function that the LLM wants to call - pub name: String, - /// The arguments that the LLM wants to pass to the function - pub arguments: Value, -} - /// A tool call in the response #[derive(Debug, Deserialize, Clone)] pub struct ToolCallResponse { /// Unique identifier of the tool call pub id: String, /// Type of tool that was called + #[serde(rename = "type")] pub r#type: ToolType, /// Function that the LLM wants to call - pub function: ToolFunctionResponse, + pub function: ChatCompletionMessageToolCallFunction, } -/// The content of the response #[derive(Debug, Deserialize, Clone)] -pub struct ResponseContent { - /// Role of the responder (typically "assistant") - pub role: MessageRole, - /// Model that generated the response +pub struct ChatCompletionChoice { + pub finish_reason: String, + pub message: Message, + pub index: i32, +} + +/// The response from generating content +#[derive(Debug, Deserialize, Clone)] +pub struct CreateChatCompletionResponse { + pub id: String, + pub choices: Vec, + pub created: i64, pub model: String, - /// Generated content - pub content: String, - /// Tool calls made by the model + pub object: String, +} + +/// The response from streaming content generation +#[derive(Debug, Deserialize, Clone)] +pub struct CreateChatCompletionStreamResponse { + /// A unique identifier for the chat completion. Each chunk has the same ID. + pub id: String, + /// A list of chat completion choices. Can contain more than one element if `n` is greater than 1. + pub choices: Vec, + /// The Unix timestamp (in seconds) of when the chat completion was created. + pub created: i64, + /// The model used to generate the completion. + pub model: String, + /// This fingerprint represents the backend configuration that the model runs with. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + /// The object type, which is always "chat.completion.chunk". + pub object: String, + /// Usage statistics for the completion request. + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +/// Choice in a streaming completion response +#[derive(Debug, Deserialize, Clone)] +pub struct ChatCompletionStreamChoice { + /// The delta content for this streaming chunk + pub delta: ChatCompletionStreamDelta, + /// Index of the choice in the choices array + pub index: i32, + /// The reason the model stopped generating tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, +} + +/// Delta content for streaming responses +#[derive(Debug, Deserialize, Clone)] +pub struct ChatCompletionStreamDelta { + /// Role of the message sender + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + /// Content of the message delta + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// Tool calls for this delta + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } -/// The response from generating content +/// Usage statistics for the completion request #[derive(Debug, Deserialize, Clone)] -pub struct GenerateResponse { - /// Provider that generated the response - pub provider: Provider, - /// Content of the response - pub response: ResponseContent, +pub struct CompletionUsage { + /// Number of tokens in the generated completion + pub completion_tokens: i64, + /// Number of tokens in the prompt + pub prompt_tokens: i64, + /// Total number of tokens used in the request (prompt + completion) + pub total_tokens: i64, } /// Client for interacting with the Inference Gateway API @@ -277,8 +363,7 @@ pub trait InferenceGatewayAPI { /// /// # Returns /// A list of models available from all providers - fn list_models(&self) - -> impl Future, GatewayError>> + Send; + fn list_models(&self) -> impl Future> + Send; /// Lists available models by a specific provider /// @@ -296,7 +381,7 @@ pub trait InferenceGatewayAPI { fn list_models_by_provider( &self, provider: Provider, - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// Generates content using a specified model /// @@ -319,7 +404,7 @@ pub trait InferenceGatewayAPI { provider: Provider, model: &str, messages: Vec, - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// Stream content generation directly using the backend SSE stream. /// @@ -356,6 +441,26 @@ impl InferenceGatewayClient { } } + /// Creates a new client instance with default configuration + /// pointing to http://localhost:8080/v1 + pub fn new_default() -> Self { + let base_url = std::env::var("INFERENCE_GATEWAY_URL") + .unwrap_or_else(|_| "http://localhost:8080/v1".to_string()); + + Self { + base_url, + client: Client::new(), + token: None, + tools: None, + max_tokens: None, + } + } + + /// Returns the base URL this client is configured to use + pub fn base_url(&self) -> &str { + &self.base_url + } + /// Provides tools to use for generation /// /// # Arguments @@ -394,8 +499,8 @@ impl InferenceGatewayClient { } impl InferenceGatewayAPI for InferenceGatewayClient { - async fn list_models(&self) -> Result, GatewayError> { - let url = format!("{}/llms", self.base_url); + async fn list_models(&self) -> Result { + let url = format!("{}/models", self.base_url); let mut request = self.client.get(&url); if let Some(token) = &self.token { request = request.bearer_auth(token); @@ -403,7 +508,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient { let response = request.send().await?; match response.status() { - StatusCode::OK => Ok(response.json().await?), + StatusCode::OK => { + let json_response: ListModelsResponse = response.json().await?; + Ok(json_response) + } StatusCode::UNAUTHORIZED => { let error: ErrorResponse = response.json().await?; Err(GatewayError::Unauthorized(error.error)) @@ -426,8 +534,8 @@ impl InferenceGatewayAPI for InferenceGatewayClient { async fn list_models_by_provider( &self, provider: Provider, - ) -> Result { - let url = format!("{}/llms/{}", self.base_url, provider); + ) -> Result { + let url = format!("{}/models?provider={}", self.base_url, provider); let mut request = self.client.get(&url); if let Some(token) = &self.token { request = self.client.get(&url).bearer_auth(token); @@ -435,7 +543,10 @@ impl InferenceGatewayAPI for InferenceGatewayClient { let response = request.send().await?; match response.status() { - StatusCode::OK => Ok(response.json().await?), + StatusCode::OK => { + let json_response: ListModelsResponse = response.json().await?; + Ok(json_response) + } StatusCode::UNAUTHORIZED => { let error: ErrorResponse = response.json().await?; Err(GatewayError::Unauthorized(error.error)) @@ -460,18 +571,17 @@ impl InferenceGatewayAPI for InferenceGatewayClient { provider: Provider, model: &str, messages: Vec, - ) -> Result { - let url = format!("{}/llms/{}/generate", self.base_url, provider); + ) -> Result { + let url = format!("{}/chat/completions?provider={}", self.base_url, provider); let mut request = self.client.post(&url); if let Some(token) = &self.token { request = request.bearer_auth(token); } - let request_payload = GenerateRequest { + let request_payload = CreateChatCompletionRequest { model: model.to_string(), messages, stream: false, - ssevents: false, tools: self.tools.clone(), max_tokens: self.max_tokens, }; @@ -509,16 +619,15 @@ impl InferenceGatewayAPI for InferenceGatewayClient { let client = self.client.clone(); let base_url = self.base_url.clone(); let url = format!( - "{}/llms/{}/generate", + "{}/chat/completions?provider={}", base_url, provider.to_string().to_lowercase() ); - let request = GenerateRequest { + let request = CreateChatCompletionRequest { model: model.to_string(), messages, stream: true, - ssevents: true, tools: None, max_tokens: None, }; @@ -568,8 +677,9 @@ impl InferenceGatewayAPI for InferenceGatewayClient { #[cfg(test)] mod tests { use crate::{ - GatewayError, GenerateRequest, GenerateResponse, InferenceGatewayAPI, - InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolFunction, ToolType, + CreateChatCompletionRequest, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, FunctionObject, GatewayError, InferenceGatewayAPI, + InferenceGatewayClient, Message, MessageRole, Provider, Tool, ToolType, }; use futures_util::{pin_mut, StreamExt}; use mockito::{Matcher, Server}; @@ -581,7 +691,6 @@ mod tests { (Provider::Ollama, "ollama"), (Provider::Groq, "groq"), (Provider::OpenAI, "openai"), - (Provider::Google, "google"), (Provider::Cloudflare, "cloudflare"), (Provider::Cohere, "cohere"), (Provider::Anthropic, "anthropic"), @@ -599,7 +708,6 @@ mod tests { ("\"ollama\"", Provider::Ollama), ("\"groq\"", Provider::Groq), ("\"openai\"", Provider::OpenAI), - ("\"google\"", Provider::Google), ("\"cloudflare\"", Provider::Cloudflare), ("\"cohere\"", Provider::Cohere), ("\"anthropic\"", Provider::Anthropic), @@ -617,6 +725,7 @@ mod tests { role: MessageRole::Tool, content: "The weather is sunny".to_string(), tool_call_id: Some("call_123".to_string()), + ..Default::default() }; let serialized = serde_json::to_string(&message_with_tool).unwrap(); @@ -651,7 +760,6 @@ mod tests { (Provider::Ollama, "ollama"), (Provider::Groq, "groq"), (Provider::OpenAI, "openai"), - (Provider::Google, "google"), (Provider::Cloudflare, "cloudflare"), (Provider::Cohere, "cohere"), (Provider::Anthropic, "anthropic"), @@ -664,7 +772,7 @@ mod tests { #[test] fn test_generate_request_serialization() { - let request_payload = GenerateRequest { + let request_payload = CreateChatCompletionRequest { model: "llama3.2:1b".to_string(), messages: vec![ Message { @@ -679,10 +787,9 @@ mod tests { }, ], stream: false, - ssevents: false, tools: Some(vec![Tool { r#type: ToolType::Function, - function: ToolFunction { + function: FunctionObject { name: "get_current_weather".to_string(), description: "Get the current weather of a city".to_string(), parameters: json!({ @@ -714,7 +821,6 @@ mod tests { } ], "stream": false, - "ssevents": false, "tools": [ { "type": "function", @@ -746,29 +852,36 @@ mod tests { async fn test_authentication_header() -> Result<(), GatewayError> { let mut server = Server::new_async().await; + let mock_response = r#"{ + "object": "list", + "data": [] + }"#; + let mock_with_auth = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .match_header("authorization", "Bearer test-token") .with_status(200) .with_header("content-type", "application/json") - .with_body("[]") + .with_body(mock_response) .expect(1) .create(); - let client = InferenceGatewayClient::new(&server.url()).with_token("test-token"); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url).with_token("test-token"); client.list_models().await?; mock_with_auth.assert(); let mock_without_auth = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .match_header("authorization", Matcher::Missing) .with_status(200) .with_header("content-type", "application/json") - .with_body("[]") + .with_body(mock_response) .expect(1) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); client.list_models().await?; mock_without_auth.assert(); @@ -784,13 +897,14 @@ mod tests { }"#; let mock = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .with_status(401) .with_header("content-type", "application/json") .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); let error = client.list_models().await.unwrap_err(); assert!(matches!(error, GatewayError::Unauthorized(_))); @@ -806,27 +920,34 @@ mod tests { async fn test_list_models() -> Result<(), GatewayError> { let mut server = Server::new_async().await; - let raw_response_json = r#"[ - { - "provider": "ollama", - "models": [ - {"name": "llama2"} - ] - } - ]"#; + let raw_response_json = r#"{ + "object": "list", + "data": [ + { + "id": "llama2", + "object": "model", + "created": 1630000001, + "owned_by": "ollama", + "served_by": "ollama" + } + ] + }"#; let mock = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_response_json) .create(); - let client = InferenceGatewayClient::new(&server.url()); - let models = client.list_models().await?; + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let response = client.list_models().await?; - assert_eq!(models.len(), 1); - assert_eq!(models[0].models[0].name, "llama2"); + assert!(response.provider.is_none()); + assert_eq!(response.object, "list"); + assert_eq!(response.data.len(), 1); + assert_eq!(response.data[0].id, "llama2"); mock.assert(); Ok(()) @@ -838,23 +959,32 @@ mod tests { let raw_json_response = r#"{ "provider":"ollama", - "models": [{ - "name": "llama2" - }] + "object":"list", + "data": [ + { + "id": "llama2", + "object": "model", + "created": 1630000001, + "owned_by": "ollama", + "served_by": "ollama" + } + ] }"#; let mock = server - .mock("GET", "/llms/ollama") + .mock("GET", "/v1/models?provider=ollama") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()); - let models = client.list_models_by_provider(Provider::Ollama).await?; + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let response = client.list_models_by_provider(Provider::Ollama).await?; - assert_eq!(models.provider, Provider::Ollama); - assert_eq!(models.models[0].name, "llama2"); + assert!(response.provider.is_some()); + assert_eq!(response.provider, Some(Provider::Ollama)); + assert_eq!(response.data[0].id, "llama2"); mock.assert(); Ok(()) @@ -865,22 +995,32 @@ mod tests { let mut server = Server::new_async().await; let raw_json_response = r#"{ - "provider":"ollama", - "response":{ - "role":"assistant", - "model":"llama2", - "content":"Hellloooo" - } + "id": "chatcmpl-456", + "object": "chat.completion", + "created": 1630000001, + "model": "mixtral-8x7b", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Hellloooo" + } + } + ] }"#; let mock = server - .mock("POST", "/llms/ollama/generate") + .mock("POST", "/v1/chat/completions?provider=ollama") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let messages = vec![Message { role: MessageRole::User, content: "Hello".to_string(), @@ -890,10 +1030,8 @@ mod tests { .generate_content(Provider::Ollama, "llama2", messages) .await?; - assert_eq!(response.provider, Provider::Ollama); - assert_eq!(response.response.role, MessageRole::Assistant); - assert_eq!(response.response.model, "llama2"); - assert_eq!(response.response.content, "Hellloooo"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); + assert_eq!(response.choices[0].message.content, "Hellloooo"); mock.assert(); Ok(()) @@ -904,24 +1042,33 @@ mod tests { let mut server = Server::new_async().await; let raw_json = r#"{ - "provider": "groq", - "response": { - "role": "assistant", + "id": "chatcmpl-456", + "object": "chat.completion", + "created": 1630000001, "model": "mixtral-8x7b", - "content": "Hello" - } + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Hello" + } + } + ] }"#; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); - let direct_parse: Result = serde_json::from_str(raw_json); + let direct_parse: Result = serde_json::from_str(raw_json); assert!( direct_parse.is_ok(), "Direct JSON parse failed: {:?}", @@ -938,10 +1085,8 @@ mod tests { .generate_content(Provider::Groq, "mixtral-8x7b", messages) .await?; - assert_eq!(response.provider, Provider::Groq); - assert_eq!(response.response.role, MessageRole::Assistant); - assert_eq!(response.response.model, "mixtral-8x7b"); - assert_eq!(response.response.content, "Hello"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); + assert_eq!(response.choices[0].message.content, "Hello"); mock.assert(); Ok(()) @@ -956,13 +1101,14 @@ mod tests { }"#; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(400) .with_header("content-type", "application/json") .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); let messages = vec![Message { role: MessageRole::User, content: "Hello".to_string(), @@ -987,13 +1133,14 @@ mod tests { let mut server: mockito::ServerGuard = Server::new_async().await; let unauthorized_mock = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .with_status(401) .with_header("content-type", "application/json") .with_body(r#"{"error":"Invalid token"}"#) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); match client.list_models().await { Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"), _ => panic!("Expected Unauthorized error"), @@ -1001,7 +1148,7 @@ mod tests { unauthorized_mock.assert(); let bad_request_mock = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .with_status(400) .with_header("content-type", "application/json") .with_body(r#"{"error":"Invalid provider"}"#) @@ -1014,7 +1161,7 @@ mod tests { bad_request_mock.assert(); let internal_error_mock = server - .mock("GET", "/llms") + .mock("GET", "/v1/models") .with_status(500) .with_header("content-type", "application/json") .with_body(r#"{"error":"Internal server error occurred"}"#) @@ -1036,22 +1183,32 @@ mod tests { let mut server = Server::new_async().await; let raw_json = r#"{ - "provider": "Groq", - "response": { - "role": "assistant", - "model": "mixtral-8x7b", - "content": "Hello" - } + "id": "chatcmpl-456", + "object": "chat.completion", + "created": 1630000001, + "model": "mixtral-8x7b", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Hello" + } + } + ] }"#; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let messages = vec![Message { role: MessageRole::User, content: "Hello".to_string(), @@ -1062,8 +1219,10 @@ mod tests { .generate_content(Provider::Groq, "mixtral-8x7b", messages) .await?; - assert_eq!(response.provider, Provider::Groq); - assert_eq!(response.response.content, "Hello"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); + assert_eq!(response.choices[0].message.content, "Hello"); + assert_eq!(response.model, "mixtral-8x7b"); + assert_eq!(response.object, "chat.completion"); mock.assert(); Ok(()) @@ -1074,13 +1233,15 @@ mod tests { let mut server = Server::new_async().await; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "text/event-stream") .with_chunked_body(move |writer| -> std::io::Result<()> { let events = vec![ - format!("event: {}\ndata: {}\n\n", r#"message"#, r#"{"provider":"groq","response":{"role":"assistant","model":"mixtral-8x7b","content":"Hello"}}"#), - format!("event: {}\ndata: {}\n\n", r#"message"#, r#"{"provider":"groq","response":{"role":"assistant","model":"mixtral-8x7b","content":"World"}}"#), + format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}"#), + format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268191,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{"role":"assistant","content":" World"},"finish_reason":null}]}"#), + format!("data: {}\n\n", r#"{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268192,"model":"mixtral-8x7b","system_fingerprint":"fp_","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"completion_tokens":40,"total_tokens":57}}"#), + format!("data: [DONE]\n\n") ]; for event in events { writer.write_all(event.as_bytes())?; @@ -1089,7 +1250,9 @@ mod tests { }) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let messages = vec![Message { role: MessageRole::User, content: "Test message".to_string(), @@ -1100,23 +1263,24 @@ mod tests { pin_mut!(stream); while let Some(result) = stream.next().await { let result = result?; - let generate_response: GenerateResponse = - serde_json::from_str(&result.data).expect("Failed to parse GenerateResponse"); - - assert_eq!(result.event, Some("message".to_string())); - assert_eq!(generate_response.provider, Provider::Groq); - assert!(matches!( - generate_response.response.role, - MessageRole::Assistant - )); - assert!(matches!( - generate_response.response.model.as_str(), - "mixtral-8x7b" - )); - assert!(matches!( - generate_response.response.content.as_str(), - "Hello" | "World" - )); + let generate_response: CreateChatCompletionStreamResponse = + serde_json::from_str(&result.data) + .expect("Failed to parse CreateChatCompletionResponse"); + + if generate_response.choices[0].finish_reason.is_some() { + assert_eq!( + generate_response.choices[0].finish_reason.as_ref().unwrap(), + "stop" + ); + break; + } + + if let Some(content) = &generate_response.choices[0].delta.content { + assert!(matches!(content.as_str(), "Hello" | " World")); + } + if let Some(role) = &generate_response.choices[0].delta.role { + assert_eq!(role, &MessageRole::Assistant); + } } mock.assert(); @@ -1128,7 +1292,7 @@ mod tests { let mut server = Server::new_async().await; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(400) .with_header("content-type", "application/json") .with_chunked_body(move |writer| -> std::io::Result<()> { @@ -1144,7 +1308,9 @@ mod tests { .expect_at_least(1) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let messages = vec![Message { role: MessageRole::User, content: "Test message".to_string(), @@ -1171,28 +1337,34 @@ mod tests { let mut server = Server::new_async().await; let raw_json_response = r#"{ - "provider": "groq", - "response": { - "role": "assistant", - "model": "deepseek-r1-distill-llama-70b", - "content": "Let me check the weather for you.", - "tool_calls": [ - { - "id": "1234", - "type": "function", - "function": { - "name": "get_weather", - "arguments": { - "location": "London" + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1630000000, + "model": "deepseek-r1-distill-llama-70b", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": "Let me check the weather for you.", + "tool_calls": [ + { + "id": "1234", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"London\"}" + } } - } + ] } - ] - } + } + ] }"#; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json_response) @@ -1200,7 +1372,7 @@ mod tests { let tools = vec![Tool { r#type: ToolType::Function, - function: ToolFunction { + function: FunctionObject { name: "get_weather".to_string(), description: "Get the weather for a location".to_string(), parameters: json!({ @@ -1216,7 +1388,9 @@ mod tests { }, }]; - let client = InferenceGatewayClient::new(&server.url()).with_tools(Some(tools)); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url).with_tools(Some(tools)); + let messages = vec![Message { role: MessageRole::User, content: "What's the weather in London?".to_string(), @@ -1227,20 +1401,21 @@ mod tests { .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages) .await?; - assert_eq!(response.provider, Provider::Groq); - assert_eq!(response.response.role, MessageRole::Assistant); - assert_eq!(response.response.model, "deepseek-r1-distill-llama-70b"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); assert_eq!( - response.response.content, + response.choices[0].message.content, "Let me check the weather for you." ); - let tool_calls = response.response.tool_calls.unwrap(); + let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].function.name, "get_weather"); - let params = &tool_calls[0].function.arguments; - assert_eq!(params["location"], "London"); + let params = tool_calls[0] + .function + .parse_arguments() + .expect("Failed to parse function arguments"); + assert_eq!(params["location"].as_str().unwrap(), "London"); mock.assert(); Ok(()) @@ -1251,22 +1426,32 @@ mod tests { let mut server = Server::new_async().await; let raw_json_response = r#"{ - "provider": "openai", - "response": { - "role": "assistant", - "model": "deepseek-r1-distill-llama-70b", - "content": "Hello!" - } + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1630000000, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Hello!" + } + } + ] }"#; let mock = server - .mock("POST", "/llms/openai/generate") + .mock("POST", "/v1/chat/completions?provider=openai") .with_status(200) .with_header("content-type", "application/json") .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); + let messages = vec![Message { role: MessageRole::User, content: "Hi".to_string(), @@ -1277,7 +1462,10 @@ mod tests { .generate_content(Provider::OpenAI, "gpt-4", messages) .await?; - assert!(response.response.tool_calls.is_none()); + assert_eq!(response.model, "gpt-4"); + assert_eq!(response.choices[0].message.content, "Hello!"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); + assert!(response.choices[0].message.tool_calls.is_none()); mock.assert(); Ok(()) @@ -1287,27 +1475,6 @@ mod tests { async fn test_generate_content_with_tools_payload() -> Result<(), GatewayError> { let mut server = Server::new_async().await; - let raw_json_response = r#"{ - "provider": "groq", - "response": { - "role": "assistant", - "model": "deepseek-r1-distill-llama-70b", - "content": "Let me check the weather for you", - "tool_calls": [ - { - "id": "1234", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": { - "city": "Toronto" - } - } - } - ] - } - }"#; - let raw_request_body = r#"{ "model": "deepseek-r1-distill-llama-70b", "messages": [ @@ -1321,7 +1488,6 @@ mod tests { } ], "stream": false, - "ssevents": false, "tools": [ { "type": "function", @@ -1343,8 +1509,35 @@ mod tests { ] }"#; + let raw_json_response = r#"{ + "id": "1234", + "object": "chat.completion", + "created": 1630000000, + "model": "deepseek-r1-distill-llama-70b", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Let me check the weather for you", + "tool_calls": [ + { + "id": "1234", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"city\": \"Toronto\"}" + } + } + ] + } + } + ] + }"#; + let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "application/json") .match_body(mockito::Matcher::JsonString(raw_request_body.to_string())) @@ -1353,7 +1546,7 @@ mod tests { let tools = vec![Tool { r#type: ToolType::Function, - function: ToolFunction { + function: FunctionObject { name: "get_current_weather".to_string(), description: "Get the current weather of a city".to_string(), parameters: json!({ @@ -1368,7 +1561,9 @@ mod tests { }), }, }]; - let client = InferenceGatewayClient::new(&server.url()); + + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url); let messages = vec![ Message { @@ -1388,13 +1583,20 @@ mod tests { .generate_content(Provider::Groq, "deepseek-r1-distill-llama-70b", messages) .await?; - assert_eq!(response.response.role, MessageRole::Assistant); - assert_eq!(response.response.model, "deepseek-r1-distill-llama-70b"); + assert_eq!(response.choices[0].message.role, MessageRole::Assistant); assert_eq!( - response.response.content, + response.choices[0].message.content, "Let me check the weather for you" ); - assert_eq!(response.response.tool_calls.unwrap().len(), 1); + assert_eq!( + response.choices[0] + .message + .tool_calls + .as_ref() + .unwrap() + .len(), + 1 + ); mock.assert(); Ok(()) @@ -1405,16 +1607,24 @@ mod tests { let mut server = Server::new_async().await; let raw_json_response = r#"{ - "provider": "groq", - "response": { - "role": "assistant", - "model": "mixtral-8x7b", - "content": "Here's a poem with 100 tokens..." - } + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1630000000, + "model": "mixtral-8x7b", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "Here's a poem with 100 tokens..." + } + } + ] }"#; let mock = server - .mock("POST", "/llms/groq/generate") + .mock("POST", "/v1/chat/completions?provider=groq") .with_status(200) .with_header("content-type", "application/json") .match_body(mockito::Matcher::JsonString( @@ -1422,7 +1632,6 @@ mod tests { "model": "mixtral-8x7b", "messages": [{"role":"user","content":"Write a poem"}], "stream": false, - "ssevents": false, "max_tokens": 100 }"# .to_string(), @@ -1430,7 +1639,8 @@ mod tests { .with_body(raw_json_response) .create(); - let client = InferenceGatewayClient::new(&server.url()).with_max_tokens(Some(100)); + let base_url = format!("{}/v1", server.url()); + let client = InferenceGatewayClient::new(&base_url).with_max_tokens(Some(100)); let messages = vec![Message { role: MessageRole::User, @@ -1442,11 +1652,13 @@ mod tests { .generate_content(Provider::Groq, "mixtral-8x7b", messages) .await?; - assert_eq!(response.provider, Provider::Groq); assert_eq!( - response.response.content, + response.choices[0].message.content, "Here's a poem with 100 tokens..." ); + assert_eq!(response.model, "mixtral-8x7b"); + assert_eq!(response.created, 1630000000); + assert_eq!(response.object, "chat.completion"); mock.assert(); Ok(()) @@ -1465,4 +1677,26 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_client_base_url_configuration() -> Result<(), GatewayError> { + let mut custom_url_server = Server::new_async().await; + + let custom_url_mock = custom_url_server + .mock("GET", "/health") + .with_status(200) + .create(); + + let custom_client = InferenceGatewayClient::new(&custom_url_server.url()); + let is_healthy = custom_client.health_check().await?; + assert!(is_healthy); + custom_url_mock.assert(); + + let default_client = InferenceGatewayClient::new_default(); + + let default_url = "http://localhost:8080/v1"; + assert_eq!(default_client.base_url(), default_url); + + Ok(()) + } }