简体   繁体   中英

How to do SQL Server impersonation with Entity Framework Core?

I have a need to do SQL Server impersonation in an existing EF Core project. I got this working (sort of). Currently whenever accessing any property in the DbContext there is an explicit call to a function that checks the current connection state of the context, and attempts to connect if necessary.

My idea was whenever the connection is opened via this method I revert any impersonation in place (I assume since the connection is shared this will likely be necessary?). To do so I simply send a DbCommand with REVERT as it's command text. This on it's own works fine.

After the REVERT I then do a check to see if impersonation is required for the current request. If it is I run a EXECUTE AS USER = @disguise . And this works as well, sort of.

In the same request there is no issue and it seems like further queries are indeed using the impersonated USER. However MOST times (not always?) on the next request the very first query (doesn't seem to matter what the query is) I get the following error:

Cannot continue the execution because the session is in the kill state.
A severe error occurred on the current command. The results, if any, should be discarded.

When I reverse the order though running the EXECUTE AS... statement first before the REVERT , no error occurs at all (although of course no statement actually uses the impersonated context I need). So I don't think just impersonating is an issue on it's own. The only difference I can think of is that in the case where it breaks I'm allowing EF to do all of it's own background things in the impersonated context.

Does anyone understand why this error is occurring? My best guess is something EF is doing when releasing a connection back to the connection pool or recovering a connection from the connection pool isn't playing nice with the impersonated context?

More directly I'm looking for a way to get EF to play nice with impersonation or else more troubleshooting steps I can take to investigate further.

EDIT: adding my DbContext class here for reference

namespace CM.App.Models
{
    public class AppDataContext : DbContext
    {
        private DataContextUser dataContextUser;
        private bool impersonationSet = false;

        public bool HasAdminAccess() => Execute("SELECT IS_ROLEMEMBER('CM_Admin')", (row) => row.GetInt32(0) == 1).Single();

        private DbSet<Report> reports;

        public DbSet<Report> Reports
        {
            get
            {
                OpenConnection();
                return reports;
            }
            set
            {
                this.reports = value;
            }
        }

        private DbSet<Action> actions;
        public DbSet<Action> Actions
        {
            get
            {
                OpenConnection();
                return actions;
            }
            set
            {
                this.actions = value;
            }
        }

        private DbSet<UserSettings> userSettings;
        public DbSet<UserSettings> UserSettings
        {
            get
            {
                OpenConnection();
                return userSettings;
            }
            set
            {
                this.userSettings = value;
            }
        }

        private DbSet<UserStaticReportConfiguration> userStaticReportConfigurations;
        public DbSet<UserStaticReportConfiguration> UserStaticReportConfigurations
        {
            get
            {
                OpenConnection();
                return userStaticReportConfigurations;
            }
            set
            {
                this.userStaticReportConfigurations = value;
            }
        }

        public AppDataContext() { }
        public AppDataContext(DbContextOptions<AppDataContext> options) : base(options) { }

        protected override void OnModelCreating(ModelBuilder modelBuilder)
        {
            modelBuilder.Entity<UserStaticReportConfiguration>()
                .HasIndex(usrc => new { usrc.DbUserName, usrc.Key }).IsUnique();
        }

        public void SetDataContextUser(DataContextUser dataContextUser)
        {
            this.dataContextUser = dataContextUser;
            impersonationSet = false;
            OpenConnection();
        }

        private DataContextException getDataContextException(SqlException sqlException)
        {
            string message = string.Empty;

            foreach (object objErr in sqlException.Errors)
            {
                SqlError err = objErr as SqlError;
                if (message.Length > 0)
                    message += "\n";
                message += err.Message;
            }

            if (string.IsNullOrEmpty(message))
                message = sqlException.Message;

            return DataContextException.GetDataContextException(message);
        }

        public async Task OpenConnectionAsync()
        {
            switch (Database.GetDbConnection().State)
            {
                case ConnectionState.Closed:
                case ConnectionState.Broken:
                case ConnectionState.Connecting:
                    impersonationSet = false;
                    try
                    {
                        await Database.OpenConnectionAsync();
                    }
                    catch (SqlException sqlException)
                    {
                        throw getDataContextException(sqlException);
                    }
                    break;
                default:
                    return;
            }

            if (!impersonationSet && dataContextUser != null && dataContextUser.IsFacade)
            {
                impersonationSet = true;
                await ExecuteAsync("EXECUTE AS USER = @disguise", new Dictionary<string, object> { { "@disguise", dataContextUser.DbUsername } });
            }

            return;
        }

        public void OpenConnection() =>
            OpenConnectionAsync().Wait();
        public void CloseConnection() =>
            Database.CloseConnection();

        public async Task<int> ExecuteAsync(string commandText, Dictionary<string, object> parameters = null, Action<DbCommand> prep = null)
        {
            if (commandText == null)
                throw new ArgumentNullException(nameof(commandText));

            using (DbCommand cmd = Database.GetDbConnection().CreateCommand())
            {
                cmd.CommandText = commandText;
                if (parameters != null)
                {
                    foreach (KeyValuePair<string, object> kvp in parameters)
                    {
                        DbParameter param = cmd.CreateParameter();
                        param.ParameterName = kvp.Key;
                        param.Value = kvp.Value ?? DBNull.Value;
                        cmd.Parameters.Add(param);
                    }
                }
                if (prep != null)
                    prep(cmd);
                await OpenConnectionAsync();
                try
                {
                    return await cmd.ExecuteNonQueryAsync();
                }
                catch (SqlException sqlException)
                {
                    throw getDataContextException(sqlException);
                }
            }
        }
        public async Task<List<T>> ExecuteAsync<T>(string commandText, Func<DbDataReader, T> rowReader, Dictionary<string, object> parameters = null, Action<DbCommand> prep = null)
        {
            if (commandText == null)
                throw new ArgumentNullException(nameof(commandText));
            if (rowReader == null)
                throw new ArgumentNullException(nameof(rowReader));
            List<T> ret = null;
            using (DbCommand cmd = Database.GetDbConnection().CreateCommand())
            {
                cmd.CommandText = commandText;
                if (parameters != null)
                {
                    foreach (KeyValuePair<string, object> kvp in parameters)
                    {
                        DbParameter param = cmd.CreateParameter();
                        param.ParameterName = kvp.Key;
                        param.Value = kvp.Value ?? DBNull.Value;
                        cmd.Parameters.Add(param);
                    }
                }
                if (prep != null)
                    prep(cmd);

                await OpenConnectionAsync();
                try
                {
                    using (DbDataReader reader = await cmd.ExecuteReaderAsync())
                    {
                        ret = new List<T>();
                        while (reader.Read())
                        {
                            ret.Add(rowReader(reader));
                        }
                    }
                }
                catch (SqlException sqlException)
                {
                    throw getDataContextException(sqlException);
                }
            }

            return ret;
        }
        public int Execute(string commandText, Dictionary<string, object> parameters = null, Action<DbCommand> prep = null)
        {
            try
            {
                return ExecuteAsync(commandText, parameters, prep: prep).Result;
            }
            catch (AggregateException ex)
            {
                if (ex.InnerExceptions.FirstOrDefault() is DataContextException && !ex.InnerExceptions.Skip(1).Any())
                    throw ex.InnerExceptions.Single();
                throw ex;
            }
        }
        public List<T> Execute<T>(string commandText, Func<DbDataReader, T> rowReader, Dictionary<string, object> parameters = null, Action<DbCommand> prep = null)
        {
            try
            {
                return ExecuteAsync(commandText, rowReader, parameters, prep).Result;
            }
            catch (AggregateException ex)
            {
                if (ex.InnerExceptions.FirstOrDefault() is DataContextException && !ex.InnerExceptions.Skip(1).Any())
                    throw ex.InnerExceptions.Single();
                throw ex;
            }
        }

    }
}

EDIT 2: additional context

Our users have username/passwords to sign in to a webservice that then uses the username/password as the actual connection credentials to a backend database. Technical users can directly connect to the database if they wish, but the server provides a more non-technical friendly view of the data. The impersonation is used for admins who need to help users understand why they are seeing the data they see. There is RLS in place that restricts what data is visible to each user so we felt SQL Server level impersonation would be the truest view of what another user would see.

Ok one solution I found, although I don't like it, is to keep any connections utilizing impersonation from being pooled. I didn't realize this was possible, and I found it in the EFCore source code .

Basically I just added SqlConnection.ClearPool((SqlConnection)Database.GetDbConnection()); before my EXECUTE AS ... statement and removed the REVERT (as the connection is no longer shared, it can't be in the wrong context moving forward).

The only thing I don't like about this is I never found out what the real underlying issue EF (or the underlying providers) was having with my impersonated connection state.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM