diff --git a/CHANGELOG.md b/CHANGELOG.md index b0d05e954..c32407055 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Batches retrieval of logits from the GPU when the --n-best flag is specified. - Local/global sharding with MPI training via `--sharding local` - fp16 support for factors. - Correct training with fp16 via `--fp16`. diff --git a/src/tensors/gpu/algorithm.cu b/src/tensors/gpu/algorithm.cu index 926e9deb5..ca6ae2107 100644 --- a/src/tensors/gpu/algorithm.cu +++ b/src/tensors/gpu/algorithm.cu @@ -150,5 +150,33 @@ template void swap_ranges(Ptr, float*, float*, float*); template void swap_ranges(Ptr, double*, double*, double*); // clang-format on +template +__global__ void ggatherIndices(float* d_out, T* d_in, size_t* indices, size_t indicesToGather) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + if(index < indicesToGather) { + d_out[index] = static_cast(d_in[indices[index]]); + } +} + +void gatherIndices(Ptr backend, float* d_out, float* d_in, size_t* d_indices, size_t indices_size) { + CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no)); + int threadsPerBlock = std::min(MAX_THREADS, (int)indices_size); + int blocks = (indices_size + threadsPerBlock - 1) / threadsPerBlock; + ggatherIndices<<>>(d_out, d_in, d_indices, indices_size); + CUDA_CHECK(cudaStreamSynchronize(0)); +} + +void gatherIndices(Ptr backend, float* d_out, float16* d_in, size_t* d_indices, size_t indices_size) { +#if COMPILE_FP16 + CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no)); + int threadsPerBlock = std::min(MAX_THREADS, (int)indices_size); + int blocks = (indices_size + threadsPerBlock - 1) / threadsPerBlock; + ggatherIndices<<>>(d_out, (__half*)d_in, d_indices, indices_size); + CUDA_CHECK(cudaStreamSynchronize(0)); +#else + ABORT("FP16 not supported with current hardware or CUDA version"); +#endif +} + } // namespace gpu } // namespace marian diff --git a/src/tensors/gpu/algorithm.h b/src/tensors/gpu/algorithm.h index 84f8b41ea..1ec45a1af 100644 --- a/src/tensors/gpu/algorithm.h +++ b/src/tensors/gpu/algorithm.h @@ -1,6 +1,12 @@ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "tensors/backend.h" +#include "common/types.h" namespace marian { namespace gpu { @@ -17,5 +23,9 @@ void setSparse(Ptr backend, const std::vector&, const std::vector&, float*); + +void gatherIndices(Ptr backend, float* d_out, float* d_in, size_t* d_indices, size_t indices_size); + +void gatherIndices(Ptr backend, float* d_out, float16* d_in, size_t* d_indices, size_t indices_size); } // namespace gpu } // namespace marian diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index 10c3e7f19..b1be028e6 100755 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h @@ -1,3 +1,8 @@ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "common/definitions.h" @@ -109,6 +114,39 @@ class TensorBase { return TensorBase::New(mem, Shape{1, (int)size}, type(), backend_); } + void gatherFromIndices(Tensor gatheredResults, Tensor flattenedIndices) { + ABORT_IF((flattenedIndices->type() != Type::uint64), + "Type of indices must be uint64"); + + ABORT_IF(gatheredResults->size() < flattenedIndices->size(), + "The result tensor is too small to hold all of the indexed values."); + + ABORT_IF(gatheredResults->type() != Type::float32, + "The type of the result tensor must be float32."); + + if(backend_->getDeviceId().type == DeviceType::cpu) { + float* gatheredResultsPtr = gatheredResults->data(); + size_t* flattenedIndicesPtr = flattenedIndices->data(); + float* dataToGather = data(); + + for(int i = 0; i < flattenedIndices->size(); ++i) { + gatheredResultsPtr[i] = dataToGather[flattenedIndicesPtr[i]]; + } + } + #ifdef CUDA_FOUND + else { + if (type_ == Type::float32) { + return gpu::gatherIndices(backend_, gatheredResults->data(), data(), flattenedIndices->data(), flattenedIndices->size()); + } else if(type_ == Type::float16) { + return gpu::gatherIndices(backend_, gatheredResults->data(), data(), flattenedIndices->data(), flattenedIndices->size()); + } else { + ABORT("INVALID TYPE FOR OP"); + } + } + #endif + + } + // @TODO: review if we can eliminate GPU-specific code here, // potentially by moving this to non-class members. template diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index 5c1989a68..b05e68aea 100755 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -1,4 +1,10 @@ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #include "translator/beam_search.h" +#include "tensors/tensor_allocator.h" #include "data/factored_vocab.h" #include "translator/helpers.h" @@ -40,6 +46,12 @@ Beams BeamSearch::toHyps(const std::vector& nBestKeys, // [current } } + // Hold the flattened logit indices for each state so we can batch retrieval later. Additionally, store the original batch index to we can update the hypothesis in new beams + std::vector origBatchIndices; + std::vector oldBeamHypIndices; + std::vector newBeamHypIndices; + std::vector> flattenedLogitIndices(states.size()); + for(size_t i = 0; i < nBestKeys.size(); ++i) { // [currentDimBatch, beamSize] flattened // Keys encode batchIdx, beamHypIdx, and word index in the entire beam. // They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1. @@ -123,23 +135,22 @@ Beams BeamSearch::toHyps(const std::vector& nBestKeys, // [current // Set score breakdown for n-best lists if(options_->get("n-best")) { - auto breakDown = beam[beamHypIdx]->getScoreBreakdown(); ABORT_IF(factoredVocab && factorGroup > 0 && !factoredVocab->canExpandFactoredWord(word, factorGroup), "A word without this factor snuck through to here??"); - breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0 - for(size_t j = 0; j < states.size(); ++j) { + for(uint64_t j = 0; j < states.size(); ++j) { auto lval = states[j]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab] // The flatting happens based on actual (current) batch size and batch index computed with batch-pruning as we are looking into the pruned tensor - size_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key' + uint64_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key' // @TODO: use a function on shape() to index, or new method val->at({i1, i2, i3, i4}) with broadcasting ABORT_IF(lval->shape() != Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}) && (beamHypIdx == 0 && lval->shape() != Shape({1, 1, (int)currentDimBatch, (int)vocabSize})), "Unexpected shape of logits?? {} != {}", lval->shape(), Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize})); - - breakDown[j] += lval->get(flattenedLogitIndex); + flattenedLogitIndices[j].push_back(flattenedLogitIndex); } - hyp->setScoreBreakdown(breakDown); + newBeamHypIndices.push_back(newBeam.size()); + origBatchIndices.push_back(origBatchIdx); + oldBeamHypIndices.push_back(beamHypIdx); } // Set alignments @@ -151,6 +162,36 @@ Beams BeamSearch::toHyps(const std::vector& nBestKeys, // [current newBeam.push_back(hyp); } + // We need to set the score breakdown outside of the main loop to batch requests. This avoids issuing several 4 byte memcpys when using the GPU backend. + if(options_->get("n-best")) { + Tensor indices; + Tensor logitsTensor; + allocator_->allocate(indices, {(int)flattenedLogitIndices[0].size()}, Type::uint64); + allocator_->allocate(logitsTensor, indices->shape(), Type::float32); + std::vector logits(flattenedLogitIndices[0].size()); + + for(size_t state = 0; state < states.size(); ++state) { + auto lval = states[state]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab] + indices->set(flattenedLogitIndices[state]); + lval->gatherFromIndices(logitsTensor, indices); + logitsTensor->get(logits); + + for(int i = 0; i < flattenedLogitIndices[state].size(); ++i) { + const auto originalBatchIdx = origBatchIndices[i]; + const auto beamHypIdx = oldBeamHypIndices[i]; + const auto& beam = beams[originalBatchIdx]; + auto& newBeam = newBeams[originalBatchIdx]; + + auto breakDown = beam[beamHypIdx]->getScoreBreakdown(); + breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0 + breakDown[state] += logits[i]; + newBeam[newBeamHypIndices[i]]->setScoreBreakdown(breakDown); + } + } + allocator_->free(indices); + allocator_->free(logitsTensor); + } + // if factored vocab and this is not the first factor, we need to // also propagate factored hypotheses that do not get expanded in this step because they don't have this factor if (factorGroup > 0) { @@ -261,6 +302,7 @@ Histories BeamSearch::search(Ptr graph, Ptr const auto trgUnkId = trgVocab_->getUnkId(); auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId()); + allocator_ = graph->getTensorAllocator(); for(auto scorer : scorers_) { scorer->clear(graph); diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index e2de7d243..9ac11db70 100644 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -1,3 +1,8 @@ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "marian.h" @@ -6,12 +11,14 @@ namespace marian { +class TensorAllocator; class BeamSearch { private: Ptr options_; std::vector> scorers_; size_t beamSize_; Ptr trgVocab_; + Ptr allocator_; const float INVALID_PATH_SCORE; const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.