import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import yfinance as yf
import ta

class EnhancedMetalsTradingStrategy:
    def __init__(self, primary_symbol, correlated_symbols, start_date, end_date):
        """
        Ulepszona inicjalizacja strategii
        :param primary_symbol: Główny metal (np. GC=F dla złota)
        :param correlated_symbols: Lista powiązanych instrumentów (np. SI=F dla srebra)
        """
        self.primary_symbol = primary_symbol
        self.correlated_symbols = correlated_symbols
        self.start_date = start_date
        self.end_date = end_date
        self.data = {}
        self.correlation_threshold = 0.7
        
    def fetch_all_data(self):
        """Pobieranie danych dla wszystkich powiązanych instrumentów"""
        self.data[self.primary_symbol] = yf.download(self.primary_symbol, 
                                                    start=self.start_date, 
                                                    end=self.end_date)
        
        for symbol in self.correlated_symbols:
            self.data[symbol] = yf.download(symbol, 
                                          start=self.start_date, 
                                          end=self.end_date)
    
    def analyze_market_regime(self):
        """Analiza reżimu rynkowego"""
        df = self.data[self.primary_symbol]
        
        # Określenie trendu długoterminowego
        df['LT_trend'] = ta.trend.ema_indicator(df['Close'], window=200)
        df['ST_trend'] = ta.trend.ema_indicator(df['Close'], window=20)
        
        # Analiza zmienności
        df['volatility'] = df['Close'].pct_change().rolling(window=20).std()
        
        # Identyfikacja reżimu rynkowego
        df['market_regime'] = np.where(
            (df['ST_trend'] > df['LT_trend']) & (df['volatility'] < df['volatility'].mean()),
            'strong_uptrend',
            np.where(
                (df['ST_trend'] < df['LT_trend']) & (df['volatility'] < df['volatility'].mean()),
                'strong_downtrend',
                'choppy'
            )
        )
        
        return df['market_regime']
    
    def calculate_intermarket_signals(self):
        """Analiza sygnałów międzyrynkowych"""
        signals = {}
        primary_returns = self.data[self.primary_symbol]['Close'].pct_change()
        
        for symbol in self.correlated_symbols:
            corr_returns = self.data[symbol]['Close'].pct_change()
            correlation = primary_returns.rolling(window=30).corr(corr_returns)
            
            # Generowanie sygnałów na podstawie dywergencji
            signals[symbol] = np.where(
                (correlation > self.correlation_threshold) & 
                (np.sign(primary_returns) != np.sign(corr_returns)),
                1 if primary_returns.iloc[-1] > 0 else -1,
                0
            )
        
        return pd.DataFrame(signals)
    
    def generate_advanced_signals(self):
        """Generowanie zaawansowanych sygnałów tradingowych"""
        df = self.data[self.primary_symbol]
        signals = pd.DataFrame(index=df.index)
        
        # 1. Analiza momentum z filtrem trendu
        signals['trend_momentum'] = np.where(
            (df['Close'] > df['LT_trend']) & 
            (ta.momentum.rsi(df['Close']) < 30), 1,
            np.where(
                (df['Close'] < df['LT_trend']) & 
                (ta.momentum.rsi(df['Close']) > 70), -1, 0
            )
        )
        
        # 2. Analiza wolumenu
        signals['volume_trend'] = np.where(
            (df['Volume'] > df['Volume'].rolling(window=20).mean()) &
            (df['Close'] > df['Close'].shift(1)), 1,
            np.where(
                (df['Volume'] > df['Volume'].rolling(window=20).mean()) &
                (df['Close'] < df['Close'].shift(1)), -1, 0
            )
        )
        
        # 3. Dywergencje
        macd = ta.trend.macd_diff(df['Close'])
        signals['divergence'] = np.where(
            (df['Close'].diff() < 0) & (macd.diff() > 0), 1,
            np.where(
                (df['Close'].diff() > 0) & (macd.diff() < 0), -1, 0
            )
        )
        
        return signals
    
    def position_sizing(self, signal_strength, market_regime):
        """
        Zaawansowane zarządzanie pozycją
        """
        base_risk = 0.02  # 2% ryzyko bazowe
        
        # Dostosowanie wielkości pozycji do reżimu rynkowego
        regime_multiplier = {
            'strong_uptrend': 1.5,
            'strong_downtrend': 1.5,
            'choppy': 0.5
        }
        
        # Dostosowanie do zmienności
        volatility = self.data[self.primary_symbol]['Close'].pct_change().std() * np.sqrt(252)
        volatility_adjustment = 1 / volatility
        
        position_size = base_risk * signal_strength * regime_multiplier[market_regime] * volatility_adjustment
        return np.clip(position_size, -0.2, 0.2)  # Maksymalnie 20% kapitału na pozycję
    
    def apply_risk_management(self, position_size, current_price):
        """
        Zaawansowane zarządzanie ryzykiem
        """
        atr = ta.volatility.average_true_range(
            self.data[self.primary_symbol]['High'],
            self.data[self.primary_symbol]['Low'],
            self.data[self.primary_symbol]['Close']
        )
        
        # Dynamiczne stop-loss bazujący na ATR
        stop_loss = 2 * atr.iloc[-1]
        take_profit = 3 * stop_loss  # Stosunek zysku do ryzyka 3:1
        
        return {
            'position_size': position_size,
            'stop_loss': stop_loss,
            'take_profit': take_profit
        }

def run_enhanced_strategy(primary='GC=F', correlated=['SI=F', 'PL=F'], 
                         start_date='2020-01-01', end_date='2024-01-01'):
    """
    Uruchomienie ulepszonej strategii
    """
    strategy = EnhancedMetalsTradingStrategy(primary, correlated, start_date, end_date)
    strategy.fetch_all_data()
    
    # Analiza reżimu rynkowego
    market_regime = strategy.analyze_market_regime()
    
    # Generowanie sygnałów
    intermarket_signals = strategy.calculate_intermarket_signals()
    technical_signals = strategy.generate_advanced_signals()
    
    # Połączenie sygnałów i zarządzanie pozycją
    final_signals = pd.DataFrame()
    final_signals['combined_signal'] = (
        technical_signals['trend_momentum'] * 0.4 +
        technical_signals['volume_trend'] * 0.3 +
        technical_signals['divergence'] * 0.3 +
        intermarket_signals.mean(axis=1) * 0.2
    )
    
    # Zarządzanie pozycją i ryzykiem
    for i in range(len(final_signals)):
        position = strategy.position_sizing(
            final_signals['combined_signal'].iloc[i],
            market_regime.iloc[i]
        )
        risk_params = strategy.apply_risk_management(
            position,
            strategy.data[primary]['Close'].iloc[i]
        )
        final_signals.loc[final_signals.index[i], 'position_size'] = risk_params['position_size']
        final_signals.loc[final_signals.index[i], 'stop_loss'] = risk_params['stop_loss']
        final_signals.loc[final_signals.index[i], 'take_profit'] = risk_params['take_profit']
    
    return final_signals
