/* PipeWire AGL Cluster IPC
 *
 * Copyright © 2021 Collabora Ltd.
 *    @author Julian Bouzas <julian.bouzas@collabora.com>
 *
 * SPDX-License-Identifier: MIT
 */

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <string.h>
#include <pthread.h>
#include <time.h>
#include <errno.h>
#include <assert.h>

#include "private.h"

#define MAX_POLL_EVENTS 128
#define MAX_LOG_MESSAGE 1024

/* log */

const char *icipc_logger_level_text[] = {
  [ICIPC_LOG_LEVEL_ERROR] = "E",
  [ICIPC_LOG_LEVEL_WARN] = "W",
  [ICIPC_LOG_LEVEL_INFO] = "I",
};

struct icipc_logger {
  enum icipc_log_level level;
};

static const struct icipc_logger *
icipc_log_get_instance (void)
{
  static struct icipc_logger logger_ = { 0, };
  static struct icipc_logger* instance_ = NULL;

  if (instance_ == NULL) {
    char * val_str = NULL;
    enum icipc_log_level val = 0;

    /* default to error */
    logger_.level = ICIPC_LOG_LEVEL_WARN;

    /* get level from env */
    val_str = getenv ("ICIPC_DEBUG");
    if (val_str && sscanf (val_str, "%u", &val) == 1 &&
        val >= ICIPC_LOG_LEVEL_NONE)
      logger_.level = val;

    instance_ = &logger_;
  }

  return instance_;
}

void
icipc_logv (enum icipc_log_level level, const char *fmt, va_list args)
{
  const struct icipc_logger *logger = NULL;

  logger = icipc_log_get_instance ();
  assert (logger);

  if (logger->level >= level) {
    assert (level > 0);
    char msg[MAX_LOG_MESSAGE];
    struct timespec time;
    clock_gettime (CLOCK_REALTIME, &time);
    vsnprintf (msg, MAX_LOG_MESSAGE, fmt, args);
    fprintf (stderr, "[%s][%lu.%lu] %s\n", icipc_logger_level_text[level],
        time.tv_sec, time.tv_sec, msg);
  }
}

void
icipc_log (enum icipc_log_level level, const char *fmt, ...)
{
  va_list args;
  va_start (args, fmt);
  icipc_logv (level, fmt, args);
  va_end (args);
}

/* socket */

ssize_t
icipc_socket_write (int fd, const uint8_t *buffer, size_t size)
{
  size_t total_written = 0;
  size_t n;

  assert (fd >= 0);
  assert (buffer != NULL);
  assert (size > 0);

  do {
    n = write(fd, buffer, size);
    if (n < size) {
      if (errno == EINTR)
        continue;
      if (errno == EAGAIN || errno == EWOULDBLOCK)
        return total_written;
      return -1;
    }
    total_written += n;
  } while (total_written < size);

  return total_written;
}

ssize_t
icipc_socket_read (int fd, uint8_t **buffer, size_t *max_size)
{
  ssize_t n;
  ssize_t size;
  size_t offset = 0;

  assert (buffer);
  assert (*buffer);
  assert (max_size);
  assert (*max_size > 0);

again:
  size = *max_size - offset;
  n = read (fd, *buffer + offset, size);
  if (n == 0)
    return 0;

  /* check for errors */
  if (n < 0) {
    if (errno == EINTR)
      goto again;
    if (errno == EAGAIN || errno == EWOULDBLOCK)
      return offset;
    return -1;
  }

  /* realloc if we need more space, and read again */
  if (n >= size) {
    *max_size += *max_size;
    *buffer = reallocarray (*buffer, *max_size, sizeof (uint8_t));
    offset += n;
    goto again;
  }

  return offset + n;
}

/* epoll thread */

bool
icipc_epoll_thread_init (struct epoll_thread *self,
                         int socket_fd,
                         icipc_epoll_thread_event_funct_t sock_func,
                         icipc_epoll_thread_event_funct_t other_func,
                         void *data)
{
  struct epoll_event event;

  self->socket_fd = socket_fd;
  self->event_fd = -1;
  self->epoll_fd = -1;

  /* create event fd */
  self->event_fd = eventfd (0, EFD_CLOEXEC | EFD_NONBLOCK);
  if (self->event_fd == -1)
    goto error;

  /* create epoll fd */
  self->epoll_fd = epoll_create1 (EPOLL_CLOEXEC);
  if (self->epoll_fd == -1)
    goto error;

  /* poll socket fd */
  event.events = EPOLLIN;
  event.data.fd = self->socket_fd;
  if (epoll_ctl (self->epoll_fd, EPOLL_CTL_ADD, self->socket_fd, &event) != 0)
    goto error;

  /* poll event fd */
  event.events = EPOLLIN;
  event.data.fd = self->event_fd;
  if (epoll_ctl (self->epoll_fd, EPOLL_CTL_ADD, self->event_fd, &event) != 0)
    goto error;

  self->socket_event_func = sock_func;
  self->other_event_func = other_func;
  self->event_data = data;
  return true;

error:
  if (self->epoll_fd != -1)
    close (self->epoll_fd);
  if (self->event_fd != -1)
    close (self->event_fd);
  return false;
}

static void *
epoll_thread_run (void *data)
{
  struct epoll_thread *self = data;
  bool exit = false;

  while (!exit) {
    /* wait for events */
    struct epoll_event ep[MAX_POLL_EVENTS];
    int n = epoll_wait (self->epoll_fd, ep, MAX_POLL_EVENTS, -1);
    if (n < 0) {
      icipc_log_error ("epoll_thread: failed to wait for event: %s",
          strerror(errno));
      continue;
    }

    for (int i = 0; i < n; i++) {
      /* socket fd */
      if (ep[i].data.fd == self->socket_fd) {
        if (self->socket_event_func)
          self->socket_event_func (self, ep[i].data.fd, self->event_data);
      }

      /* event fd */
      else if (ep[i].data.fd == self->event_fd) {
        uint64_t stop = 0;
        ssize_t res = read (ep[i].data.fd, &stop, sizeof(uint64_t));
        if (res == sizeof(uint64_t) && stop == 1)
          exit = true;
      }

      /* other */
      else {
        if (self->other_event_func)
          self->other_event_func (self, ep[i].data.fd, self->event_data);
      }
    }
  }

  return NULL;
}

bool
icipc_epoll_thread_start (struct epoll_thread *self)
{
  return pthread_create (&self->thread, NULL, epoll_thread_run, self) == 0;
}

void
icipc_epoll_thread_stop (struct epoll_thread *self)
{
  uint64_t value = 1;
  ssize_t res = write (self->event_fd, &value, sizeof(uint64_t));
  if (res == sizeof(uint64_t))
    pthread_join (self->thread, NULL);
}

void
icipc_epoll_thread_destroy (struct epoll_thread *self)
{
  close (self->epoll_fd);
  close (self->event_fd);
}