import asyncio
import websockets
import json
import os
import pytz
import time
from datetime import datetime, timedelta
from lib.config_manager import ConfigManager
import ntplib
from lib.time_utils import get_time_frame_config  # Import from shared module

config = ConfigManager()

# ==================== Color Settings ====================
ANSI_COLORS = {
    "yellow": "\033[93m",
    "purple": "\033[95m",
    "blue": "\033[94m",
    "red": "\033[91m",
    "green": "\033[92m",
    "cyan": "\033[96m",
    "white": "\033[97m",
}
RESET_COLOR = "\033[0m"

coin_colors = {}
for coin, color in config.COIN_COLORS.items():
    if color in ANSI_COLORS:
        coin_colors[coin] = ANSI_COLORS[color]
# =====================================================

DATA_FILE = os.path.join("storage", "trades.json")
STATE_FILE = os.path.join("storage", "state.json")
FT_SECONDS = 900  # Default 15m, will be updated from config

whale_trades = []
start_time = None
cum_d_volumes = {symbol: 0.0 for symbol in config.SYMBOLS}
cum_ft_volumes = {symbol: 0.0 for symbol in config.SYMBOLS}
current_ft_candle_start = {symbol: 0 for symbol in config.SYMBOLS}
last_clean_time = time.time()
stored_day = None
reporting_task = None

def load_state():
    try:
        if os.path.exists(STATE_FILE):
            with open(STATE_FILE, 'r') as f:
                return json.load(f)
    except Exception as e:
        print(f"⚠️ Error loading state: {str(e)}")
    return None

def save_state():
    try:
        os.makedirs("storage", exist_ok=True)
        temp_path = STATE_FILE + ".tmp"
        state = {
            "last_update": time.time(),
            "cum_d_volumes": cum_d_volumes,
            "cum_ft_volumes": cum_ft_volumes,
            "current_ft_candle_start": current_ft_candle_start,
            "stored_day": stored_day
        }
        with open(temp_path, 'w') as f:
            json.dump(state, f)
        os.replace(temp_path, STATE_FILE)
    except Exception as e:
        print(f"⚠️ Error saving state: {str(e)}")

def load_existing_trades():
    try:
        if os.path.exists(DATA_FILE):
            with open(DATA_FILE, "r") as f:
                return json.load(f)
        return []
    except Exception as e:
        print(f"⚠️ Error loading trades: {str(e)}")
        return []

def save_trades():
    try:
        os.makedirs("storage", exist_ok=True)
        temp_path = DATA_FILE + ".tmp"
        with open(temp_path, "w") as f:
            json.dump(whale_trades, f)
        os.replace(temp_path, DATA_FILE)
    except Exception as e:
        print(f"⚠️ Error saving trades: {str(e)}")

def clean_old_trades():
    global whale_trades
    current_time = get_ntp_time()
    cutoff = current_time - (config.MAX_TRADE_AGE_DAYS * 86400)
    initial_count = len(whale_trades)
    whale_trades = [t for t in whale_trades if t["timestamp"] >= cutoff]
    
    if initial_count != len(whale_trades):
        print(f"🧹 Cleaned {initial_count - len(whale_trades)} old trades")
        save_trades()

def get_ntp_time():
    try:
        ntp_client = ntplib.NTPClient()
        for server in ['pool.ntp.org', 'time.google.com']:
            try:
                response = ntp_client.request(server)
                return response.tx_time
            except:
                continue
        return datetime.now(pytz.timezone("UTC")).timestamp()
    except:
        return datetime.now(pytz.timezone("UTC")).timestamp()

def calculate_cumulative_volumes(trades, symbol, day_start, ft_start):
    cum_d = 0.0
    cum_ft = 0.0
    
    for trade in trades:
        if trade["coin"] != symbol:
            continue
            
        signed_vol = trade["volume"] if trade["side"] == "B" else -trade["volume"]
        
        if trade["timestamp"] >= day_start:
            cum_d += signed_vol
        
        if trade["timestamp"] >= ft_start:
            cum_ft += signed_vol
            
    return cum_d, cum_ft

