// // Functors.swift // Plinth // // Created by Daniel Clelland on 20/04/22. // import Foundation import Accelerate import Numerics extension Matrix { @inlinable public func fmap(_ function: ([Scalar]) -> A) -> A { return function(elements) } @inlinable public func fmap(_ function: ([Scalar]) -> [A]) -> Matrix { return Matrix(shape: shape, elements: function(elements)) } @inlinable public func fmap(_ function: (inout [Scalar]) -> ()) -> Matrix { var output = self function(&output.elements) return output } @inlinable public func fmap(_ function: ([Scalar], inout [A]) -> ()) -> Matrix where A: Numeric { var output = Matrix.zeros(shape: shape) function(elements, &output.elements) return output } } extension ComplexMatrix { @inlinable public func fmap(real realFunction: ([Scalar]) -> [A], imaginary imaginaryFunction: ([Scalar]) -> [A]) -> ComplexMatrix { return ComplexMatrix(real: real.fmap(realFunction), imaginary: imaginary.fmap(imaginaryFunction)) } @inlinable public func fmap(real realFunction: ([Scalar], inout [A]) -> (), imaginary imaginaryFunction: ([Scalar], inout [A]) -> ()) -> ComplexMatrix where A: Numeric { return ComplexMatrix(real: real.fmap(realFunction), imaginary: imaginary.fmap(imaginaryFunction)) } } extension ComplexMatrix where Scalar == Float { @inlinable public func fmap(_ function: (DSPSplitComplex, inout [Scalar]) -> ()) -> Matrix { var input = self var output = Matrix.zeros(shape: shape) input.withUnsafeMutableSplitComplexVector { inputVector in function(inputVector, &output.elements) } return output } } extension ComplexMatrix where Scalar == Float { @inlinable public func fmap(_ function: (inout DSPSplitComplex) -> ()) -> ComplexMatrix { var input = self input.withUnsafeMutableSplitComplexVector { inputVector in function(&inputVector) } return input } @inlinable public func fmap(_ function: (DSPSplitComplex, inout DSPSplitComplex) -> ()) -> ComplexMatrix { var input = self var output = ComplexMatrix.zeros(shape: shape) input.withUnsafeMutableSplitComplexVector { inputVector in output.withUnsafeMutableSplitComplexVector { outputVector in function(inputVector, &outputVector) } } return output } } extension ComplexMatrix where Scalar == Float { @inlinable public mutating func withUnsafeMutableSplitComplexVector(_ body: (inout DSPSplitComplex) throws -> Result) rethrows -> Result { return try real.withUnsafeMutableBufferPointer { realPointer in return try imaginary.withUnsafeMutableBufferPointer { imaginaryPointer in var split = DSPSplitComplex(realp: realPointer.baseAddress!, imagp: imaginaryPointer.baseAddress!) return try body(&split) } } } } extension ComplexMatrix where Scalar == Double { @inlinable public func fmap(_ function: (DSPDoubleSplitComplex, inout [Scalar]) -> ()) -> Matrix { var input = self var output = Matrix.zeros(shape: shape) input.withUnsafeMutableSplitComplexVector { inputVector in function(inputVector, &output.elements) } return output } } extension ComplexMatrix where Scalar == Double { @inlinable public func fmap(_ function: (inout DSPDoubleSplitComplex) -> ()) -> ComplexMatrix { var input = self input.withUnsafeMutableSplitComplexVector { inputVector in function(&inputVector) } return input } @inlinable public func fmap(_ function: (DSPDoubleSplitComplex, inout DSPDoubleSplitComplex) -> ()) -> ComplexMatrix { var input = self var output = ComplexMatrix.zeros(shape: shape) input.withUnsafeMutableSplitComplexVector { inputVector in output.withUnsafeMutableSplitComplexVector { outputVector in function(inputVector, &outputVector) } } return output } } extension ComplexMatrix where Scalar == Double { @inlinable public mutating func withUnsafeMutableSplitComplexVector(_ body: (inout DSPDoubleSplitComplex) throws -> Result) rethrows -> Result { return try real.withUnsafeMutableBufferPointer { realPointer in return try imaginary.withUnsafeMutableBufferPointer { imaginaryPointer in var split = DSPDoubleSplitComplex(realp: realPointer.baseAddress!, imagp: imaginaryPointer.baseAddress!) return try body(&split) } } } }