/* 
  Native entries for the Fft class

  @author Jeff Schoen
  @version $Id: Fft.c,v 1.8 2010/01/28 13:09:45 jgs Exp $
*/

#include "native.h"

#include "libnmfft.c"

JNIEXPORT void JNICALL Java_nxm_sys_libm_Fft_processNative (JNIEnv *env, jclass clazz, 
	jobject array, jint size, jint flags) {
  jfloat *buf; buf = GetArray(array);
  nm_fft ((fft_cmplx*)buf, size, flags);
  ReleaseArray(array,buf);
}

JNIEXPORT jlong JNICALL Java_nxm_sys_libm_Fft_initPlan (JNIEnv *env, jclass clazz, 
	jint size, jint flags) {
  FFT_Plan *plan = (FFT_Plan *)malloc(sizeof(FFT_Plan));
  fft_init(plan,size,flags);
  if (plan->validflag != FFT_FLAGS_VALIDPLAN) {
    free(plan);
    plan = 0;
  }
  return (jlong)(jpointer)plan;
}
JNIEXPORT void JNICALL Java_nxm_sys_libm_Fft_workPlan (JNIEnv *env, jclass clazz, 
	jlong planp, jobject array) {
  FFT_Plan *plan = (FFT_Plan *)(jpointer)planp;
  jfloat *buf; buf = GetArray(array);
  fft_work((fft_cmplx *)buf,plan);
  ReleaseArray(array,buf);
}
JNIEXPORT jlong JNICALL Java_nxm_sys_libm_Fft_freePlan (JNIEnv *env, jclass clazz, 
	jlong planp) {
  FFT_Plan *plan = (FFT_Plan *)(jpointer)planp;
  fft_free(plan);
  free(plan);
  return (jlong)(jpointer)NULL;
}

#ifdef _CUDA

#include <cuda.h>
#include <cufft.h>
#include <driver_types.h>
typedef float2 Complex;
typedef struct {
  int type;
  int dir;
  int alloc;
  int xfer;
  cufftHandle handle;
  Complex *cubuf;
} CUDA_Plan;

JNIEXPORT jlong JNICALL Java_nxm_sys_libm_Fft_initPlanCU (JNIEnv *env, jclass clazz, 
	jint size, jint flags) {
  CUDA_Plan *plan = (CUDA_Plan *)malloc(sizeof(CUDA_Plan));
  plan->dir = (flags&FFT_FLAGS_INVERSE)? CUFFT_INVERSE : CUFFT_FORWARD;
  plan->type  = (flags&FFT_FLAGS_COMPLEX)? CUFFT_C2C : (flags&FFT_FLAGS_INVERSE)? CUFFT_C2R : CUFFT_R2C;
  plan->alloc = sizeof(Complex) * ((flags&FFT_FLAGS_COMPLEX)?  size : size/2+1);
  plan->xfer  = sizeof(Complex) * ((flags&FFT_FLAGS_COMPLEX)?  size : size/2);
  cufftPlan1d(&plan->handle, size, plan->type, 1);
  cudaMalloc((void**)&plan->cubuf, plan->alloc);
  return (jlong)(jpointer)plan;
}
JNIEXPORT void JNICALL Java_nxm_sys_libm_Fft_workPlanCU (JNIEnv *env, jclass clazz, 
	jlong planp, jobject array) {
  CUDA_Plan *plan = (CUDA_Plan *)(jpointer)planp;
  jfloat *buf = GetArray(array);
  //printf("CUFFT type=%d size=%d dir=%d\n",plan->type,plan->xfer,plan->dir);
  cudaMemcpy(plan->cubuf, (Complex*)buf, plan->xfer, cudaMemcpyHostToDevice);
  if (plan->type == CUFFT_C2C)
    cufftExecC2C(plan->handle, (cufftComplex *)plan->cubuf, (cufftComplex *)plan->cubuf, plan->dir);
  else if (plan->type == CUFFT_R2C)
    cufftExecR2C(plan->handle, (cufftReal *)plan->cubuf, (cufftComplex *)plan->cubuf);
  else if (plan->type == CUFFT_C2R)
    cufftExecC2R(plan->handle, (cufftComplex *)plan->cubuf, (cufftReal *)plan->cubuf);
  cudaMemcpy((Complex*)buf, plan->cubuf, plan->xfer, cudaMemcpyDeviceToHost);
  ReleaseArray(array,buf);
}
JNIEXPORT jlong JNICALL Java_nxm_sys_libm_Fft_freePlanCU (JNIEnv *env, jclass clazz, 
	jlong planp) {
  CUDA_Plan *plan = (CUDA_Plan *)(jpointer)planp;
  cudaFree(plan->cubuf);
  cufftDestroy(plan->handle);
  free(plan);
  return (jlong)(jpointer)NULL;
}

#endif
