#!/usr/bin/env python3
import sqlite3
from sqlite3 import Connection
import argparse
from dataclasses import dataclass, asdict
from typing import List, Optional
from datetime import datetime
import os
import json


@dataclass
class WeightEntry:
    id: Optional[int]
    weight: float
    time: datetime

    def __str__(self):
        date_str = self.time.strftime("%Y-%m-%d")
        return f"{date_str} | {self.weight:5.1f} kg"


def default_serializer(obj):
    if isinstance(obj, datetime):
        return obj.isoformat()
    elif hasattr(obj, "__dict__"):
        return asdict(obj)
    raise TypeError(f"Type {type(obj)} not serializable")


parser = argparse.ArgumentParser(description="Vikt weight tracker")
parser.add_argument("-a", type=float, required=False, help="Add a new weight.")
parser.add_argument("-l", action="store_true", required=False, help="List all weights.")
parser.add_argument("--db", default="weights.sqlite", help="Database file path")
parser.add_argument("--import-json", help="Import from json")
args = parser.parse_args()

# NOTE: Idempotent
table = """
CREATE TABLE IF NOT EXISTS weights (
    id INTEGER PRIMARY KEY,
    weight REAL NOT NULL,
    time TEXT UNIQUE NOT NULL DEFAULT (CURRENT_TIMESTAMP)
);
""".strip()


def db_exec(db_conn: Connection, query):
    cursor = db_conn.cursor()
    cursor.execute(query)
    db_conn.commit()


def add_weight(
    db_conn: Connection, weight: float, datet: datetime = datetime.now().astimezone()
):
    """Insert a weight record into the database."""
    cursor = db_conn.cursor()
    cursor.execute(
        "INSERT OR IGNORE INTO weights (weight, time) VALUES (?, ?)",
        (
            weight,
            datet.isoformat(),
        ),
    )
    db_conn.commit()


def get_weights(db_conn: Connection) -> List[WeightEntry]:
    """Gets weights from the database."""
    cursor = db_conn.cursor()

    cursor.execute("SELECT * FROM weights ORDER BY time ASC")
    rows = cursor.fetchall()

    db_conn.commit()
    return [
        WeightEntry(id=row[0], weight=row[1], time=datetime.fromisoformat(row[2]))
        for row in rows
    ]


def import_file(file_json: str) -> List[WeightEntry]:
    entries = []
    with open(file_json, "r") as f:
        content = f.read()
        for entry in json.loads(content)["entries"]:
            entries.append(
                WeightEntry(
                    None, entry["weight"], datetime.fromisoformat(entry["date"])
                )
            )
    return entries


def read_or_create_config():
    default_config = {"db_path": "~/Documents/vikt.db"}
    config_dir = os.path.expanduser("~/.config/")
    config_path = os.path.join(config_dir, "viktconfig.json")
    os.makedirs(config_dir, exist_ok=True)

    if not os.path.exists(config_path):
        with open(config_path, "w") as f:
            json.dump(default_config, f, indent=2)
        print(f"Config file created with defaults at {config_path}")
        return default_config
    else:
        with open(config_path, "r") as f:
            config = json.load(f)
        return config


# TODO: Use this
monthly_avg = """
SELECT
    strftime('%Y-%m', time) AS month,
    AVG(weight) AS avg_weight
FROM weights
GROUP BY month
ORDER BY month;
"""

avg30d = """
SELECT AVG(weight) AS avg_weight_30d
FROM weights
WHERE time >= datetime('now', '-30 days');
"""


def main():
    conf = read_or_create_config()
    conn = sqlite3.connect(os.path.expanduser(conf["db_path"]))

    db_exec(conn, "PRAGMA foreign_keys = ON;")
    db_exec(conn, table)

    if args.import_json:
        weights = import_file(args.import_json)
        for w in weights:
            add_weight(conn, w.weight, w.time)

    if args.a is not None:
        add_weight(conn, args.a)
        print(f"Weight {args.a} added to the database.")

    if args.l:
        for w in get_weights(conn):
            print(w)
    else:
        crs = conn.cursor()

        res = crs.execute("SELECT MIN(weight), MAX(weight) FROM weights;")
        avg = res.fetchone()
        if avg and avg[0] and avg[1]:
            print(f"Minimum: {avg[0]:.1f}")
            print(f"Maximum: {avg[1]:.1f}")

        res = crs.execute("SELECT MAX(weight) - MIN(weight) FROM weights;")
        avg = res.fetchone()
        if avg and avg[0]:
            print(f"P2P Delta: {avg[0]:.1f}")

        res = crs.execute("SELECT AVG(weight) FROM weights;")
        avg = res.fetchone()
        if avg and avg[0]:
            print(f"Total average: {avg[0]:.1f}")

        res = crs.execute(avg30d)
        avg = res.fetchone()
        if avg and avg[0]:
            print(f"30 day average: {avg[0]:.1f}")

        res = crs.execute("SELECT weight FROM weights ORDER BY time DESC LIMIT 1;")
        avg = res.fetchone()
        if avg and avg[0]:
            print(f"Latest: {avg[0]:.1f}")

    conn.close()


if __name__ == "__main__":
    main()
