Skip to content

Commit 7f046a5

Browse files
authored
Address security issue of loading arbitrary files as external data (#26776)
### Description Verify external data references in TensorProto specify data location that is under the model directory structure, reject absolute paths and paths that escape the model path. Make the validation function available to bridge based EPs. Expose ExternalDataInfo via a bridge to some EPs that choose to handle the data itself. ### Motivation and Context This is a security concern.
1 parent b3780dc commit 7f046a5

File tree

11 files changed

+300
-6
lines changed

11 files changed

+300
-6
lines changed

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,27 @@ Status TensorProtoWithExternalDataToTensorProto(
306306
return Status::OK();
307307
}
308308

309+
Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
310+
const std::filesystem::path& location) {
311+
// Reject absolute paths
312+
ORT_RETURN_IF(location.is_absolute(),
313+
"Absolute paths not allowed for external data location");
314+
if (!base_dir.empty()) {
315+
// Resolve and verify the path stays within model directory
316+
auto base_canonical = std::filesystem::weakly_canonical(base_dir);
317+
// If the symlink exists, it resolves to the target path;
318+
// so if the symllink is outside the directory it would be caught here.
319+
auto resolved = std::filesystem::weakly_canonical(base_dir / location);
320+
// Check that resolved path starts with base directory
321+
auto [base_end, resolved_it] = std::mismatch(
322+
base_canonical.begin(), base_canonical.end(),
323+
resolved.begin(), resolved.end());
324+
ORT_RETURN_IF(base_end != base_canonical.end(),
325+
"External data path: ", location, " escapes model directory: ", base_dir);
326+
}
327+
return Status::OK();
328+
}
329+
309330
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
310331
const std::filesystem::path& tensor_proto_dir,
311332
std::basic_string<ORTCHAR_T>& external_file_path,

onnxruntime/core/framework/tensorprotoutils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,18 @@ Status TensorProtoWithExternalDataToTensorProto(
521521
const std::filesystem::path& model_path,
522522
ONNX_NAMESPACE::TensorProto& new_tensor_proto);
523523

524+
/// <summary>
525+
/// The functions will make sure the 'location' specified in the external data is under the 'base_dir'.
526+
/// If the `base_dir` is empty, the function only ensures that `location` is not an absolute path.
527+
/// </summary>
528+
/// <param name="base_dir">model location directory</param>
529+
/// <param name="location">location is a string retrieved from TensorProto external data that is not
530+
/// an in-memory tag</param>
531+
/// <returns>The function will fail if the resolved full path is not under the model directory
532+
/// or one of the subdirectories</returns>
533+
Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
534+
const std::filesystem::path& location);
535+
524536
#endif // !defined(SHARED_PROVIDER)
525537

526538
inline bool HasType(const ONNX_NAMESPACE::AttributeProto& at_proto) {

onnxruntime/core/graph/graph.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3724,9 +3724,14 @@ Status Graph::ConvertInitializersIntoOrtValues() {
37243724
std::vector<Graph*> all_subgraphs;
37253725
FindAllSubgraphs(all_subgraphs);
37263726

3727+
const auto& model_path = GetModel().ModelPath();
3728+
PathString model_dir;
3729+
if (!model_path.empty()) {
3730+
ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, model_dir));
3731+
}
3732+
37273733
auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
37283734
// if we have any initializers that are not in memory, put them there.
3729-
const auto& model_path = graph.ModelPath();
37303735
auto& graph_proto = *graph.graph_proto_;
37313736
for (int i = 0, lim = graph_proto.initializer_size(); i < lim; ++i) {
37323737
auto& tensor_proto = *graph_proto.mutable_initializer(i);
@@ -3744,9 +3749,18 @@ Status Graph::ConvertInitializersIntoOrtValues() {
37443749
"The model contains initializers with arbitrary in-memory references.",
37453750
"This is an invalid model.");
37463751
}
3752+
} else {
3753+
// Validate external data location
3754+
std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
3755+
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
3756+
const auto& location = external_data_info->GetRelPath();
3757+
auto st = utils::ValidateExternalDataPath(model_dir, location);
3758+
if (!st.IsOK()) {
3759+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
3760+
"External data path validation failed for initializer: ", tensor_proto.name(),
3761+
". Error: ", st.ErrorMessage());
3762+
}
37473763
}
3748-
// ignore data on disk, that will be loaded either by EP or at session_state finalize
3749-
// ignore valid in-memory references
37503764
continue;
37513765
}
37523766

