Fast MDCT/IMDCT Based on Forward FFT
Doing audio, especially audio coding, and looking for a fast and easy to use C code of Modified Discrete Cosine Transform (MDCT) and its inverse IMDCT? Here it is. It works well for me. And I hope it works well for you.
A special flavor of the code is that both the MDCT and IMDCT are based on forward FFT. This is possible bacause the MDCT/IMDCT are in fact the shifted DCT-IV/IDCT-IV, whose transformation matrice is symmetric and othorgonal (thus, the DCT-IV and the IDCT-IV are identical up to scaling). Some small code space and data space might be saved this way. No big deal generally, but it might be worthy of the effort on embedded systems.
Another special flavor of the code is that the MDCT and IMDCT here are normalized, which means, globally, the power of time samples is equal to that of spectral coefficients, provided that a window satisfying the Princen-Bradley condition of perfect reconstruction (e.g., sine and KBD) is applied. However, temporal power and spectral power of a single transform are not neccessarily equal, due to the intrinsic property (and allure) of Time Domain Alasing Cancellation (TDAC).
The following MDCT and IMDCT code contains two files: mdct.h and mdct.c. To use the code in your project, you just need to
- First, initialize a 'mdct_plan' with size M (the number of MDCT coefficients)
mdct_plan* m_plan = mdct_init(M);
- Then, call mdct/imdct as many times as required
mdct(freq, time, m_plan); // out: freq[0,...,M-1]; in: time[0,...,2M-1]
or
imdct(time, freq, m_plan); // out: time[0,...,2M-1]; in: freq[0,...,M-1]
- Finally, free up the 'mdct_plan'
mdct_free(m_plan);
All of the above is examplified in the following "mdct_test.c". To build 'mdct_test' in double precision float, run
gcc -o mdct_test -O2 mdct_test.c mdct.c -lfftw3 -lm;
to build 'mdct_test' in single precision float, run
gcc -o mdct_test -O2 -DSINGLE_PRECISION mdct_test.c mdct.c -lfftw3f -lm
- Pre-twiddling: folding and arranging the 2M-point time input into M/2-point complex input;
- Forward FFT: M/2-point complex;
- Post-twiddling: unfolding the resulting M/2-point complex FFT spectrum into the M-point MDCT spectrum.
Happy coding:)
/******** begin of mdct.h ******** */
#ifndef __MDCT_H
#define __MDCT_H
#include <fftw3.h>
#ifdef __cplusplus
extern "C" {
#endif
#ifdef SINGLE_PRECISION
typedef float FLOAT;
typedef fftwf_complex FFTW_COMPLEX;
typedef fftwf_plan FFTW_PLAN;
#else // DOUBLE_PRECISION
typedef double FLOAT;
typedef fftw_complex FFTW_COMPLEX;
typedef fftw_plan FFTW_PLAN;
#endif // SINGLE_PRECISION
typedef struct {
int M; // MDCT spectrum size (number of bins)
FLOAT* twiddle; // twiddle factor
FFTW_COMPLEX* fft_in; // fft workspace, input
FFTW_COMPLEX* fft_out; // fft workspace, output
FFTW_PLAN fft_plan; // fft configuration
} mdct_plan;
mdct_plan* mdct_init(int M); // MDCT spectrum size (number of bins)
void mdct_free(mdct_plan* m_plan);
void mdct(FLOAT* mdct_line, FLOAT* time_signal, mdct_plan* m_plan);
void imdct(FLOAT* time_signal, FLOAT* mdct_line, mdct_plan* m_plan);
#ifdef __cplusplus
}
#endif
#endif // __MDCT_H
/******** end of mdct.h ******** */
/******** begin of mdct.c ******** */
#ifdef SINGLE_PRECISION
#define FFTW_MALLOC fftwf_malloc
#define FFTW_FREE fftwf_free
#define FFTW_PLAN_1D fftwf_plan_dft_1d
#define FFTW_DESTROY fftwf_destroy_plan
#define FFTW_EXECUTE fftwf_execute
#else // DOUBLE_PRECISION
#define FFTW_MALLOC fftw_malloc
#define FFTW_FREE fftw_free
#define FFTW_PLAN_1D fftw_plan_dft_1d
#define FFTW_DESTROY fftw_destroy_plan
#define FFTW_EXECUTE fftw_execute
#endif // SINGLE_PRECISION
void mdct_free(mdct_plan* m_plan)
{
if(m_plan)
{
FFTW_DESTROY(m_plan->fft_plan);
FFTW_FREE(m_plan->fft_in);
FFTW_FREE(m_plan->fft_out);
if(m_plan->twiddle)
free(m_plan->twiddle);
free(m_plan);
}
}
#define MDCT_CLEAUP(msg, ...) \
{fprintf(stderr, msg", %s(), %s:%d \n", \
__VA_ARGS__, __func__, __FILE__, __LINE__); \
mdct_free(m_plan); return NULL;}
mdct_plan* mdct_init(int M)
{
int n;
FLOAT alpha, omega, scale;
mdct_plan* m_plan = NULL;
if(0x00 != (M & 0x01))
MDCT_CLEAUP(" Expect an even number of MDCT coeffs, but meet %d", M);
m_plan = (mdct_plan*) malloc(sizeof(mdct_plan));
if(NULL == m_plan)
MDCT_CLEAUP(" malloc error: %s", "m_plan");
memset(m_plan, 0, sizeof(m_plan[0]));
m_plan->M = M;
m_plan->twiddle = (FLOAT*) malloc(sizeof(FLOAT) * M);
if(NULL == m_plan->twiddle)
MDCT_CLEAUP(" malloc error: %s", "m_plan->twiddle");
alpha = M_PI / (8.f * M);
omega = M_PI / M;
scale = sqrt(sqrt(2.f / M));
for(n = 0; n < (M >> 1); n++)
{
m_plan->twiddle[2*n+0] = (FLOAT) (scale * cos(omega * n + alpha));
m_plan->twiddle[2*n+1] = (FLOAT) (scale * sin(omega * n + alpha));
}
m_plan->fft_in
= (FFTW_COMPLEX*) FFTW_MALLOC(sizeof(FFTW_COMPLEX) * M >> 1);
if(NULL == m_plan->fft_in)
MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_in");
m_plan->fft_out
= (FFTW_COMPLEX*) FFTW_MALLOC(sizeof(FFTW_COMPLEX) * M >> 1);
if(NULL == m_plan->fft_out)
MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_out");
m_plan->fft_plan = FFTW_PLAN_1D(M >> 1,
m_plan->fft_in,
m_plan->fft_out,
FFTW_FORWARD,
FFTW_MEASURE);
if(NULL == m_plan->fft_plan)
MDCT_CLEAUP(" malloc error: %s", "m_plan->fft_plan");
return m_plan;
}
void mdct(FLOAT* mdct_line, FLOAT* time_signal, mdct_plan* m_plan)
{
FLOAT *xr, *xi, r0, i0;
FLOAT *cos_tw, *sin_tw, c, s;
int M, M2, M32, M52, n;
M = m_plan->M;
M2 = M >> 1;
M32 = 3 * M2;
M52 = 5 * M2;
cos_tw = m_plan->twiddle;
sin_tw = cos_tw + 1;
/* odd/even folding and pre-twiddle */
xr = (FLOAT*) m_plan->fft_in;
xi = xr + 1;
for(n = 0; n < M2; n += 2)
{
r0 = time_signal[M32-1-n] + time_signal[M32+n];
i0 = time_signal[M2+n] - time_signal[M2-1-n];
c = cos_tw[n];
s = sin_tw[n];
xr[n] = r0 * c + i0 * s;
xi[n] = i0 * c - r0 * s;
}
for(; n < M; n += 2)
{
r0 = time_signal[M32-1-n] - time_signal[-M2+n];
i0 = time_signal[M2+n] + time_signal[M52-1-n];
c = cos_tw[n];
s = sin_tw[n];
xr[n] = r0 * c + i0 * s;
xi[n] = i0 * c - r0 * s;
}
/* complex FFT of size M/2 */
FFTW_EXECUTE(m_plan->fft_plan);
/* post-twiddle */
xr = (FLOAT*) m_plan->fft_out;
xi = xr + 1;
for(n = 0; n < M; n += 2)
{
r0 = xr[n];
i0 = xi[n];
c = cos_tw[n];
s = sin_tw[n];
mdct_line[n] = - r0 * c - i0 * s;
mdct_line[M-1-n] = - r0 * s + i0 * c;
}
}
void imdct(FLOAT* time_signal, FLOAT* mdct_line, mdct_plan* m_plan)
{
FLOAT *xr, *xi, r0, i0, r1, i1;
FLOAT *cos_tw, *sin_tw, c, s;
int M, M2, M32, M52, n;
M = m_plan->M;
M2 = M >> 1;
M32 = 3 * M2;
M52 = 5 * M2;
cos_tw = m_plan->twiddle;
sin_tw = cos_tw + 1;
/* pre-twiddle */
xr = (FLOAT*) m_plan->fft_in;
xi = xr + 1;
for(n = 0; n < M; n += 2)
{
r0 = mdct_line[n];
i0 = mdct_line[M-1-n];
c = cos_tw[n];
s = sin_tw[n];
xr[n] = -i0 * s - r0 * c;
xi[n] = -i0 * c + r0 * s;
}
/* complex FFT of size M/2 */
FFTW_EXECUTE(m_plan->fft_plan);
/* odd/even expanding and post-twiddle */
xr = (FLOAT*) m_plan->fft_out;
xi = xr + 1;
for(n = 0; n < M2; n += 2)
{
r0 = xr[n];
i0 = xi[n];
c = cos_tw[n];
s = sin_tw[n];
r1 = r0 * c + i0 * s;
i1 = r0 * s - i0 * c;
time_signal[M32-1-n] = r1;
time_signal[M32+n] = r1;
time_signal[M2+n] = i1;
time_signal[M2-1-n] = -i1;
}
for(; n < M; n += 2)
{
r0 = xr[n];
i0 = xi[n];
c = cos_tw[n];
s = sin_tw[n];
r1 = r0 * c + i0 * s;
i1 = r0 * s - i0 * c;
time_signal[M32-1-n] = r1;
time_signal[-M2+n] = -r1;
time_signal[M2+n] = i1;
time_signal[M52-1-n] = i1;
}
}
/******** end of mdct.c ******** */
/******** begin of mdct_test.c ******** */
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/time.h>
#include <time.h>
#include "mdct.h"
int main(int argc, char* argv[])
{
int M, r, i;
FLOAT* time = NULL;
FLOAT* freq = NULL;
mdct_plan* m_plan = NULL;
char* precision = NULL;
struct timeval t0, t1;
long long elps;
if(3 != argc)
{
fprintf(stderr, " Usage: %s <MDCT_SPECTRUM_SIZE> <run_times> \n", argv[0]);
return -1;
}
sscanf(argv[1], "%d", &M);
sscanf(argv[2], "%d", &r);
if(NULL == (m_plan = mdct_init(M)))
return -1;
if(NULL == (time = (FLOAT*) malloc(2 * M * sizeof(FLOAT))))
return -1;
if(NULL == (freq = (FLOAT*) malloc(M * sizeof(FLOAT))))
return -1;
for(i = 0; i < 2 * M; i++)
time[i] = 2.f * rand() / RAND_MAX - 1.f;
for(i = 0; i < M; i++)
freq[i] = 2.f * rand() / RAND_MAX - 1.f;
precision = (sizeof(float) == sizeof(FLOAT))?
"single precision" : "double precision";
#if 1
gettimeofday(&t0, NULL);
for(i = 0; i < r; i++)
mdct(freq, time, m_plan);
gettimeofday(&t1, NULL);
elps = (t1.tv_sec - t0.tv_sec) * 1000000 + (t1.tv_usec - t0.tv_usec);
fprintf(stdout, "MDCT size of %d, %s, running %d times, average %.3f ms\n",
M, precision, r, (FLOAT) elps / r / 1000.f);
#endif // 0
#if 1
gettimeofday(&t0, NULL);
for(i = 0; i < r; i++)
imdct(time, freq, m_plan);
gettimeofday(&t1, NULL);
elps = (t1.tv_sec - t0.tv_sec) * 1000000 + (t1.tv_usec - t0.tv_usec);
fprintf(stdout, "IMDCT size of %d, %s, running %d times, average %.3f ms\n",
M, precision, r, (FLOAT) elps / r / 1000.f);
#endif //0
#if 0
for(i = 0; i < 2 * M; i++)
fprintf(stdout, "%f ", time[i]);
fprintf(stdout, "\n");
for(i = 0; i < M; i++)
fprintf(stdout, "%f ", freq[i]);
fprintf(stdout, "\n");
#endif // 0
free(time);
free(freq);
mdct_free(m_plan);
return 0;
}
/******** end of mdct_test.c ******** */