diff --git a/src/Data/ConnectionHelper.cs b/src/Data/ConnectionHelper.cs index 9ef74a4..b77102e 100644 --- a/src/Data/ConnectionHelper.cs +++ b/src/Data/ConnectionHelper.cs @@ -1,8 +1,6 @@ using System; -using System.Collections.Generic; using System.Data; using System.Data.Common; -using System.Text; namespace Kros.KORM.Data { @@ -45,9 +43,6 @@ protected virtual void Dispose(bool disposing) } } - public void Dispose() - { - Dispose(true); - } + public void Dispose() => Dispose(true); } } diff --git a/src/Data/TransactionHelper.cs b/src/Data/TransactionHelper.cs index 9203044..4fd4ec7 100644 --- a/src/Data/TransactionHelper.cs +++ b/src/Data/TransactionHelper.cs @@ -1,4 +1,5 @@ -using Kros.Utils; +using Kros.KORM.Properties; +using Kros.Utils; using System; using System.Collections.Generic; using System.Data; @@ -13,34 +14,35 @@ namespace Kros.KORM.Data internal class TransactionHelper { public const IsolationLevel DefaultIsolationLevel = IsolationLevel.ReadCommitted; - private const int TIMEOUT_DEFAULT = 30; - - private readonly DbConnection _connection; - private Transaction _topTransaction; - private bool _canCommit = true; - private readonly Stack _transactions = new Stack(); + private const int DefaultCommandTimeout = 30; #region Nested types private class Transaction : ITransaction { - private readonly ConnectionHelper _connectionHelper; + private readonly DbConnection _connection; + private readonly bool _closeConnection; private readonly Lazy _transaction; private readonly TransactionHelper _transactionHelper; private bool _wasCommitOrRollback = false; - public Transaction(TransactionHelper transactionHelper, ConnectionHelper connectionHelper, IsolationLevel isolationLevel) + public Transaction( + TransactionHelper transactionHelper, + DbConnection connection, + bool closeConnection, + IsolationLevel isolationLevel) { - _connectionHelper = connectionHelper; - _transaction = new Lazy(() => connectionHelper.Connection.BeginTransaction(isolationLevel)); _transactionHelper = transactionHelper; + _connection = connection; + _closeConnection = closeConnection; + _transaction = new Lazy(() => _connection.BeginTransaction(isolationLevel)); } public void Commit() { - _wasCommitOrRollback = true; if (_transactionHelper.CanCommitTransaction) { + _wasCommitOrRollback = true; _transaction.Value.Commit(); _transactionHelper.EndTransaction(true); } @@ -53,7 +55,7 @@ public void Rollback() _transactionHelper.EndTransaction(false); } - public int CommandTimeout { get; set; } = TIMEOUT_DEFAULT; + public int CommandTimeout { get; set; } = DefaultCommandTimeout; public static implicit operator DbTransaction(Transaction transaction) => transaction?._transaction.Value; @@ -69,7 +71,10 @@ public void Dispose() { _transaction.Value.Dispose(); } - _connectionHelper.Dispose(); + if (_closeConnection) + { + _connection.Close(); + } } } @@ -87,8 +92,14 @@ public NestedTransaction(TransactionHelper transactionHelper, int timeout) public void Commit() { + _wasCommitOrRollback = true; _transactionHelper.EndTransaction(true); + } + + public void Rollback() + { _wasCommitOrRollback = true; + _transactionHelper.EndTransaction(false); } public void Dispose() @@ -99,31 +110,32 @@ public void Dispose() } } - public void Rollback() - { - _transactionHelper.EndTransaction(false); - _wasCommitOrRollback = true; - } - public int CommandTimeout { - get => _timeout; - set { } + get => DefaultCommandTimeout; + set => throw new InvalidOperationException(Resources.NestedTransactionCommandTimeoutIsReadonly); } } #endregion - public TransactionHelper(DbConnection connection) + private readonly DbConnection _connection; + private readonly bool _closeConnection; + private Transaction _topTransaction; + private bool _canCommit = true; + private readonly Stack _transactions = new Stack(); + + public TransactionHelper(DbConnection connection, bool closeConnection) { _connection = Check.NotNull(connection, nameof(connection)); + _closeConnection = closeConnection; } public ITransaction BeginTransaction(IsolationLevel isolationLevel) { if (_transactions.Count == 0) { - _topTransaction = new Transaction(this, new ConnectionHelper(_connection), isolationLevel); + _topTransaction = new Transaction(this, _connection, _closeConnection, isolationLevel); _transactions.Push(_topTransaction); _canCommit = true; } @@ -131,7 +143,6 @@ public ITransaction BeginTransaction(IsolationLevel isolationLevel) { _transactions.Push(new NestedTransaction(this, _topTransaction.CommandTimeout)); } - return _transactions.Peek(); } @@ -145,7 +156,6 @@ private void EndTransaction(bool success) { _canCommit &= success; _transactions.Pop(); - if (!_transactions.Any()) { _topTransaction = null; @@ -160,7 +170,6 @@ public DbCommand CreateCommand() cmd.Transaction = _topTransaction; cmd.CommandTimeout = _topTransaction.CommandTimeout; } - return cmd; } } diff --git a/src/Database.cs b/src/Database.cs index 93f2302..1c8f66d 100644 --- a/src/Database.cs +++ b/src/Database.cs @@ -60,8 +60,14 @@ public partial class Database : IDatabase /// /// Builder for creating instance. /// + [Obsolete("Use CreateBuilder method.")] public static IDatabaseBuilder Builder => new DatabaseBuilder(); + /// + /// Creates a builder for creating instance. + /// + public static IDatabaseBuilder CreateBuilder() => new DatabaseBuilder(); + #endregion #region Private fields diff --git a/src/IAuthTokenProvider.cs b/src/IAuthTokenProvider.cs new file mode 100644 index 0000000..e1aa50c --- /dev/null +++ b/src/IAuthTokenProvider.cs @@ -0,0 +1,14 @@ +namespace Kros.KORM +{ + /// + /// Support for token-based authentication for SQL Server. + /// + public interface IAuthTokenProvider + { + /// + /// Returns authentication token, or value, if token can not be obtained. + /// + /// Authentication token. + string GetToken(); + } +} diff --git a/src/Query/Providers/QueryProvider.cs b/src/Query/Providers/QueryProvider.cs index 990962b..d4c5074 100644 --- a/src/Query/Providers/QueryProvider.cs +++ b/src/Query/Providers/QueryProvider.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -122,7 +123,7 @@ protected QueryProvider( InitSqlExpressionVisitor(Check.NotNull(sqlGeneratorFactory, nameof(sqlGeneratorFactory))); IsExternalConnection = false; _modelBuilder = Check.NotNull(modelBuilder, nameof(modelBuilder)); - _transactionHelper = new Lazy(() => new TransactionHelper(Connection)); + _transactionHelper = new Lazy(TransactionHelperFactory); } /// @@ -147,11 +148,17 @@ protected QueryProvider( InitSqlExpressionVisitor(Check.NotNull(sqlGeneratorFactory, nameof(sqlGeneratorFactory))); IsExternalConnection = true; - _transactionHelper = new Lazy(() => new TransactionHelper(Connection)); + _transactionHelper = new Lazy(TransactionHelperFactory); } private void InitSqlExpressionVisitor(ISqlExpressionVisitorFactory sqlGeneratorFactory) - => _sqlExpressionVisitor = new Lazy(() => sqlGeneratorFactory.CreateVisitor(Connection)); + => _sqlExpressionVisitor = new Lazy(() => sqlGeneratorFactory.CreateVisitor(GetConnection())); + + private TransactionHelper TransactionHelperFactory() + { + DbConnection connection = GetConnection(); + return new TransactionHelper(connection, !connection.State.HasFlag(ConnectionState.Open)); + } #endregion @@ -190,7 +197,7 @@ public void SetParameterDbType(DbParameter parameter, string tableName, string c private TableSchema LoadTableSchema(string tableName) { IDatabaseSchemaLoader schemaLoader = GetSchemaLoader(); - TableSchema tableSchema = schemaLoader.LoadTableSchema(Connection, tableName); + TableSchema tableSchema = schemaLoader.LoadTableSchema(GetConnection(), tableName); return tableSchema ?? throw new InvalidOperationException(string.Format(Resources.QueryProviderCouldNotGetTableSchema, tableName)); } @@ -541,7 +548,7 @@ public IIdGeneratorsForDatabaseInit GetIdGeneratorsForDatabaseInit() private IDbConnection GetConnectionForIdGenerator() { - var connection = (Connection as ICloneable).Clone() as DbConnection; + var connection = (GetConnection() as ICloneable).Clone() as DbConnection; try { connection.Open(); @@ -631,24 +638,27 @@ public TResult Execute(Expression expression) /// Vráti spojenie na databázu s ktorou trieda pracuje. Ak trieda bola vytvorená iba so zadaným /// connection string-om, je vytvorené nové spojenie. /// - protected DbConnection Connection + [Obsolete("Use GetConnection() method.")] + protected DbConnection Connection => GetConnection(); + + /// + /// Returns (creates if needed) connection. + /// + /// instance. + protected virtual DbConnection GetConnection() { - get + if (_connection == null) { - if (_connection == null) - { - _connection = DbProviderFactory.CreateConnection(); - _connection.ConnectionString = _connectionSettings.ConnectionString; - } - return _connection; + _connection = DbProviderFactory.CreateConnection(); + _connection.ConnectionString = _connectionSettings.ConnectionString; } + return _connection; } - private Data.ConnectionHelper OpenConnection() - { - return new Data.ConnectionHelper(Connection); - } + private Data.ConnectionHelper OpenConnection() => new Data.ConnectionHelper(GetConnection()); + [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", + Justification = "Query is external or generated.")] private DbCommandInfo CreateCommand(IQuery query) { DbCommand command = _transactionHelper.Value.CreateCommand(); @@ -671,6 +681,8 @@ private DbCommandInfo CreateCommand(IQuery query) protected internal void SetQueryFilter(IQuery query, ISqlExpressionVisitor sqlVisitor) => (query as IQueryBaseInternal).ApplyQueryFilter(_databaseMapper, sqlVisitor); + [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", + Justification = "Query is external or generated.")] private DbCommand CreateCommand(string commandText, CommandParameterCollection parameters) { DbCommand command = _transactionHelper.Value.CreateCommand(); @@ -683,7 +695,6 @@ private DbCommand CreateCommand(string commandText, CommandParameterCollection p AddCommandParameter(command, parameter); } } - return command; } diff --git a/src/Query/Providers/SqlServerQueryProvider.cs b/src/Query/Providers/SqlServerQueryProvider.cs index 2394c90..c3b50d8 100644 --- a/src/Query/Providers/SqlServerQueryProvider.cs +++ b/src/Query/Providers/SqlServerQueryProvider.cs @@ -17,6 +17,8 @@ namespace Kros.KORM.Query /// public class SqlServerQueryProvider : QueryProvider { + private readonly IAuthTokenProvider _tokenProvider; + /// /// Initializes a new instance of the class. /// @@ -25,14 +27,17 @@ public class SqlServerQueryProvider : QueryProvider /// The model builder. /// The logger. /// The Database mapper. + /// Provider to support token-based authentication. public SqlServerQueryProvider( KormConnectionSettings connectionString, ISqlExpressionVisitorFactory sqlGeneratorFactory, IModelBuilder modelBuilder, ILogger logger, - IDatabaseMapper databaseMapper) + IDatabaseMapper databaseMapper, + IAuthTokenProvider tokenProvider) : base(connectionString, sqlGeneratorFactory, modelBuilder, logger, databaseMapper) { + _tokenProvider = tokenProvider; } /// @@ -43,14 +48,17 @@ public SqlServerQueryProvider( /// The model builder. /// The logger. /// The Database mapper. + /// Provider to support token-based authentication. public SqlServerQueryProvider( DbConnection connection, ISqlExpressionVisitorFactory sqlGeneratorFactory, IModelBuilder modelBuilder, ILogger logger, - IDatabaseMapper databaseMapper) + IDatabaseMapper databaseMapper, + IAuthTokenProvider tokenProvider) : base(connection, sqlGeneratorFactory, modelBuilder, logger, databaseMapper) { + _tokenProvider = tokenProvider; } /// @@ -58,6 +66,26 @@ public SqlServerQueryProvider( /// public override DbProviderFactory DbProviderFactory => SqlClientFactory.Instance; + /// + /// Returns (creates if needed) connection. If was setup in constructor, + /// it is used to set the AccessToken on connection. + /// + /// instance. + protected override DbConnection GetConnection() + { + var connection = (SqlConnection)base.GetConnection(); + SetAccessToken(connection); + return connection; + } + + private void SetAccessToken(SqlConnection connection) + { + if (_tokenProvider != null) + { + connection.AccessToken = _tokenProvider.GetToken(); + } + } + /// /// Creates instance of . /// @@ -69,11 +97,11 @@ public override IBulkInsert CreateBulkInsert() var transaction = GetCurrentTransaction(); if (IsExternalConnection || transaction != null) { - return new SqlServerBulkInsert(Connection as SqlConnection, transaction as SqlTransaction); + return new SqlServerBulkInsert(GetConnection() as SqlConnection, transaction as SqlTransaction); } else { - return new SqlServerBulkInsert(ConnectionString); + return new SqlServerBulkInsert(CreateConnection()); } } @@ -89,14 +117,22 @@ public override IBulkUpdate CreateBulkUpdate() if (IsExternalConnection || transaction != null) { - return new SqlServerBulkUpdate(Connection as SqlConnection, transaction as SqlTransaction); + return new SqlServerBulkUpdate(GetConnection() as SqlConnection, transaction as SqlTransaction); } else { - return new SqlServerBulkUpdate(ConnectionString); + return new SqlServerBulkUpdate(CreateConnection()); } } + private SqlConnection CreateConnection() + { + var connection = (SqlConnection)DbProviderFactory.CreateConnection(); + connection.ConnectionString = ConnectionString; + SetAccessToken(connection); + return connection; + } + /// /// Returns instance of . /// diff --git a/src/Query/Providers/SqlServerQueryProviderFactory.cs b/src/Query/Providers/SqlServerQueryProviderFactory.cs index f1d150d..354fff6 100644 --- a/src/Query/Providers/SqlServerQueryProviderFactory.cs +++ b/src/Query/Providers/SqlServerQueryProviderFactory.cs @@ -27,7 +27,8 @@ public IQueryProvider Create(DbConnection connection, IModelBuilder modelBuilder new SqlServerSqlExpressionVisitorFactory(databaseMapper), modelBuilder, new Logger(), - databaseMapper); + databaseMapper, + null); /// /// Creates the SqlServer query provider. @@ -47,7 +48,8 @@ public IQueryProvider Create( new SqlServerSqlExpressionVisitorFactory(databaseMapper), modelBuilder, new Logger(), - databaseMapper); + databaseMapper, + null); /// /// Registers instance of this type to . diff --git a/tests/Kros.KORM.UnitTests/CommandGenerator/CommandGeneratorShould.cs b/tests/Kros.KORM.UnitTests/CommandGenerator/CommandGeneratorShould.cs index 6e5cf66..64eda63 100644 --- a/tests/Kros.KORM.UnitTests/CommandGenerator/CommandGeneratorShould.cs +++ b/tests/Kros.KORM.UnitTests/CommandGenerator/CommandGeneratorShould.cs @@ -284,7 +284,8 @@ private IQuery CreateQuery() new SqlServerSqlExpressionVisitorFactory(new DatabaseMapper(new ConventionModelMapper())), Substitute.For(), new Logger(), - Substitute.For())); + Substitute.For(), + null)); return query; } @@ -349,7 +350,8 @@ private IQuery CreateFooIdentityQuery() new SqlServerSqlExpressionVisitorFactory(new DatabaseMapper(new ConventionModelMapper())), Substitute.For(), new Logger(), - Substitute.For())); + Substitute.For(), + null)); return query; } diff --git a/tests/Kros.KORM.UnitTests/Query/Providers/QueryProviderShould.cs b/tests/Kros.KORM.UnitTests/Query/Providers/QueryProviderShould.cs index 328c33c..de2ec3a 100644 --- a/tests/Kros.KORM.UnitTests/Query/Providers/QueryProviderShould.cs +++ b/tests/Kros.KORM.UnitTests/Query/Providers/QueryProviderShould.cs @@ -76,10 +76,7 @@ private TestQueryProvider(DbConnection externalConnection) public override DbProviderFactory DbProviderFactory => _dbProviderFactory; - public void CreateConnection() - { - var connection = Connection; - } + public void CreateConnection() => GetConnection(); public override IBulkInsert CreateBulkInsert() { @@ -511,7 +508,8 @@ private static SqlServerQueryProvider CreateQueryProvider(SqlConnection connecti Substitute.For(), new ModelBuilder(Database.DefaultModelFactory), Substitute.For(), - Substitute.For()); + Substitute.For(), + null); private static SqlServerQueryProvider CreateQueryProvider(string connectionString) => new SqlServerQueryProvider( @@ -519,7 +517,8 @@ private static SqlServerQueryProvider CreateQueryProvider(string connectionStrin Substitute.For(), new ModelBuilder(Database.DefaultModelFactory), Substitute.For(), - Substitute.For()); + Substitute.For(), + null); #endregion } diff --git a/tests/Kros.KORM.UnitTests/Query/QueryShould.cs b/tests/Kros.KORM.UnitTests/Query/QueryShould.cs index dcb9ae1..4f50cd1 100644 --- a/tests/Kros.KORM.UnitTests/Query/QueryShould.cs +++ b/tests/Kros.KORM.UnitTests/Query/QueryShould.cs @@ -272,7 +272,8 @@ private IQuery CreateQuery() new SqlServerSqlExpressionVisitorFactory(mapper), Substitute.For(), new Logger(), - Substitute.For())); + Substitute.For(), + null)); return query; }