Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 42 additions & 55 deletions pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,70 +122,64 @@ T init(void) const {
};

/*!
* \brief AvgPoolFunction: Implementing avg pool
* \brief AccPoolFunction: Implementing accumulation pool
*
* This class inherits from the generic Poolfunction to implement accumulation Pool
*
* This class inherits from the generic Poolfunction to implement Average Pool
*
* \tparam TA Datatype of the internal accumulation in the avg pool function
* \tparam TO Datatype of the output generated by the avg pool function
* \tparam size Value used as divisor on the accumulator to generate output

* \tparam size Unused
*
*/
template<typename TA, typename TO, unsigned size>
class AvgPoolFunction : public PoolFunction<TA, TO, size> {
template <typename TA, unsigned size>
class AccPoolFunction : public PoolFunction<TA, TA, size>
{
public:
/*!
* \brief pool: computes the sum
*
* \param input Input value to be used in the avg pool function
* \param accu Accumulation value already computed in previous iterations
*/
TA pool(TA const &input, TA const &accu) const{
/*!
* \brief pool: computes the sum
*
* \param input Input value to be used in the avg pool function
* \param accu Accumulation value already computed in previous iterations
*/
TA pool(TA const &input, TA const &accu) const
{
#pragma HLS inline
return comp::add<TA, TA, TA>()(input,accu);
return comp::add<TA, TA, TA>()(input, accu);
}
/*!
* \brief activate: compute the output of the avg pooling algorithm
*
* \param accu Accumulation value already computed in previous iterations
*/
TO activate(TA const &accu) const {
/*!
* \brief activate: compute the output of the max pooling algorithm
*
* \param accu Accumulation value already computed in previous iterations
*/
TA activate(TA const &accu) const
{
#pragma HLS inline
return (accu/size);
return accu;
}
};

/*!
* \brief AccPoolFunction: Implementing accumulation pool
* \brief AvgPoolFunction: Implementing avg pool
*
* This class inherits from the generic Poolfunction to implement accumulation Pool
* This class inherits from the generic Poolfunction to implement Average Pool
*
* \tparam TA Datatype of the internal accumulation in the avg pool function

* \tparam size Unused
* \tparam TO Datatype of the output generated by the avg pool function
* \tparam size Value used as divisor on the accumulator to generate output
*
*/
template<typename TA, unsigned size>
class AccPoolFunction : public PoolFunction<TA, TA, size> {
template <typename TA, typename TO, unsigned size>
class AvgPoolFunction : public AccPoolFunction<TA, size>
{
public:
/*!
* \brief pool: computes the sum
*
* \param input Input value to be used in the avg pool function
* \param accu Accumulation value already computed in previous iterations
*/
TA pool(TA const &input, TA const &accu) const{
#pragma HLS inline
return comp::add<TA, TA, TA>()(input,accu);
}
/*!
* \brief activate: compute the output of the max pooling algorithm
* \brief activate: compute the output of the avg pooling algorithm
*
* \param accu Accumulation value already computed in previous iterations
*/
TA activate(TA const &accu) const {
*/
TO activate(TA const &accu) const {
#pragma HLS inline
return accu;
return (accu / size);
}
};

Expand All @@ -199,29 +193,22 @@ class AccPoolFunction : public PoolFunction<TA, TA, size> {
* \tparam TA Datatype of the internal accumulation in the quant avg pool function
* \tparam TO Datatype of the output generated by the quant avg pool function
* \tparam size Number of right shifts applied to generate output
* \tparam KernelSize Size of the kernel used in the quant avg pool function
*
*/
template<typename TA, typename TO, unsigned size>
class QuantAvgPoolFunction : public PoolFunction<TA, TO, size> {
template <typename TA, typename TO, unsigned size, unsigned KernelSize>
class QuantAvgPoolFunction : public AvgPoolFunction<TA, TO, KernelSize>
{
static_assert(KernelSize > 0, "KernelSize must be greater than 0");
public:
/*!
* \brief pool: computes the sum
*
* \param input Input value to be used in the avg pool function
* \param accu Accumulation value already computed in previous iterations
*/
TA pool(TA const &input, TA const &accu) const{
#pragma HLS inline
return input + accu;
}
/*!
* \brief activate: compute the output of the quant avg pooling algorithm
*
* \param accu Accumulation value already computed in previous iterations
*/
TO activate(TA const &accu) const {
#pragma HLS inline
return TO(accu>>size);
return TO(AvgPoolFunction<TA, TO, KernelSize>::activate(accu) >> size); // Right shift of Trunc Node
}
};

Expand Down