import pygatt
import time
import csv
import matplotlib.pyplot as plt
from binascii import hexlify
import pywinusb.hid as hid

class Rabboni:
    def __init__(self,mode=None):
        self.device = None
        self.mode = mode
        if self.mode=="BLE":
            self.adapter = pygatt.BGAPIBackend()
            self.adapter.start()
        elif self.mode =="USB":
            pass
        else :
            raise ValueError("Mode must be USB or BLE")
        self.vid = 0x04d9
        self.pid = 0xb564
        self.usb_temp = []
        self.characteristics = []
        self.Status = 0
        self.Hex_data   = 0
        self.Accx = 0
        self.Accy = 0
        self.Accz = 0
        self.Gyrx = 0
        self.Gyry = 0
        self.Gyrz = 0
        self.Cnt   = 0

        self.data_num = 0
        self.Accx_list = []
        self.Accy_list = []
        self.Accz_list = []
        self.Gyrx_list = []
        self.Gyry_list = []
        self.Gyrz_list = []
        self.Cnt_list = []
        self.report = 0

    def stop(self):
        if self.mode == "BLE":
            self.Status = 0
            self.adapter.stop()
        elif self.mode == "USB":
            self.Status = 0
            close_cmd = [0x00 for i in range(33)]
            close_cmd[1] = 0x02 # Report ID
            close_cmd[2] = 0x33
            close_cmd[3] = 0x0a
            self.report[0].set_raw_data(close_cmd)
            self.report[0].send()
            if self.device:
                self.device.close()

    def scan(self, timeout=5):
        if self.mode == "BLE":
            self.devices = self.adapter.scan(timeout)
            # for dev in (self.devices):
            #     print ("Name : %s  MAC : %s"  %(dev["name"], dev["address"]))
            return self.devices

    def print_device(self):
        if self.mode == "BLE":
            for dev in (self.devices):
                print ("Name : %s  MAC : %s"  %(dev["name"], dev["address"]))
        # return self.devices

    def connect_name(self, name, devices=None):
        if self.mode == "BLE":
            if devices is None:
                devices = self.devices
            for dev in self.devices:
                if name == dev['name']:
                    return self.connect(dev['address'])
            return None

    def connect(self, address = None):
        if self.mode == "BLE":
            self.device = self.adapter.connect(address,address_type=pygatt.BLEAddressType.random)
            self.Status = 1
            return self.device
        elif self.mode == "USB":
            _filter = hid.HidDeviceFilter(vendor_id = self.vid, product_id = self.pid)
            hid_device = _filter.get_devices()
            self.Status = 1
            if len(hid_device) > 0:
                self.device = hid_device[0]
                self.device.open()
                self.report = self.device.find_output_reports()
                self.fea_report = self.device.find_feature_reports()
            else : 
                raise ValueError("Nodevice vendor_id : %s , product_id : %s" %(hex(self.vid),hex(self.pid)))
            cmd = [0x00 for i in range(33)]
            cmd[1] = 0x02
            cmd[3] = 0x0a
            send_30 = cmd.copy()
            send_30[2] = 0x30
            ### Report request(set feature)
            szBuf = [0,0,0,0,0,0,0,0,0]
            baudrate = 115200
            szBuf[1] = 0x01
            szBuf[2] = (baudrate & 0x00ff)
            szBuf[3] = (0xC2)
            szBuf[4] = (baudrate >> 16)
            szBuf[5] = (baudrate >> 24)
            szBuf[6] = 0
            szBuf[7] = 0
            szBuf[8] = 0x08
            ### Report request(set feature)
            if self.device:
                if self.fea_report:
                    self.fea_report[0].set_raw_data(szBuf)
                    bytes_num = self.fea_report[0].send()
                if self.report:
                    self.report[0].set_raw_data(send_30)
                    bytes_num = self.report[0].send()
                    time.sleep(0.5)


    def disconnect(self):
        if self.mode == "BLE":
            self.Status = 0
            self.device.disconnect()
        elif self.mode == "USB":
            self.Status = 0
            close_cmd = [0x00 for i in range(33)]
            close_cmd[1] = 0x02 # Report ID
            close_cmd[2] = 0x33
            close_cmd[3] = 0x0a
            self.report[0].set_raw_data(close_cmd)
            self.report[0].send()
            if self.device:
                self.device.close()

    def discover_characteristics(self, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            for uuid in device.discover_characteristics().keys():
                try:
                    device.char_read(uuid)
                    self.characteristics.append(
                        {'uuid': uuid, 'handle': device.get_handle(uuid), 'readable': True})
                except Exception as e:
                    if "unable to read" in str(e).lower():
                        self.characteristics.append(
                            {'uuid': uuid, 'handle': device.get_handle(uuid), 'readable': False})
                    else:
                        raise e
            # return characteristics
    def print_char(self):
        if self.mode == "BLE":
            print ("====== device.discover_characteristics() =====")
            for ch_number in range(len(self.characteristics)):
                print("Read UUID %s (handle ): %d Readable: %s" 
                                    %(self.characteristics[ch_number]['uuid'], self.characteristics[ch_number]['handle'], self.characteristics[ch_number]['readable']))
     
    def ble_subscribe(self, uuid, device=None, callback=None, indication=True):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            device.subscribe(uuid, callback, indication)

    def read_characteristics(self, uuid, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            return device.char_read(uuid)

    def read_characteristics_handle(self, handle, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            return device.char_read_handle(handle)

    def write_characteristics(self, str, uuid, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            data = map(ord, str)
            for i in range(0, len(data), 20):
                device.char_write(uuid, data[i:i + 20])

    def write_characteristics_handle(self, str, handle, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            data = map(ord, str)
            for i in range(0, len(data), 20):
                device.char_write_handle(handle, data[i:i + 20])

    def rst_count(self, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            device.char_write("0000fff6-0000-1000-8000-00805f9b34fb", bytearray([0x36]))
        elif self.mode == "USB":
            rst_cmd = [0x00 for i in range(33)]
            rst_cmd[1] = 0x36
            rst_cmd[2] =0x0a 
            close_cmd = [0x00 for i in range(33)]
            close_cmd[1] = 0x02 # Report ID
            close_cmd[2] = 0x33
            close_cmd[3] = 0x0a
            cmd = [0x00 for i in range(33)]
            cmd[1] = 0x02
            cmd[3] = 0x0a
            send_32 = cmd.copy()
            send_32[2] = 0x32
            send_30 = cmd.copy()
            send_30[2] = 0x30
            if self.device:
                if self.report:
                    self.report[0].set_raw_data(close_cmd)
                    self.report[0].send()
                    time.sleep(0.2)
                    self.report[0].set_raw_data(rst_cmd)
                    bytes_num = self.report[0].send()
                    time.sleep(0.5)
                    self.report[0].set_raw_data(send_30)
                    bytes_num = self.report[0].send()
                    self.report[0].set_raw_data(send_32)
                    bytes_num = self.report[0].send()
                    time.sleep(0.5)



    def read_callback(self,handle, value):
        if self.mode == "BLE":
            value_data = hexlify(value)
            self.Hex_data = value_data
            self.Accx = convert_acc(value_data[:4])
            self.Accy = convert_acc(value_data[4:8])
            self.Accz = convert_acc(value_data[8:12])
            self.Gyrx = convert_gyro(value_data[12:16])
            self.Gyry = convert_gyro(value_data[16:20])
            self.Gyrz = convert_gyro(value_data[20:24])
            self.Cnt  = int(value_data[24:], 16)
            self.Accx_list.append(self.Accx)
            self.Accy_list.append(self.Accy)
            self.Accz_list.append(self.Accz)
            self.Gyrx_list.append(self.Gyrx)
            self.Gyry_list.append(self.Gyry)
            self.Gyrz_list.append(self.Gyrz)
            self.Cnt_list.append(self.Cnt)
            self.data_num +=1 

    def read_callback_usb(self,value):
        if self.mode == "USB":
            self.filter_data(value)
            self.get_data()

    def print_data(self):
        time.sleep(0.1)
        print ("--------------------------------")
        print ("Acc_x : %f, Acc_y : %f, Acc_z : %f "% (self.Accx,self.Accy,self.Accz))
        print ("Gyr_x : %f, Gyr_y : %f, Gyr_z : %f "% (self.Gyrx,self.Gyry,self.Gyrz))
        print ("Count : %i"% (self.Cnt ))


    def read_data(self, device=None):
        if self.mode == "BLE":
            if device is None:
                device = self.device
            device.subscribe("00001601-0000-1000-8000-00805f9b34fb", self.read_callback)
        elif self.mode == "USB":
            cmd = [0x00 for i in range(33)]
            cmd[1] = 0x02
            cmd[3] = 0x0a
            send_32 = cmd.copy()
            send_32[2] = 0x32
            send_30 = cmd.copy()
            send_30[2] = 0x30
            if self.device:
                if self.report:
                    self.device.set_raw_data_handler(self.read_callback_usb)
                    self.report[0].set_raw_data(send_32)
                    bytes_num = self.report[0].send()


    def filter_data(self,x):
        t = []
        for i in range( len(x)):
            hex_d = hex(x[i])[2:]
            if len(hex_d)==1:
                hex_d = "0"+hex_d
            t.append(hex_d)
        self.usb_temp.extend(t[2:2+(x[1])])

    def get_data(self):
        # data = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
        for i in range(len(self.usb_temp)):
            if self.usb_temp[i] == '0e':
                if i+15 < len(self.usb_temp):
                    value_data = "".join (self.usb_temp[i+1:i+15])
                    del self.usb_temp[i:i+15]
                    # print ("data: %s_%s_%s_%s_%s_%s_%s" %(value_data[:4],value_data[4:8],value_data[8:12],value_data[12:16],value_data[16:20],value_data[20:24],value_data[24:]))
                    self.Hex_data = value_data
                    self.Accx = convert_acc(value_data[:4])
                    self.Accy = convert_acc(value_data[4:8])
                    self.Accz = convert_acc(value_data[8:12])
                    self.Gyrx = convert_gyro(value_data[12:16])
                    self.Gyry = convert_gyro(value_data[16:20])
                    self.Gyrz = convert_gyro(value_data[20:24])
                    self.Cnt  = int(value_data[24:26], 16)
                    self.Accx_list.append(self.Accx)
                    self.Accy_list.append(self.Accy)
                    self.Accz_list.append(self.Accz)
                    self.Gyrx_list.append(self.Gyrx)
                    self.Gyry_list.append(self.Gyry)
                    self.Gyrz_list.append(self.Gyrz)
                    self.Cnt_list.append(self.Cnt)
                    self.data_num +=1 
                break   
    def write_csv(self, data,file_name=None):
        if file_name ==None:
            raise("Need File_name!")
        else:
            file_name = file_name+'.csv'
            with open(file_name, 'w') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(data)
    def plot_pic(self, data, file_name=None,show = True):

        plt.plot(data)
        if file_name != None:
            plt.savefig(file_name )
        if show == True:
            plt.show()






def convert_acc(acc):
    x = int(acc,16)
    x = twos_comp(x,16)
    x = float(x)
    return x*16/32768
    
def convert_gyro(gyro):
    x = int(gyro,16)
    x = twos_comp(x,16)
    x = float(x)
    return x*2000/32768

def twos_comp(val, bits):
    """compute the 2's complement of int value val"""
    if (val & (1 << (bits - 1))) != 0: # if sign bit is set e.g., 8bit: 128-255
        val = val - (1 << bits)        # compute negative value
    return val 

