Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion MoltenVK/MoltenVK/GPUObjects/MVKPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class MVKGraphicsPipeline : public MVKPipeline {
MVKSmallVector<MVKZeroDivisorVertexBinding> _zeroDivisorVertexBindings;
MVKSmallVector<MVKShaderStage> _stagesUsingPhysicalStorageBufferAddressesCapability;
MVKSmallVector<uint32_t, kMVKDefaultAttachmentCount> _colorAttachmentLocations;
std::unordered_map<uint32_t, bool> _fragmentOutputIsFloat; // location -> true if shader outputs float
std::unordered_map<uint32_t, id<MTLRenderPipelineState>> _multiviewMTLPipelineStates;
MVKStaticBitSet<kMVKMaxBufferCount> _vkVertexBuffers;
MVKStaticBitSet<kMVKMaxBufferCount> _mtlVertexBuffers;
Expand Down Expand Up @@ -547,9 +548,10 @@ class MVKRenderPipelineCompiler : public MVKMetalCompiler {

#pragma mark Construction

MVKRenderPipelineCompiler(MVKVulkanAPIDeviceObject* owner) : MVKMetalCompiler(owner) {
MVKRenderPipelineCompiler(MVKVulkanAPIDeviceObject* owner, bool suppressErrors = false) : MVKMetalCompiler(owner) {
_compilerType = "Render pipeline";
_pPerformanceTracker = &getPerformanceStats().shaderCompilation.pipelineCompile;
_suppressErrors = suppressErrors;
}

~MVKRenderPipelineCompiler() override;
Expand Down
71 changes: 71 additions & 0 deletions MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,62 @@ static void addCommonImplicitBuffersToShaderConfig(SPIRVToMSLConversionConfigura
}
addPrevStageOutputToShaderConversionConfig(shaderConfig, shaderOutputs);

// Add fragment output format info so SPIRV-Cross generates correct output types
// for integer render targets (e.g. RGBA16Uint needs uint4, not float4)
{
const VkPipelineRenderingCreateInfo* pRendInfo = getRenderingCreateInfo(pCreateInfo);
if (_isRasterizingColor && pCreateInfo->pColorBlendState) {
for (uint32_t caIdx = 0; caIdx < pCreateInfo->pColorBlendState->attachmentCount; caIdx++) {
uint32_t caLoc = _colorAttachmentLocations[caIdx];
if (caLoc == VK_ATTACHMENT_UNUSED) { continue; }
VkFormat vkFmt = pRendInfo->pColorAttachmentFormats[caIdx];
if (!vkFmt) { continue; }
mvk::MSLShaderInterfaceVariable fo;
auto& fosv = fo.shaderVar;
fosv.location = caLoc;
fosv.component = 0;
fosv.builtin = spv::BuiltInMax;
fosv.vecsize = 4;
switch (getPixelFormats()->getFormatType(vkFmt)) {
case kMVKFormatColorUInt8:
fosv.format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
break;
case kMVKFormatColorUInt16:
fosv.format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
break;
case kMVKFormatColorHalf:
case kMVKFormatColorInt16:
fosv.format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
break;
default:
fosv.format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
break;
}
if (fosv.format != MSL_SHADER_VARIABLE_FORMAT_OTHER) {
shaderConfig.shaderOutputs.push_back(fo);
}
}
}
}

// Reflect fragment shader outputs to know their base types (float vs uint).
{
SPIRVShaderOutputs fragOutputs;
std::string errorLog;
if (getShaderOutputs(_fragmentModule->getSPIRV(), spv::ExecutionModelFragment,
pFragmentSS->pName, fragOutputs, errorLog)) {
for (auto& fo : fragOutputs) {
if (fo.builtin == spv::BuiltInMax && fo.isUsed) {
bool isFloat = (fo.baseType == SPIRV_CROSS_NAMESPACE::SPIRType::Float ||
fo.baseType == SPIRV_CROSS_NAMESPACE::SPIRType::Half ||
fo.baseType == SPIRV_CROSS_NAMESPACE::SPIRType::Double);
_fragmentOutputIsFloat[fo.location] = isFloat;
}
}
}

}

MVKMTLFunction func = getMTLFunction(shaderConfig, pFragmentSS, pFragmentFB, _fragmentModule, "Fragment");
id<MTLFunction> mtlFunc = func.getMTLFunction();
plDesc.fragmentFunction = mtlFunc;
Expand Down Expand Up @@ -1933,6 +1989,21 @@ static void addCommonImplicitBuffersToShaderConfig(SPIRVToMSLConversionConfigura
if (caLoc == VK_ATTACHMENT_UNUSED) { continue; }

MTLPixelFormat mtlPixFmt = getPixelFormats()->getMTLPixelFormat(pRendInfo->pColorAttachmentFormats[caIdx]);

// Per-location format fix: swap integer to float only where the shader outputs float.
auto fragIt = _fragmentOutputIsFloat.find(caLoc);
if (fragIt != _fragmentOutputIsFloat.end() && fragIt->second) {
switch (mtlPixFmt) {
case MTLPixelFormatRGBA16Uint: case MTLPixelFormatRGBA16Sint: mtlPixFmt = MTLPixelFormatRGBA16Float; break;
case MTLPixelFormatRGBA32Uint: case MTLPixelFormatRGBA32Sint: mtlPixFmt = MTLPixelFormatRGBA32Float; break;
case MTLPixelFormatRG16Uint: case MTLPixelFormatRG16Sint: mtlPixFmt = MTLPixelFormatRG16Float; break;
case MTLPixelFormatR16Uint: case MTLPixelFormatR16Sint: mtlPixFmt = MTLPixelFormatR16Float; break;
case MTLPixelFormatRGBA8Uint: mtlPixFmt = MTLPixelFormatRGBA8Unorm; break;
case MTLPixelFormatRGBA8Sint: mtlPixFmt = MTLPixelFormatRGBA8Snorm; break;
default: break;
}
}

MTLRenderPipelineColorAttachmentDescriptor* colorDesc = plDesc.colorAttachments[caLoc];
colorDesc.pixelFormat = mtlPixFmt;

Expand Down
1 change: 1 addition & 0 deletions MoltenVK/MoltenVK/GPUObjects/MVKRenderPass.mm
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ - (void)setOpenGLModeEnabled:(BOOL)enabled;
false, loadOverride)) {
mtlColorAttDesc.clearColor = pixFmts->getMTLClearColor(clearValues[clrRPAttIdx].color, clrMVKRPAtt->getFormat());
}

if (isMultiview()) {
uint32_t startView = getFirstViewIndexInMetalPass(passIdx);
if (mtlColorAttDesc.texture.textureType == MTLTextureType3D)
Expand Down
1 change: 1 addition & 0 deletions MoltenVK/MoltenVK/GPUObjects/MVKSync.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ class MVKMetalCompiler : public MVKBaseDeviceObject {
uint64_t _startTime = 0;
bool _isCompileDone = false;
bool _isDestroyed = false;
bool _suppressErrors = false;
std::mutex _completionLock;
std::condition_variable _blocker;
std::string _compilerType = "Unknown";
Expand Down
1 change: 1 addition & 0 deletions MoltenVK/MoltenVK/GPUObjects/MVKSync.mm
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ VkResult mvkWaitSemaphores(MVKDevice* device,
}

void MVKMetalCompiler::handleError() {
if (_suppressErrors) return;
_owner->setConfigurationResult(reportError(VK_ERROR_INITIALIZATION_FAILED,
"%s compile failed (Error code %li):\n%s.",
_compilerType.c_str(), (long)_compileError.code,
Expand Down
Loading