// // FFT.swift // Plinth // // Created by Daniel Clelland on 26/04/22. // import Foundation import Accelerate public enum FFT { public typealias Setup = OpaquePointer } extension FFT where Scalar == Float { public static func createSetup(shape: Shape) -> Setup { let log2N = vDSP_Length(log2(Scalar(shape.length))) return vDSP_create_fftsetup(log2N, FFTRadix(kFFTRadix2))! } public static func destroySetup(_ setup: Setup) { vDSP_destroy_fftsetup(setup) } } extension FFT where Scalar == Double { public static func createSetup(shape: Shape) -> Setup { let log2N = vDSP_Length(log2(Scalar(shape.length))) return vDSP_create_fftsetupD(log2N, FFTRadix(kFFTRadix2))! } public static func destroySetup(_ setup: Setup) { vDSP_destroy_fftsetupD(setup) } } extension FFTDirection { public static let forward = FFTDirection(kFFTDirection_Forward) public static let inverse = FFTDirection(kFFTDirection_Inverse) } extension Matrix where Scalar == Float { public func fft(setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).fft(setup: setup) } public func ifft(setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).ifft(setup: setup) } } extension Matrix where Scalar == Float { public func fft(direction: FFTDirection, setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).fft(direction: direction, setup: setup) } } extension ComplexMatrix where Scalar == Float { public func fft(setup: FFT.Setup? = nil) -> ComplexMatrix { return fft(direction: .forward, setup: setup) } public func ifft(setup: FFT.Setup? = nil) -> ComplexMatrix { return fft(direction: .inverse, setup: setup) / Scalar(shape.count) } } extension ComplexMatrix where Scalar == Float { public func fft(direction: FFTDirection, setup: FFT.Setup? = nil) -> ComplexMatrix { switch setup { case .none: return fft(direction: direction) case .some(let setup): return fft(direction: direction, setup: setup) } } } extension ComplexMatrix where Scalar == Float { private func fft(direction: FFTDirection) -> ComplexMatrix { let setup = FFT.createSetup(shape: shape) let output = fft(direction: direction, setup: setup) FFT.destroySetup(setup) return output } private func fft(direction: FFTDirection, setup: FFT.Setup) -> ComplexMatrix { let log2N0 = vDSP_Length(log2(Scalar(shape.columns))) let log2N1 = vDSP_Length(log2(Scalar(shape.rows))) return fmap { inputVector in vDSP_fft2d_zip(setup, &inputVector, 1, 0, log2N0, log2N1, direction) } } } extension Matrix where Scalar == Double { public func fft(setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).fft(setup: setup) } public func ifft(setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).ifft(setup: setup) } } extension Matrix where Scalar == Double { public func fft(direction: FFTDirection, setup: FFT.Setup? = nil) -> ComplexMatrix { return ComplexMatrix(real: self).fft(direction: direction, setup: setup) } } extension ComplexMatrix where Scalar == Double { public func fft(setup: FFT.Setup? = nil) -> ComplexMatrix { return fft(direction: .forward, setup: setup) } public func ifft(setup: FFT.Setup? = nil) -> ComplexMatrix { return fft(direction: .inverse, setup: setup) / Scalar(shape.count) } } extension ComplexMatrix where Scalar == Double { public func fft(direction: FFTDirection, setup: FFT.Setup? = nil) -> ComplexMatrix { switch setup { case .none: return fft(direction: direction) case .some(let setup): return fft(direction: direction, setup: setup) } } } extension ComplexMatrix where Scalar == Double { private func fft(direction: FFTDirection) -> ComplexMatrix { let setup = FFT.createSetup(shape: shape) let output = fft(direction: direction, setup: setup) FFT.destroySetup(setup) return output } private func fft(direction: FFTDirection, setup: FFT.Setup) -> ComplexMatrix { let log2N0 = vDSP_Length(log2(Scalar(shape.columns))) let log2N1 = vDSP_Length(log2(Scalar(shape.rows))) return fmap { inputVector in vDSP_fft2d_zipD(setup, &inputVector, 1, 0, log2N0, log2N1, direction) } } }