onnxruntime/core/providers/shared_library/provider_api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,11 @@ inline bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto
441441
return g_host->Utils__HasExternalDataInMemory(ten_proto);
442442
}
443443

444+
inline Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
445+
const std::filesystem::path& location) {
446+
return g_host->Utils__ValidateExternalDataPath(base_dir, location);
447+
}
448+
444449
} // namespace utils
445450

446451
namespace graph_utils {

onnxruntime/core/providers/shared_library/provider_interfaces.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct ProviderHost;
3838
struct ProviderHostCPU;
3939

4040
class ExternalDataInfo;
41+
4142
class PhiloxGenerator;
4243
using ProviderType = const std::string&;
4344
class RandomGenerator;
@@ -999,6 +1000,9 @@ struct ProviderHost {
9991000

10001001
virtual bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) = 0;
10011002

1003+
virtual Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path,
1004+
const std::filesystem::path& location) = 0;
1005+
10021006
// Model
10031007
virtual std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
10041008
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
@@ -1136,6 +1140,15 @@ struct ProviderHost {
11361140

11371141
virtual Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) = 0;
11381142

1143+
// ExternalDataInfo
1144+
virtual void ExternalDataInfo__operator_delete(ExternalDataInfo*) = 0;
1145+
virtual const PathString& ExternalDataInfo__GetRelPath(const ExternalDataInfo*) const = 0;
1146+
virtual int64_t ExternalDataInfo__GetOffset(const ExternalDataInfo*) const = 0;
1147+
virtual size_t ExternalDataInfo__GetLength(const ExternalDataInfo*) const = 0;
1148+
virtual const std::string& ExternalDataInfo__GetChecksum(const ExternalDataInfo*) const = 0;
1149+
virtual Status ExternalDataInfo__Create(const ONNX_NAMESPACE::StringStringEntryProtos& input,
1150+
std::unique_ptr<ExternalDataInfo>& out) = 0;
1151+
11391152
// Initializer
11401153
virtual Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type,
11411154
std::string_view name,

onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,39 @@ struct ConstGraphNodes final {
12011201
PROVIDER_DISALLOW_ALL(ConstGraphNodes)
12021202
};
12031203

1204+
class ExternalDataInfo {
1205+
public:
1206+
static void operator delete(void* p) {
1207+
g_host->ExternalDataInfo__operator_delete(reinterpret_cast<ExternalDataInfo*>(p));
1208+
}
1209+
1210+
const PathString& GetRelPath() const {
1211+
return g_host->ExternalDataInfo__GetRelPath(this);
1212+
}
1213+
1214+
int64_t GetOffset() const {
1215+
return g_host->ExternalDataInfo__GetOffset(this);
1216+
}
1217+
1218+
size_t GetLength() const {
1219+
return g_host->ExternalDataInfo__GetLength(this);
1220+
}
1221+
1222+
const std::string& GetChecksum() const {
1223+
return g_host->ExternalDataInfo__GetChecksum(this);
1224+
}
1225+
1226+
static Status Create(
1227+
const ONNX_NAMESPACE::StringStringEntryProtos& input,
1228+
std::unique_ptr<ExternalDataInfo>& out) {
1229+
return g_host->ExternalDataInfo__Create(input, out);
1230+
}
1231+
1232+
ExternalDataInfo() = delete;
1233+
ExternalDataInfo(const ExternalDataInfo&) = delete;
1234+
ExternalDataInfo& operator=(const ExternalDataInfo& v) = delete;
1235+
};
1236+
12041237
class Initializer {
12051238
public:
12061239
Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type,

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "core/framework/run_options.h"
3030
#include "core/framework/sparse_utils.h"
3131
#include "core/framework/tensorprotoutils.h"
32+
#include "core/framework/tensor_external_data_info.h"
3233
#include "core/framework/TensorSeq.h"
3334
#include "core/graph/constants.h"
3435
#include "core/graph/graph_proto_serializer.h"
@@ -1281,6 +1282,11 @@ struct ProviderHostImpl : ProviderHost {
12811282
return onnxruntime::utils::HasExternalDataInMemory(ten_proto);
12821283
}
12831284

1285+
Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path,
1286+
const std::filesystem::path& location) override {
1287+
return onnxruntime::utils::ValidateExternalDataPath(base_path, location);
1288+
}
1289+
12841290
// Model (wrapped)
12851291
std::unique_ptr<Model> Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path,
12861292
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
@@ -1487,6 +1493,25 @@ struct ProviderHostImpl : ProviderHost {
14871493
graph_utils::MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_in_memory);
14881494
}
14891495