async def process_trade_data(data):
    global whale_trades, cum_d_volumes, cum_ft_volumes, current_ft_candle_start
    
    try:
        data = json.loads(data)
        if data.get("channel") == "trades" and data.get("data"):
            current_time = get_ntp_time()
            
            for trade in data["data"]:
                coin = trade.get("coin")
                if coin in config.SYMBOLS:
                    try:
                        volume = float(trade.get("sz", 0))
                    except ValueError:
                        volume = 0.0
                    
                    threshold = config.THRESHOLDS.get(coin, 1.0)
                    
                    if volume >= threshold:
                        try:
                            timestamp = trade.get("time", current_time * 1000) / 1000
                        except (TypeError, ValueError):
                            timestamp = current_time
                        
                        try:
                            price = float(trade.get("px", 0))
                        except (TypeError, ValueError):
                            price = 0.0
                        
                        side = trade.get("side", "unknown")
                        signed_volume = volume if side == "B" else -volume
                        
                        # Update cumulative volumes
                        cum_d_volumes[coin] = cum_d_volumes.get(coin, 0.0) + signed_volume
                        cum_ft_volumes[coin] = cum_ft_volumes.get(coin, 0.0) + signed_volume
                        
                        # Update candle start time
                        if timestamp >= current_ft_candle_start.get(coin, 0) + FT_SECONDS:
                            current_ft_candle_start[coin] = timestamp - (timestamp % FT_SECONDS)
                            cum_ft_volumes[coin] = 0.0
                        
                        whale_trades.append({
                            "timestamp": timestamp,
                            "volume": volume,
                            "side": side,
                            "price": price,
                            "coin": coin
                        })
                        save_trades()
                        save_state()
                        
                        side_str = "Buy" if side == "B" else "Sell"
                        color_code = coin_colors.get(coin, RESET_COLOR)
                        
                        print(
                            f"{color_code}{coin:8}  {side_str:5} {volume:12.2f}  "
                            f"{price:12.4f}$   "
                            f"CumD:{cum_d_volumes.get(coin, 0.0):+14.2f}     "
                            f"CumFT:{cum_ft_volumes.get(coin, 0.0):+14.2f}{RESET_COLOR}"
                        )

    except Exception as e:
        print(f"❌ Error processing trade data: {str(e)}")

