diff --git a/src/shared/drivers/dma/DMA.cpp b/src/shared/drivers/dma/DMA.cpp index 5edfcec61870ccd4a4eb55c962e596592195c2ea..aa2eea4e2157f8993bff58fb2a295e60dcdf884d 100644 --- a/src/shared/drivers/dma/DMA.cpp +++ b/src/shared/drivers/dma/DMA.cpp @@ -24,6 +24,7 @@ #include <kernel/logging.h> #include <utils/ClockUtils.h> +#include <utils/Debug.h> #include <map> @@ -307,7 +308,7 @@ bool DMADriver::tryChannel(DMADefs::DMAStreamId id) return streams.count(id) == 0; } -DMAStream* DMADriver::acquireStreamBlocking( +DMAStreamGuard DMADriver::acquireStreamBlocking( DMADefs::DMAStreamId id, DMADefs::Channel channel, const std::chrono::nanoseconds timeout) { @@ -327,7 +328,7 @@ DMAStream* DMADriver::acquireStreamBlocking( if (res == TimedWaitResult::Timeout) { // The timeout expired - return nullptr; + return DMAStreamGuard(nullptr); } } } @@ -337,10 +338,10 @@ DMAStream* DMADriver::acquireStreamBlocking( // if (streams.size() == 0) // RCC->AHB1ENR |= RCC_AHB1ENR_DMA1EN; - return streams[id] = new DMAStream(id, channel); + return DMAStreamGuard((streams[id] = new DMAStream(id, channel))); } -DMAStream* DMADriver::automaticAcquireStreamBlocking( +DMAStreamGuard DMADriver::automaticAcquireStreamBlocking( DMADefs::Peripherals peripheral, const std::chrono::nanoseconds timeout) { const auto availableStreams = @@ -360,7 +361,7 @@ DMAStream* DMADriver::automaticAcquireStreamBlocking( if (streams.count(id) == 0) { // Stream is free - return streams[id] = new DMAStream(id, channel); + return DMAStreamGuard(streams[id] = new DMAStream(id, channel)); } } @@ -375,7 +376,7 @@ DMAStream* DMADriver::automaticAcquireStreamBlocking( if (res == TimedWaitResult::Timeout) { // The timeout expired - return nullptr; + return DMAStreamGuard(nullptr); } } } @@ -679,4 +680,11 @@ DMAStream::DMAStream(DMADefs::DMAStreamId id, DMADefs::Channel channel) irqNumber = DMADefs::irqNumberMapping[static_cast<uint8_t>(id)]; } +DMAStream* DMAStreamGuard::operator->() +{ + D(assert((pStream != nullptr) && "pointer is null")); + + return pStream; +} + } // namespace Boardcore diff --git a/src/shared/drivers/dma/DMA.h b/src/shared/drivers/dma/DMA.h index 6fd45988bc063a8d03344a04a3c1e17b31d08c67..10398a705ec2f8d803b70ed12a0ca9e0ff6c6725 100644 --- a/src/shared/drivers/dma/DMA.h +++ b/src/shared/drivers/dma/DMA.h @@ -89,6 +89,7 @@ struct DMATransaction // Forward declaration class DMAStream; +class DMAStreamGuard; class DMADriver { @@ -109,9 +110,9 @@ public: * @return The pointer to the allocated stream if successful, nullptr if the * timeout expired. */ - DMAStream* acquireStreamBlocking(DMADefs::DMAStreamId id, - DMADefs::Channel channel, - const std::chrono::nanoseconds timeout); + DMAStreamGuard acquireStreamBlocking( + DMADefs::DMAStreamId id, DMADefs::Channel channel, + const std::chrono::nanoseconds timeout); /** * @brief Try to acquire a stream that is connected to the specified @@ -124,7 +125,7 @@ public: * * TODO: change name */ - DMAStream* automaticAcquireStreamBlocking( + DMAStreamGuard automaticAcquireStreamBlocking( DMADefs::Peripherals peripheral, const std::chrono::nanoseconds timeout); @@ -416,6 +417,9 @@ public: DMAStream& operator=(const DMAStream&) = delete; }; +/** + * @brief Simple RAII class to handle DMA streams. + */ class DMAStreamGuard { public: @@ -432,9 +436,16 @@ public: DMAStreamGuard(const DMAStreamGuard&) = delete; DMAStreamGuard& operator=(const DMAStreamGuard&) = delete; - DMAStream* operator->() { return pStream; } + DMAStreamGuard(DMAStreamGuard&&) noexcept = default; + DMAStreamGuard& operator=(DMAStreamGuard&&) noexcept = default; + + DMAStream* operator->(); - inline DMAStream* get() { return pStream; } + /** + * @return True if the stream was correctly allocated and + * is ready to use. False otherwise. + */ + inline bool isValid() { return pStream != nullptr; } private: DMAStream* pStream = nullptr; diff --git a/src/tests/drivers/test-dma-mem-to-mem.cpp b/src/tests/drivers/test-dma-mem-to-mem.cpp index fcfe2bfee6fe825200cfc28d9b70cecac934af7f..75153aef78ac9bf9212f0a231e413dcbc0c83a3a 100644 --- a/src/tests/drivers/test-dma-mem-to-mem.cpp +++ b/src/tests/drivers/test-dma-mem-to-mem.cpp @@ -33,10 +33,11 @@ void printBuffer(uint8_t *buffer, size_t size); int main() { - DMAStreamGuard stream(DMADriver::instance().automaticAcquireStreamBlocking( - DMADefs::Peripherals::PE_MEM_ONLY, std::chrono::seconds::zero())); + DMAStreamGuard stream = + DMADriver::instance().automaticAcquireStreamBlocking( + DMADefs::Peripherals::PE_MEM_ONLY, std::chrono::seconds::zero()); - if (stream.get() == nullptr) + if (!stream.isValid()) { printf("Error, cannot allocate dma stream\n"); return 0;