1496+
// ExternalDataInfo (wrapped)
1497+
void ExternalDataInfo__operator_delete(ExternalDataInfo* p) override { delete p; }
1498+
const PathString& ExternalDataInfo__GetRelPath(const ExternalDataInfo* p) const override {
1499+
return p->GetRelPath();
1500+
}
1501+
int64_t ExternalDataInfo__GetOffset(const ExternalDataInfo* p) const override {
1502+
return narrow<int64_t>(p->GetOffset());
1503+
}
1504+
size_t ExternalDataInfo__GetLength(const ExternalDataInfo* p) const override {
1505+
return p->GetLength();
1506+
}
1507+
const std::string& ExternalDataInfo__GetChecksum(const ExternalDataInfo* p) const override {
1508+
return p->GetChecksum();
1509+
}
1510+
Status ExternalDataInfo__Create(const ONNX_NAMESPACE::StringStringEntryProtos& input,
1511+
std::unique_ptr<ExternalDataInfo>& out) override {
1512+
return ExternalDataInfo::Create(input, out);
1513+
}
1514+
14901515
// Initializer (wrapped)
14911516
Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type,
14921517
std::string_view name,

onnxruntime/test/framework/tensorutils_test.cc

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <cstdint>
1414
#include <limits>
15+
#include <fstream>
1516

1617
#include "gtest/gtest.h"
1718
#include "gmock/gmock.h"
@@ -502,5 +503,88 @@ TEST(TensorProtoUtilsTest, ConstantTensorProtoWithExternalData) {
502503
TestConstantNodeConversionWithExternalData<float>(TensorProto_DataType_FLOAT);
503504
TestConstantNodeConversionWithExternalData<double>(TensorProto_DataType_DOUBLE);
504505
}
506+
507+
// Test fixture for creating temporary directories and files for path validation tests.
508+
class PathValidationTest : public ::testing::Test {
509+
protected:
510+
void SetUp() override {
511+
// Create a temporary directory for the tests.
512+
base_dir_ = std::filesystem::temp_directory_path() / "PathValidationTest";
513+
outside_dir_ = std::filesystem::temp_directory_path() / "outside";
514+
std::filesystem::create_directories(base_dir_);
515+
std::filesystem::create_directories(outside_dir_);
516+
}
517+
518+
void TearDown() override {
519+
// Clean up the temporary directory.
520+
std::filesystem::remove_all(base_dir_);
521+
std::filesystem::remove_all(outside_dir_);
522+
}
523+
524+
std::filesystem::path base_dir_;
525+
std::filesystem::path outside_dir_;
526+
};
527+
528+
// Test cases for ValidateExternalDataPath.
529+
TEST_F(PathValidationTest, ValidateExternalDataPath) {
530+
// Valid relative path.
531+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "data.bin"));
532+
533+
// Empty base directory.
534+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "data.bin"));
535+
536+
// Empty location.
537+
// Only validate it is not an absolute path.
538+
ASSERT_TRUE(utils::ValidateExternalDataPath(base_dir_, "").IsOK());
539+
540+
// Path with ".." that escapes the base directory.
541+
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "../data.bin").IsOK());
542+
543+
// Absolute path.
544+
#ifdef _WIN32
545+
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "C:\\data.bin").IsOK());
546+
ASSERT_FALSE(utils::ValidateExternalDataPath("", "C:\\data.bin").IsOK());
547+
#else
548+
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "/data.bin").IsOK());
549+
ASSERT_FALSE(utils::ValidateExternalDataPath("", "/data.bin").IsOK());
550+
#endif // Absolute path.
551+
552+
// Windows vs Unix path separators.
553+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "sub/data.bin"));
554+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "sub\\data.bin"));
555+
556+
// Base directory does not exist.
557+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath("non_existent_dir", "data.bin"));
558+
}
559+
560+
TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkInside) {
561+
// Symbolic link that points inside the base directory.
562+
try {
563+
auto target = base_dir_ / "target.bin";
564+
std::ofstream{target};
565+
auto link = base_dir_ / "link.bin";
566+
std::filesystem::create_symlink(target, link);
567+
} catch (const std::exception& e) {
568+
GTEST_SKIP() << "Skipping symlink tests since symlink creation is not supported in this environment. Exception: "
569+
<< e.what();
570+
}
571+
ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "link.bin"));
572+
}
573+
574+
TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkOutside) {
575+
// Symbolic link that points outside the base directory.
576+
auto outside_target = outside_dir_ / "outside.bin";
577+
try {
578+
{
579+
std::ofstream{outside_target};
580+
auto outside_link = base_dir_ / "outside_link.bin";
581+
std::filesystem::create_symlink(outside_target, outside_link);
582+
}
583+
} catch (const std::exception& e) {
584+
GTEST_SKIP() << "Skipping symlink tests since symlink creation is not supported in this environment. Exception: " << e.what();
585+
}
586+
ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "outside_link.bin").IsOK());
587+
}
588+
505589
} // namespace test
506590
} // namespace onnxruntime

