/* PipeWire AGL Cluster IPC
 *
 * Copyright © 2021 Collabora Ltd.
 *    @author Julian Bouzas <julian.bouzas@collabora.com>
 *
 * SPDX-License-Identifier: MIT
 */

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/epoll.h>
#include <string.h>
#include <errno.h>
#include <assert.h>

#include "private.h"
#include "sender.h"

#define MAX_ASYNC_TASKS 128

typedef struct SenderTask {
        icipc_sender_reply_func_t func;
        void *data;
} SenderTask;

struct icipc_sender {
        struct sockaddr_un addr;
        int socket_fd;

        uint8_t *buffer_read;
        size_t buffer_size;

        EpollThread epoll_thread;
        bool is_connected;

        icipc_sender_lost_conn_func_t lost_func;
        void *lost_data;
        bool lost_connection;

        SenderTask async_tasks[MAX_ASYNC_TASKS];

        /* for subclasses */
        void *user_data;
};

static int push_sync_task(
                struct icipc_sender *self,
                icipc_sender_reply_func_t func,
                void *data) {
        size_t i;
        for (i = MAX_ASYNC_TASKS; i > 1; i--) {
                SenderTask *curr = self->async_tasks + i - 1;
                SenderTask *next = self->async_tasks + i - 2;
                if (next->func != NULL && curr->func == NULL) {
                        curr->func = func;
                        curr->data = data;
                        return i - 1;
                } else if (i - 2 == 0 && next->func == NULL) {
                        /* empty queue */
                        next->func = func;
                        next->data = data;
                        return 0;
                }
        }
        return -1;
}

static void pop_sync_task(
                struct icipc_sender *self,
                bool trigger,
                bool all,
                const uint8_t * buffer,
                size_t size) {
        size_t i;
        for (i = 0; i < MAX_ASYNC_TASKS; i++) {
                SenderTask *task = self->async_tasks + i;
                if (task->func != NULL) {
                        if (trigger)
                                task->func(self, buffer, size, task->data);
                        task->func = NULL;
                        if (!all)
                                return;
                }
        }
}

static void socket_event_received(EpollThread *t, int fd, void *data) {
        struct icipc_sender *self = data;

        /* receiver sends a reply, read it trigger corresponding task */
        ssize_t size =
            icipc_socket_read(fd, &self->buffer_read, &self->buffer_size);
        if (size <= 0) {
                if (size < 0)
                        icipc_log_error("sender: could not read reply: %s",
                                        strerror(errno));
                /* receiver disconnected */
                epoll_ctl(t->epoll_fd, EPOLL_CTL_DEL, fd, NULL);
                shutdown(self->socket_fd, SHUT_RDWR);
                self->is_connected = false;
                self->lost_connection = true;
                if (self->lost_func)
                        self->lost_func(self, fd, self->lost_data);
                /* clear queue */
                pop_sync_task(self, true, true, NULL, 0);
                return;
        }

        /* trigger async task */
        pop_sync_task(self, true, false, self->buffer_read, size);
        return;
}

/* API */

struct icipc_sender *icipc_sender_new(
                const char *path,
                size_t buffer_size,
                icipc_sender_lost_conn_func_t lost_func,
                void *lost_data,
                size_t user_size) {
        struct icipc_sender *self;
        int res;

        if (path == NULL)
                return NULL;

        self = calloc(1, sizeof(struct icipc_sender) + user_size);
        if (self == NULL)
                return NULL;

        self->socket_fd = -1;

        /* set address */
        self->addr.sun_family = AF_LOCAL;
        res =
            icipc_construct_socket_path(path, self->addr.sun_path,
                                        sizeof(self->addr.sun_path));
        if (res < 0)
                goto error;

        /* create socket */
        self->socket_fd =
            socket(PF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
        if (self->socket_fd < 0)
                goto error;

        /* alloc buffer read */
        self->buffer_size = buffer_size;
        self->buffer_read = calloc(buffer_size, sizeof(uint8_t));
        if (self->buffer_read == NULL)
                goto error;

        /* init epoll thread */
        if (!icipc_epoll_thread_init(&self->epoll_thread, self->socket_fd,
                                     socket_event_received, NULL, self))
                goto error;

        self->lost_func = lost_func;
        self->lost_data = lost_data;
        self->lost_connection = false;
        if (user_size > 0)
                self->user_data =
                    (void *)((uint8_t *) self + sizeof(struct icipc_sender));

        return self;

 error:
        if (self->buffer_read)
                free(self->buffer_read);
        if (self->socket_fd != -1)
                close(self->socket_fd);
        free(self);
        return NULL;
}

void icipc_sender_free(struct icipc_sender *self) {
        icipc_sender_disconnect(self);

        icipc_epoll_thread_destroy(&self->epoll_thread);
        free(self->buffer_read);
        close(self->socket_fd);
        free(self);
}

bool icipc_sender_connect(struct icipc_sender *self) {
        if (icipc_sender_is_connected(self))
                return true;

        /* if connection was lost, re-init epoll thread with new socket */
        if (self->lost_connection) {
                icipc_epoll_thread_stop(&self->epoll_thread);
                icipc_epoll_thread_destroy(&self->epoll_thread);
                self->socket_fd =
                    socket(PF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK,
                           0);
                if (self->socket_fd < 0)
                        return false;
                if (!icipc_epoll_thread_init
                    (&self->epoll_thread, self->socket_fd,
                     socket_event_received, NULL, self)) {
                        close(self->socket_fd);
                        return false;
                }
                self->lost_connection = false;
        }

        /* connect */
        if (connect(self->socket_fd, (struct sockaddr *)&self->addr,
                    sizeof(self->addr)) == 0 &&
            icipc_epoll_thread_start(&self->epoll_thread)) {
                self->is_connected = true;
                return true;
        }

        return false;
}

void icipc_sender_disconnect(struct icipc_sender *self) {
        if (icipc_sender_is_connected(self)) {
                icipc_epoll_thread_stop(&self->epoll_thread);
                shutdown(self->socket_fd, SHUT_RDWR);
                self->is_connected = false;
        }
}

bool icipc_sender_is_connected(struct icipc_sender *self) {
        return self->is_connected;
}

bool icipc_sender_send(
                struct icipc_sender *self,
                const uint8_t * buffer,
                size_t size,
                icipc_sender_reply_func_t func,
                void *data) {
        int id = -1;

        if (buffer == NULL || size == 0)
                return false;

        if (!icipc_sender_is_connected(self))
                return false;

        /* add the task in the queue */
        if (func) {
                id = push_sync_task(self, func, data);
                if (id == -1)
                        return false;
        }

        /* write buffer and remove task if it fails */
        if (icipc_socket_write(self->socket_fd, buffer, size) <= 0) {
                if (id != -1)
                        self->async_tasks[id].func = NULL;
                return false;
        }

        return true;
}

void *icipc_sender_get_user_data(struct icipc_sender *self) {
        return self->user_data;
}