async def reporting_loop():
    global last_clean_time, stored_day, cum_d_volumes
    
    # Get time frame configuration
    time_frame_config = get_time_frame_config(config.TIME_FRAMES)
    
    interval_seconds = config.INTERVAL_MINUTES * 60
    current_time = get_ntp_time()
    next_report_time = (current_time // interval_seconds + 1) * interval_seconds
    wait_time = next_report_time - current_time
    
    if wait_time > 0:
        print(f"⏱ Waiting {wait_time:.1f} seconds for first report...")
        await asyncio.sleep(wait_time)
    
    consecutive_errors = 0
    
    while True:
        try:
            current_time = get_ntp_time()
            current_utc_day = datetime.utcfromtimestamp(current_time).strftime("%Y-%m-%d")
            
            if stored_day != current_utc_day:
                if stored_day is not None:
                    print(f"♻️ Resetting daily volumes for new day: {current_utc_day}")
                    for symbol in config.SYMBOLS:
                        cum_d_volumes[symbol] = 0.0
                stored_day = current_utc_day
                save_state()
            
            if current_time - last_clean_time > 86400:
                clean_old_trades()
                last_clean_time = current_time
            
            for symbol in config.SYMBOLS:
                symbol_trades = [t for t in whale_trades if t["coin"] == symbol]
                
                if not symbol_trades:
                    continue
                    
                # Generate summary using the shared function
                from bot.telegram_bot import generate_summary
                summary = generate_summary(
                    symbol_trades,
                    current_time,
                    symbol,
                    time_frame_config,
                    interval_seconds
                )
                
                if summary:
                    try:
                        from bot.telegram_bot import send_telegram_report
                        await send_telegram_report(summary)
                        consecutive_errors = 0
                        print(f"✅ Report sent successfully for {symbol}\n")
                    except Exception as e:
                        consecutive_errors += 1
                        print(f"❌ Telegram error ({consecutive_errors}): {e}")
                        
                        if consecutive_errors >= 3:
                            print("⚠️ Too many errors, waiting 60 seconds...")
                            await asyncio.sleep(60)
            
            next_report_time = current_time - (current_time % interval_seconds) + interval_seconds
            wait_time = next_report_time - current_time
            
            if wait_time > 0:
                await asyncio.sleep(wait_time)
                
        except Exception as e:
            print(f"🔥 Critical error in reporting loop: {e}")
            print("🛑 Restarting reporting loop in 30 seconds...")
            await asyncio.sleep(30)

async def connect_to_websocket():
    global whale_trades, start_time, cum_d_volumes, cum_ft_volumes, current_ft_candle_start, stored_day, FT_SECONDS
    global reporting_task
    
    saved_state = load_state()
    current_time = get_ntp_time()
    
    # Get time frame configuration
    time_frame_config = get_time_frame_config(config.TIME_FRAMES)
    first_time_frame = config.TIME_FRAMES[0] if config.TIME_FRAMES else "15m"
    if first_time_frame in time_frame_config:
        FT_SECONDS = time_frame_config[first_time_frame][1]
    print(f"⏱ First time frame: {first_time_frame} ({FT_SECONDS} seconds)")
    
    existing_trades = load_existing_trades()
    whale_trades = existing_trades
    
    # Initialize state
    default_cum_d = {symbol: 0.0 for symbol in config.SYMBOLS}
    default_cum_ft = {symbol: 0.0 for symbol in config.SYMBOLS}
    default_ft_candle = {symbol: 0 for symbol in config.SYMBOLS}
    
    if saved_state:
        print("♻️ Restoring from saved state")
        try:
            stored_day = saved_state.get("stored_day")
            
            saved_cum_d = saved_state.get("cum_d_volumes", {})
            for symbol in config.SYMBOLS:
                if symbol in saved_cum_d:
                    default_cum_d[symbol] = saved_cum_d[symbol]
            
            saved_cum_ft = saved_state.get("cum_ft_volumes", {})
            for symbol in config.SYMBOLS:
                if symbol in saved_cum_ft:
                    default_cum_ft[symbol] = saved_cum_ft[symbol]
            
            saved_ft_candle = saved_state.get("current_ft_candle_start", {})
            for symbol in config.SYMBOLS:
                if symbol in saved_ft_candle:
                    default_ft_candle[symbol] = saved_ft_candle[symbol]
            
            cum_d_volumes = default_cum_d
            cum_ft_volumes = default_cum_ft
            current_ft_candle_start = default_ft_candle
        except Exception as e:
            print(f"❌ Error restoring state: {e}")
            cum_d_volumes = default_cum_d
            cum_ft_volumes = default_cum_ft
            current_ft_candle_start = default_ft_candle
    else:
        stored_day = datetime.utcfromtimestamp(current_time).strftime("%Y-%m-%d")
        cum_d_volumes = default_cum_d
        cum_ft_volumes = default_cum_ft
        current_ft_candle_start = default_ft_candle
        for symbol in config.SYMBOLS:
            current_ft_candle_start[symbol] = current_time - (current_time % FT_SECONDS)
    
    current_utc_day = datetime.utcfromtimestamp(current_time).strftime("%Y-%m-%d")
    day_start = datetime.strptime(current_utc_day, "%Y-%m-%d").timestamp()
    
    for symbol in config.SYMBOLS:
        cum_d, cum_ft = calculate_cumulative_volumes(
            whale_trades, 
            symbol,
            day_start,
            current_ft_candle_start[symbol]
        )
        cum_d_volumes[symbol] = cum_d
        cum_ft_volumes[symbol] = cum_ft
    
    save_state()
    
    uri = "wss://api.hyperliquid.xyz/ws"
    while True:
        try:
            async with websockets.connect(uri, ping_interval=20, ping_timeout=15) as websocket:
                print("✅ WebSocket connected")
                
                if reporting_task is None or reporting_task.done():
                    reporting_task = asyncio.create_task(reporting_loop())
                
                for symbol in config.SYMBOLS:
                    await websocket.send(json.dumps({
                        "method": "subscribe",
                        "subscription": {"type": "trades", "coin": symbol}
                    }))
                
                while True:
                    try:
                        data = await asyncio.wait_for(websocket.recv(), timeout=25)
                        await process_trade_data(data)
                    except asyncio.TimeoutError:
                        await websocket.ping()
                    
        except Exception as e:
            print(f"❌ WebSocket error: {e}")
            print("♻️ Reconnecting in 5 seconds...\n")
            await asyncio.sleep(5)