#!/usr/local/bin/python3.11

import numpy
import scipy.io.wavfile
import scipy.signal
import sys
import PIL
import subprocess
import argparse

class APT_signal(object):
    # The audio should get resampled to 20.8 kHz if not there natively
    SAMPLE_RATE = 20800

    def __init__ (self, wavIN):
        # Grab the raw .wav
        (rate, self.signal) = scipy.io.wavfile.read(wavIN)

        # Enforce mono
        mono = True
        if len(self.signal.shape) > 1:
            if self.signal.shape[1] > 1:
                mono = False
                print("Supplied signal has {}".format(self.signal.shape[1]) + " channels, downmixing to mono...")
                wavIN_MONO = wavIN[:-4] + '_MONO.wav'
                # Create the sox code for channels to remix
                channels = "1"
                for i in range(2,self.signal.shape[1]):
                    channels = channels + "," + str(i)
                # Call sox to remix
                subprocess.call(["sox", wavIN, wavIN_MONO, "remix", channels])
                (rate, self.signal) = scipy.io.wavfile.read(wavIN_MONO)
                print("Done.\n")

        # Enforce samplerate
        goodRate = True
        if rate != self.SAMPLE_RATE:
            goodRate = False
            print("Supplied signal has samplerate {}".format(rate) + " Hz, downsampling to {}".format(self.SAMPLE_RATE) + " Hz...")
            wavIN_goodRATE = wavIN[:-4] + '_goodRATE.wav'
            if mono:
                subprocess.call(["sox", wavIN, "-r", str(self.SAMPLE_RATE), wavIN_goodRATE])
            else:
                subprocess.call(["sox", wavIN_MONO, "-r", str(self.SAMPLE_RATE), wavIN_goodRATE])
            (rate, self.signal) = scipy.io.wavfile.read(wavIN_goodRATE)
            print("Done.\n")

        # Remove any files we generated along the way
        if not mono:
            subprocess.call(["rm", wavIN_MONO])
        if not goodRate:
            subprocess.call(["rm", wavIN_goodRATE])

    def decode(self, outfile=None):
        # Take the Hilbert transform
        print("Taking Hilbert transform...")
        signalHilbert = scipy.signal.hilbert(self.signal)
        print("Done.\n")

        # Median filter
        print("Taking median filter...")
        signalMed = scipy.signal.medfilt(numpy.abs(signalHilbert), 5)
        print("Done.\n")

        # Calculate how many elements off our reshaped array will be
        print("Calculating necessary truncation...")
        elementDiff = len(signalMed) - ((len(signalMed) // 5)*5)
        print("Done.\n")

        # Truncate the filtered signal to that number of elements
        print("Truncating signal...")
        signalMedTrunc = signalMed[:len(signalMed)-elementDiff]
        print("Done.\n")

        # Reshape the truncated filtered signal to have five columns
        print("Reshaping signal...")
        signalReshaped = signalMedTrunc.reshape((len(signalMed) // 5, 5))
        print("Done.\n")

        # Digitize the reshaped signal
        print("Digitizing signal...")
        signalDigitized = self._digitize(signalReshaped[:, 2])
        print("Done.\n")

        # Sync the scanlines
        print("Synchronizing scanlines...")
        matrix = self._reshape(signalDigitized)
        print("Done.\n")

        # Create image with data
        print("Forming image...")
        image = PIL.Image.fromarray(matrix)
        if not outfile is None:
            image.save(outfile)
        print("Done.\n")
        image.show()
        return matrix

    def _digitize(self, signal, plow=0.5, phigh=99.5):
        # Calculate low and high reach of signal
        (low, high) = numpy.percentile(signal, (plow, phigh))
        delta = high - low

        # Normalize the signal to px luminance values, discretize
        data = numpy.round(255 * (signal - low) / delta)
        data[data < 0] = 0
        data[data > 255] = 255
        return data.astype(numpy.uint8)

    def _reshape(self, signal):
        # We are searching for a sync frame, which will appear as seven pulses
        # with some black pixels
        syncA = [0, 128, 255, 128]*7 + [0]*7

        # Track maximum correlations found: (index, val)
        peaks = [(0, 0)]

        # There is a minimum distance between peaks, probably more than 2000
        mindistance = 2000

        # Downshift values to get meaningful correlation
        signalshifted = [x-128 for x in signal]
        syncA = [x-128 for x in syncA]
        for i in range(len(signal)-len(syncA)):
            corr = numpy.dot(syncA, signalshifted[i:i+len(syncA)])

            # If previous peak is too far, we keep it but add this value as new
            if i - peaks[-1][0] > mindistance:
                peaks.append((i, corr))
            elif corr > peaks[-1][1]:
                peaks[-1] = (i, corr)

        # Create image matrix, starting each line at the peaks
        matrix = []
        for i in range(len(peaks) - 1):
            matrix.append(signal[peaks[i][0] : peaks[i][0] + 2080])

        return numpy.array(matrix)

if __name__ == '__main__':
    # Construct the command parser
    parser = argparse.ArgumentParser()
    # Add test option
    parser.add_argument(
        "-t", "--test",
        help="test program functionality",
        action="store_true",
        dest="do_test"
    )
    # Add decode option
    parser.add_argument(
        "-d", "--decode",
        help="decode a .wav APT recording",
        action="store",
        default=None,
        dest="decode_file"
    )
    # Add output option
    parser.add_argument(
        "-o", "--output",
        help="specify output location",
        action="store",
        default=None,
        dest="out_file"
    )
    # Store arguments to variable
    parser.parse_args()
    arguments = parser.parse_args()


    # Here we check for the main flag options, in order of precedence

    # -t, --test
    if arguments.do_test:
        print("This doesn't actually do anything yet. Woo.")
    # -d, --decode
    elif arguments.decode_file != None:
        apt = APT_signal(arguments.decode_file)
        outfile = arguments.out_file
        apt.decode(outfile)
    # No flags
    else:
        print("No arguments passed.")
