import MetalPerformanceShadersGraph extension MPSGraph { /// entry point for onnx -> mps mapping func onnx( node: Onnx_NodeProto, optimizedForMPS: Bool, tensorsDataType: MPSDataType, tensors: inout [String: MPSGraphTensor], constants: inout [String: Onnx_TensorProto] ) throws -> Bool { let output: MPSGraphTensor switch node.opType { case "Add": output = try arithmetic(op: .add, node, tensors) case "Sub": output = try arithmetic(op: .sub, node, tensors) case "Mul": output = try arithmetic(op: .mul, node, tensors) case "Div": output = try arithmetic(op: .div, node, tensors) case "Sqrt": output = try sqrt(node, tensors) case "Exp": output = try exp(node, tensors) case "Log": output = try log(node, tensors) case "Floor": output = try floor(node, tensors) case "Less": output = try less(node, tensors) case "Greater": output = try greater(node, tensors) case "Where": output = try whereOp(node, tensors) case "BatchNormalization": output = try batchNorm(node, tensors) case "InstanceNormalization": output = try instanceNorm(node, tensors) case "custom_group_norm": // onnx does not support group norm out of the box output = try groupNorm(node, tensors, constants) case "Concat": output = try concat(node, tensors) case "Conv": output = try conv(node, tensors, swizzled: optimizedForMPS) case "FusedConv": output = try fusedConv(node, tensors, swizzled: optimizedForMPS) case "ConvTranspose": output = try convTranspose(node, tensors) case "Gemm", "MatMul": output = try gemm(node, tensors) case "GlobalAveragePool": output = try globalPool(.avg, node, tensors) case "AveragePool": output = try pool(.avg, node, tensors) case "MaxPool": output = try pool(.max, node, tensors) case "Pad": output = try pad(node, tensors, constants) case "Reshape": output = try reshape(node, tensors, constants) case "Squeeze": output = try squeeze(node, tensors, constants) case "Unsqueeze": output = try unsqueeze(node, tensors, constants) case "Shape": output = try shape(node, tensors) case "Relu": output = try relu(node, tensors) case "PRelu", "LeakyRelu": output = try prelu(node, tensors, constants) case "Elu": output = try elu(node, tensors) case "Sigmoid": output = try sigmoid(node, tensors) case "HardSigmoid": output = try hardSigmoid(node, tensors) case "Upsample", "Resize": output = try resize(node, tensors, constants) case "Tanh": output = try tanh(node, tensors) case "Softmax": output = try softmax(node, tensors) case "Flatten": output = try flatten(node, tensors) case "Transpose": output = try permute(node, tensors) case "Slice": output = try slice(node, tensors, constants) case "ReduceMean": output = try reduceMean(node, tensors, constants) case "ReduceSum": output = try reduceSum(node, tensors, constants) case "ReduceL2": output = try reduceL2(node, tensors, constants) case "Dropout": output = try dropout(node, tensors, constants) case "DepthToSpace": output = try depthToSpace(node, tensors) case "Constant": guard let value = node.attr("value") else { throw OnnxError.invalidInput(node.name) } node.output.forEach { constants[$0] = value.t } output = try constant(value.t, targetDataType: tensorsDataType) case "Cast": output = try passthrough(node, tensors) case "Clip": output = try clip(node, tensors) case "Pow": output = try pow(node, tensors) case "Tile": output = try tile(node, tensors, constants) case "Gather": output = try gather(node, tensors, constants) case "GatherElements": output = try gatherElements(node, tensors, constants) case "Expand": output = try expand(node, tensors, constants) case "Neg": output = try neg(node, tensors) case "Split": try split(node, tensors).forEach { tensors[$0.0] = $0.1 } return true default: return false } node.output.forEach { tensors[$0] = output } return true } }