onnxruntime/test/shared_lib/test_inference.cc

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4762,9 +4762,9 @@ TEST(CApiTest, custom_cast) {
47624762
custom_op_domain, nullptr);
47634763
}
47644764

4765-
TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
4765+
TEST(CApiTest, ModelWithMaliciousExternalDataInMemoryShouldFailToLoad) {
47664766
// Attempt to create an ORT session with the malicious model
4767-
// This should fail due to the invalid external data reference
4767+
// This should fail due to the invalid external in-memory reference
47684768
constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_evil_weights.onnx");
47694769

47704770
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
@@ -4785,7 +4785,7 @@ TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
47854785
}
47864786

47874787
// Verify that loading the model failed
4788-
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data";
4788+
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious in-memory data";
47894789

47904790
// Verify that the exception message indicates security or external data issues
47914791
EXPECT_TRUE(exception_message.find("in-memory") != std::string::npos ||
@@ -4794,3 +4794,36 @@ TEST(CApiTest, ModelWithMaliciousExternalDataShouldFailToLoad) {
47944794
exception_message.find("model") != std::string::npos)
47954795
<< "Exception message should indicate external data or security issue. Got: " << exception_message;
47964796
}
4797+
4798+
TEST(CApiTest, ModelWithExternalDataOutsideModelDirectoryShouldFailToLoad) {
4799+
// Attempt to create an ORT session with the malicious model
4800+
// This should fail due to the external file that is not under model directory structure
4801+
// i.e. ../../../../etc/passwd
4802+
constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_arbitrary_external_file.onnx");
4803+
4804+
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
4805+
Ort::SessionOptions session_options;
4806+
4807+
bool exception_thrown = false;
4808+
std::string exception_message;
4809+
4810+
try {
4811+
// This should throw an exception due to malicious external data
4812+
Ort::Session session(env, model_path, session_options);
4813+
} catch (const Ort::Exception& e) {
4814+
exception_thrown = true;
4815+
exception_message = e.what();
4816+
} catch (const std::exception& e) {
4817+
exception_thrown = true;
4818+
exception_message = e.what();
4819+
}
4820+
4821+
// Verify that loading the model failed
4822+
EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data";
4823+
4824+
// Verify that the exception message indicates security or external data issues
4825+
EXPECT_TRUE(exception_message.find("External data path escapes model directory") != std::string::npos ||
4826+
exception_message.find("invalid") != std::string::npos ||
4827+
exception_message.find("model") != std::string::npos)
4828+
<< "Exception message should indicate external data or security issue. Got: " << exception_message;
4829+
}
504 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)