#!/usr/bin/env python3

import argparse
import os
import socket
import sys
import time

initial_interval = 0.02
max_interval = 0.5


def is_fd_open(fd):
    try:
        os.fstat(fd)
        return True
    except OSError:
        return False


# write to file descriptor 3 if it is open (during bats tests), otherwise stderr
def write_error(ex):
    fd = 2
    if is_fd_open(3):
        fd = 3
    os.write(fd, str(ex).encode())


def wait(host, port, timeout):
    t0 = time.perf_counter()
    current_interval = initial_interval
    while True:
        try:
            with socket.create_connection((host, port), timeout=timeout):
                break
        except OSError as ex:
            if time.perf_counter() - t0 >= timeout:
                raise TimeoutError(f'Timeout waiting for {host}:{port} after {timeout}s') from ex
            time.sleep(current_interval)
            current_interval = min(current_interval * 1.5, max_interval)


def main(argv):
    parser = argparse.ArgumentParser(description="Check if a port is open.")
    parser.add_argument("port", type=int, help="Port number to check")
    parser.add_argument("--host", type=str, default="localhost", help="Host to check")
    parser.add_argument("-t", "--timeout", type=float, default=10.0, help="Timeout duration in seconds")
    parser.add_argument("-q", "--quiet", action="store_true", help="Enable quiet mode")
    args = parser.parse_args(argv)

    try:
        wait(args.host, args.port, args.timeout)
    except TimeoutError as ex:
        if not args.quiet:
            write_error(ex)
        sys.exit(1)


if __name__ == "__main__":
    main(sys.argv[1:])
