"""
Whale Hunter AI Algorithm - Universal Version with Forex Support
تشخیص حرکت نهنگ‌ها و تولید سیگنال خرید/فروش برای تمام دارایی‌ها
"""

import numpy as np
import pandas as pd
from typing import Dict, Any, List


class WhaleHunterAI:
    """الگوریتم شکار نهنگ - سیگنال‌های هوشمند ورود پول هوشمند (نسخه جهانی با پشتیبانی فارکس)"""
    
    def __init__(self, config: Dict[str, Any] = None):
        self.config = {
            'rsi_length': 14,
            'rsi_overbought': 70,
            'rsi_oversold': 30,
            'ma_short': 9,
            'ma_long': 21,
            'volume_threshold': 1.5,
            'volume_z_length': 20,
            'volume_z_min': 1.2,
            'use_vwap': True,
            'use_slope': True,
            'slope_length': 10,
            'use_adx': True,
            'adx_length': 14,
            'adx_trend_min': 20,
            'use_rsi_reclaim': True,
            'reclaim_low': 35,
            'reclaim_high': 65
        }
        if config:
            self.config.update(config)
    
    def is_forex(self, symbol: str) -> bool:
        """تشخیص نماد فارکس"""
        forex_symbols = ['EURUSD=X', 'GBPUSD=X', 'USDJPY=X', 'AUDUSD=X', 'USDCAD=X', 'USDCHF=X', 'NZDUSD=X']
        return symbol in forex_symbols or '=X' in symbol
    
    def calculate_rsi(self, prices: pd.Series, length: int) -> pd.Series:
        delta = prices.diff()
        gain = delta.where(delta > 0, 0).rolling(length).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(length).mean()
        rs = gain / loss
        rsi = 100 - (100 / (1 + rs))
        return rsi
    
    def calculate_ma(self, prices: pd.Series, length: int) -> pd.Series:
        return prices.rolling(length).mean()
    
    def calculate_vwap(self, df: pd.DataFrame) -> pd.Series:
        """محاسبه VWAP - برای فارکس از داده‌های قیمتی استفاده می‌شود"""
        typical_price = (df['High'] + df['Low'] + df['Close']) / 3
        volume = df['Volume'].clip(lower=1)
        cumulative_tp_vol = (typical_price * volume).cumsum()
        cumulative_vol = volume.cumsum()
        vwap = cumulative_tp_vol / cumulative_vol
        return vwap
    
    def calculate_adx(self, df: pd.DataFrame, length: int) -> Dict[str, pd.Series]:
        high = df['High']
        low = df['Low']
        close = df['Close']
        
        tr1 = high - low
        tr2 = (high - close.shift()).abs()
        tr3 = (low - close.shift()).abs()
        tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
        atr = tr.rolling(length).mean()
        
        up_move = high - high.shift()
        down_move = low.shift() - low
        
        plus_dm = up_move.where((up_move > down_move) & (up_move > 0), 0)
        minus_dm = down_move.where((down_move > up_move) & (down_move > 0), 0)
        
        plus_di = 100 * plus_dm.rolling(length).mean() / atr
        minus_di = 100 * minus_dm.rolling(length).mean() / atr
        
        dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di)
        adx = dx.rolling(length).mean()
        
        return {'adx': adx, 'plus_di': plus_di, 'minus_di': minus_di}
    
    def calculate_volume_zscore(self, volume: pd.Series, length: int) -> pd.Series:
        vol_mean = volume.rolling(length).mean()
        vol_std = volume.rolling(length).std()
        vol_z = (volume - vol_mean) / vol_std
        return vol_z.fillna(0)
    
    def detect_crossover(self, series1: pd.Series, series2: pd.Series) -> pd.Series:
        return (series1 > series2) & (series1.shift(1) <= series2.shift(1))
    
    def detect_crossunder(self, series1: pd.Series, series2: pd.Series) -> pd.Series:
        return (series1 < series2) & (series1.shift(1) >= series2.shift(1))
    
    def analyze(self, df: pd.DataFrame, symbol: str = "BTC-USD") -> Dict[str, Any]:
        if df.empty or len(df) < 30:
            return {
                'signal': {'buy': False, 'sell': False, 'smart_money_in': False, 'smart_money_out': False},
                'indicators': {'rsi': 50, 'adx': 20, 'trend': 'insufficient_data'},
                'error': 'داده کافی برای تحلیل وجود ندارد'
            }
        
        close = df['Close']
        high = df['High']
        low = df['Low']
        volume = df['Volume']
        
        is_forex_symbol = self.is_forex(symbol)
        if is_forex_symbol:
            volume = pd.Series([1000000] * len(df), index=df.index)
        
        # RSI
        rsi = self.calculate_rsi(close, self.config['rsi_length'])
        is_overbought = rsi > self.config['rsi_overbought']
        is_oversold = rsi < self.config['rsi_oversold']
        
        # Moving Averages
        ma_short = self.calculate_ma(close, self.config['ma_short'])
        ma_long = self.calculate_ma(close, self.config['ma_long'])
        
        buy_signal_raw = self.detect_crossover(ma_short, ma_long)
        sell_signal_raw = self.detect_crossunder(ma_short, ma_long)
        
        # Volume Analysis
        if is_forex_symbol:
            price_range = (high - low) / close
            price_range_ma = price_range.rolling(20).mean()
            high_volume = price_range > price_range_ma * 1.5
        else:
            avg_volume = volume.rolling(50).mean()
            high_volume_legacy = volume > (self.config['volume_threshold'] * avg_volume)
            vol_zscore = self.calculate_volume_zscore(volume, self.config['volume_z_length'])
            high_volume_z = vol_zscore >= self.config['volume_z_min']
            high_volume = high_volume_z if high_volume_z.any() else high_volume_legacy
        
        # VWAP
        try:
            vwap = self.calculate_vwap(df)
            vwap_long_ok = close > vwap if self.config['use_vwap'] else True
            vwap_short_ok = close < vwap if self.config['use_vwap'] else True
        except:
            vwap_long_ok = True
            vwap_short_ok = True
        
        # Slope Filter
        if self.config['use_slope'] and len(ma_short) > self.config['slope_length']:
            slope_short = ma_short - ma_short.shift(self.config['slope_length'])
            slope_long = ma_long - ma_long.shift(self.config['slope_length'])
            slope_long_ok = (slope_short > 0) & (slope_long >= 0)
            slope_short_ok = (slope_short < 0) & (slope_long <= 0)
        else:
            slope_long_ok = True
            slope_short_ok = True
        
        # ADX Regime
        if self.config['use_adx'] and len(df) > self.config['adx_length']:
            try:
                adx_data = self.calculate_adx(df, self.config['adx_length'])
                adx = adx_data['adx']
                is_trending = adx >= self.config['adx_trend_min']
                regime_ok = is_trending
            except:
                regime_ok = True
                adx = pd.Series(20, index=df.index)
        else:
            adx = pd.Series(20, index=df.index)
            regime_ok = True
        
        # RSI Reclaim
        if self.config['use_rsi_reclaim']:
            rsi_reclaim_buy = (rsi.rolling(3).min() < self.config['reclaim_low']) & (rsi > self.config['reclaim_low']) & (rsi.shift(1) <= self.config['reclaim_low'])
            rsi_reclaim_sell = (rsi.rolling(3).max() > self.config['reclaim_high']) & (rsi < self.config['reclaim_high']) & (rsi.shift(1) >= self.config['reclaim_high'])
        else:
            rsi_reclaim_buy = is_oversold
            rsi_reclaim_sell = is_overbought
        
        # Final Signals
        buy_final = buy_signal_raw & high_volume & vwap_long_ok & slope_long_ok & regime_ok & rsi_reclaim_buy
        sell_final = sell_signal_raw & high_volume & vwap_short_ok & slope_short_ok & regime_ok & rsi_reclaim_sell
        
        # Smart Money Detection
        smart_money_in = high_volume & (close > df['Open'])
        smart_money_out = high_volume & (close < df['Open'])
        
        return {
            'signal': {
                'buy': bool(buy_final.iloc[-1]) if len(buy_final) > 0 else False,
                'sell': bool(sell_final.iloc[-1]) if len(sell_final) > 0 else False,
                'smart_money_in': bool(smart_money_in.iloc[-1]) if len(smart_money_in) > 0 else False,
                'smart_money_out': bool(smart_money_out.iloc[-1]) if len(smart_money_out) > 0 else False
            },
            'indicators': {
                'rsi': float(rsi.iloc[-1]) if len(rsi) > 0 else 50,
                'rsi_status': 'overbought' if (len(rsi) > 0 and is_overbought.iloc[-1]) else ('oversold' if (len(rsi) > 0 and is_oversold.iloc[-1]) else 'neutral'),
                'ma_short': float(ma_short.iloc[-1]) if len(ma_short) > 0 else 0,
                'ma_long': float(ma_long.iloc[-1]) if len(ma_long) > 0 else 0,
                'adx': float(adx.iloc[-1]) if len(adx) > 0 else 20,
                'is_trending': bool(is_trending.iloc[-1]) if 'is_trending' in locals() and len(is_trending) > 0 else False,
                'high_volume': bool(high_volume.iloc[-1]) if len(high_volume) > 0 else False
            },
            'config': self.config
        }
    
    def get_signal_summary(self, df: pd.DataFrame, symbol: str = "BTC-USD") -> str:
        result = self.analyze(df, symbol)
        
        if result.get('error'):
            return f"⚠️ {result['error']}"
        elif result['signal']['buy']:
            return "🟢 **سیگنال خرید (BUY)** - نهنگ‌ها در حال ورود هستند!"
        elif result['signal']['sell']:
            return "🔴 **سیگنال فروش (SELL)** - نهنگ‌ها در حال خروج هستند!"
        elif result['signal']['smart_money_in']:
            return "🐋 **ورود پول هوشمند** - حجم غیرعادی با کندل صعودی"
        elif result['signal']['smart_money_out']:
            return "🐋 **خروج پول هوشمند** - حجم غیرعادی با کندل نزولی"
        else:
            return "⚪ **بدون سیگنال** - منتظر حرکت بعدی باشیم"


# نمونه استفاده
if __name__ == "__main__":
    import yfinance as yf
    whale = WhaleHunterAI()
    
    symbols = ['BTC-USD', 'EURUSD=X', 'GC=F', '^GSPC']
    for symbol in symbols:
        print(f"\n=== {symbol} ===")
        data = yf.download(symbol, period="7d", interval="1h")
        result = whale.analyze(data, symbol)
        print(f"Signal: {result['signal']}")
        print(f"ADX: {result['indicators']['adx']}")