diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaCatalog.scala b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaCatalog.scala index 8e61187283b..00653003493 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaCatalog.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/catalog/DeltaCatalog.scala @@ -146,8 +146,19 @@ class DeltaCatalog extends DelegatingCatalogExtension .getOrElse(spark.sessionState.catalog.defaultTablePath(id)) val storage = DataSource.buildStorageFormatFromOptions(writeOptions) .copy(locationUri = Option(loc)) - val tableType = - if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED + // PROP_IS_MANAGED_LOCATION indicates that the table location is not user-specified but + // system-generated. The table should be created as managed table in this case. + val isManagedLocation = Option(allTableProperties.get(TableCatalog.PROP_IS_MANAGED_LOCATION)) + .exists(_.equalsIgnoreCase("true")) + // Note: Spark generates the table location for managed tables in + // `DeltaCatalog#delegate#createTable`, so `isManagedLocation` should never be true if + // Unity Catalog is not involved. For safety we also check `isUnityCatalog` here. + val respectManagedLoc = isUnityCatalog || org.apache.spark.util.Utils.isTesting + val tableType = if (location.isEmpty || (isManagedLocation && respectManagedLoc)) { + CatalogTableType.MANAGED + } else { + CatalogTableType.EXTERNAL + } val commentOpt = Option(allTableProperties.get("comment")) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/CustomCatalogSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/CustomCatalogSuite.scala index 3578d66c5cc..79305a230d0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/CustomCatalogSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/CustomCatalogSuite.scala @@ -296,6 +296,22 @@ class CustomCatalogSuite extends QueryTest with SharedSparkSession } } } + + test("custom catalog that generates location for managed tables") { + // Reset catalog manager so that the new `spark_catalog` implementation can apply. + spark.sessionState.catalogManager.reset() + withSQLConf("spark.sql.catalog.spark_catalog" -> classOf[DummySessionCatalog].getName) { + withTable("t") { + withTempPath { path => + sql(s"CREATE TABLE t (id LONG) USING delta TBLPROPERTIES (fakeLoc='$path')") + val t = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog] + .loadTable(Identifier.of(Array("default"), "t")) + // It should be a managed table. + assert(!t.properties().containsKey(TableCatalog.PROP_EXTERNAL)) + } + } + } + } } class DummyCatalog extends TableCatalog { @@ -396,9 +412,10 @@ class DummySessionCatalogInner extends DelegatingCatalogExtension { } // A dummy catalog that adds a layer between DeltaCatalog and the Spark SessionCatalog, -// to attach additional table storage properties after the table is loaded. +// to attach additional table storage properties after the table is loaded, and generates location +// for managed tables. class DummySessionCatalog extends TableCatalog { - private var deltaCatalog: DelegatingCatalogExtension = null + private var deltaCatalog: DeltaCatalog = null override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { val inner = new DummySessionCatalogInner() @@ -421,7 +438,16 @@ class DummySessionCatalog extends TableCatalog { schema: StructType, partitions: Array[Transform], properties: java.util.Map[String, String]): Table = { - deltaCatalog.createTable(ident, schema, partitions, properties) + if (!properties.containsKey(TableCatalog.PROP_EXTERNAL) && + !properties.containsKey(TableCatalog.PROP_LOCATION)) { + val newProps = new java.util.HashMap[String, String] + newProps.putAll(properties) + newProps.put(TableCatalog.PROP_LOCATION, properties.get("fakeLoc")) + newProps.put(TableCatalog.PROP_IS_MANAGED_LOCATION, "true") + deltaCatalog.createTable(ident, schema, partitions, newProps) + } else { + deltaCatalog.createTable(ident, schema, partitions, properties) + } } override def alterTable(ident: Identifier, changes: TableChange*): Table = {