Overall Statistics
Total Trades
811
Average Win
2.10%
Average Loss
-0.44%
Compounding Annual Return
91.735%
Drawdown
29.200%
Expectancy
1.538
Net Profit
1277.319%
Sharpe Ratio
2.565
Probabilistic Sharpe Ratio
99.144%
Loss Rate
56%
Win Rate
44%
Profit-Loss Ratio
4.73
Alpha
0.608
Beta
0.035
Annual Standard Deviation
0.239
Annual Variance
0.057
Information Ratio
1.695
Tracking Error
0.288
Treynor Ratio
17.438
Total Fees
$9072.05
from datetime import timedelta

class order_ticket(): 
        
    def __init__(self, symbol, entry_price, quantity, stopLoss_percent, timestamp): 
        self.symbol = symbol
        self.entry_price = entry_price
        self.quantity = quantity
        self.stopLoss_percent = stopLoss_percent
        self.stopLoss_price  = round(self.entry_price * (1 - self.stopLoss_percent), 5)
        self.last_update = timestamp    
    
    
    def set_training_data(self, fast_ema, slow_ema, vslow_ema, rsi, macd_current, macd_signal, macd_delta_percent): 
        #Training Data
        self.fast_ema = fast_ema
        self.slow_ema = slow_ema 
        self.vslow_ema = vslow_ema
        self.rsi = rsi
        self.macd_current = macd_current
        self.macd_signal = macd_signal
        self.macd_delta_percent = macd_delta_percent
    
    def get_training_set(self, close, max_target = 3, static = True): #Input Data [entry price, fast ema, slow ema, very slow ema, fast ema / slow ema diff, slow ema / very slow ema diff, rsi, macd]
        train_set = [self.entry_price, self.fast_ema, self.slow_ema, self.vslow_ema, \
                    self.fast_ema / self.slow_ema, self.slow_ema / self.vslow_ema, self.rsi, \
                    self.macd_current, self.macd_signal, self.macd_delta_percent]
                    
        y = 0
        if static: 
            if self.is_winning(close, 1.01): 
                y = max_target
            elif self.is_winning(close, 1.06): 
                y = max_target * (1 / 3)
            elif self.is_winning(close, 1.03): 
                y = max_target * (1 / 10)
            elif self.is_winning(close, 1): 
                y = max_target * (1 / 20)
        else: 
            for i in range(1, 20): #Check 1 to 20% win 
                if self.is_winning(close, 1+ (i / 100)): 
                    y = i / 100
                else: 
                    break
        return train_set, y
    
    def reached_stop_limit(self, close): 
        if self.stopLoss_price == 0: 
            return False
            
        if close < self.stopLoss_price: 
            return True
            
        return False
    
    def is_losing(self, close): 
        if self.stopLoss_price == 0: 
            return False
            
        if close < self.entry_price: 
            return True

        return False
    
    def is_winning(self, close, tolerance = 1.05): 
        if self.stopLoss_price == 0: 
            return False
            
        if close > (self.entry_price * tolerance): 
            return True
        return False
    
    def reached_timeout(self, timestamp, timeout = 4): 
        if self.last_update + timedelta(days=timeout) < timestamp: 
            return True
        else:
            return False

    
    def update_stop_limit(self, close_price, timestamp): 
        new_stop_price = round(close_price * (1 - self.stopLoss_percent), 5)
        if new_stop_price > self.stopLoss_price: 
            self.stopLoss_price  = new_stop_price
            self.last_update = timestamp
            return True
        else: 
            return False
from System.Drawing import Color

from order_ticket import order_ticket
from model import crypto_raptor

import numpy as np

import json
from json import JSONEncoder

import tensorflow as tf

class CryptoRaptor(QCAlgorithm):
    def Initialize(self):
        #Backtesting Settings
        self.SetStartDate(2017, 1, 1)  
        self.SetCash(10000)  
        
        #Time Period Settings
        self.short_period = 11
        self.long_period = 25
        self.vlong_period = 80
        self.macd_signal_period = 9
        self.trailing_stop_loss_period = 2
        self.block_period = 24 
        
        #Indicator Settings
        self.ema_long_buy_signal = 1.01
        self.ema_short_buy_signal = 0.97
        self.macd_buy_signal = 0.0025
        self.rsi_buy_signal = 30
        
        
        #Buy Strategy & TSSL Settings
        self.trading_limit = 100000
        self.quantity_target_multiplier = float(self.GetParameter('Quantity Target Multiplier'))
        self.max_quantity_per_trade = 3
        self.min_quantity_per_trade = 0.2
        self.buy_long_stopLossPercent = 0.12
        self.buy_short_stopLossPercent = 0.06
        
        #AI Settings
        self.input_dimension = 10
        self.batch_size = 16
        self.create_new_model = False
        self.overwrite_saved_model = False
        
        #Crypto trading pairs
        self.tickers = ['BTCUSD', 'NEOUSD', 'ETCUSD', 'LTCUSD'] #, 'EOSUSD', 'ETPUSD', 'TRXUSD', 'XTZUSD', 'VSYUSD']: 
        
        #Algorithm Variables
        self.ticket_per_symbol = {}
        self.blocked_until_per_symbol = {}
        
        self.rsi_per_symbol = {}
        self.short_ema_per_symbol = {}        
        self.long_ema_per_symbol = {}
        self.vlong_ema_per_symbol = {}
        self.macd_per_symbol = {}
        self.bb_per_symbol = {}
        
        self.training_data_per_symbol = {}
        self.model_per_symbol = {}
        
        #self.SetWarmUp(self.vlong_period)
        #self.SetAlpha(CryptoRaptorAlphaModel(self.tickers))
        self.SetBrokerageModel(BrokerageName.AlphaStreams)
        #self.SetBrokerageModel(BrokerageName.Bitfinex, AccountType.Cash) 
        
        for ticker in self.tickers: 
            self.AddCrypto(ticker, Resolution.Hour, Market.Bitfinex)
            
            #Load models
            if not self.create_new_model: 
                self.model_per_symbol[ticker] = self.load_quantity_model(ticker)
            else: 
                self.model_per_symbol[ticker] = crypto_raptor(ticker, self.input_dimension, self.batch_size, False, None, None)
                
            
            #Init empty training data list for each symbol [x training data, y training data]
            self.training_data_per_symbol[ticker] = [[], []]
            
            #Add the RSI Indicator 
            rsi = self.RSI(ticker, self.long_period,  MovingAverageType.Simple, Resolution.Daily)
            if self.warmup_indicator(ticker, rsi, self.long_period): 
                self.Debug("RSI Indicator for {0} is ready!".format(ticker))
            self.rsi_per_symbol[ticker] = rsi
            
            #Add the EMA Short Indicator 
            short_ema = self.EMA(ticker, self.short_period, Resolution.Daily)
            if self.warmup_indicator(ticker, short_ema, self.short_period): 
                self.Debug("EMA ({1}) Indicator for {0} is ready!".format(ticker, self.short_period))
            self.short_ema_per_symbol[ticker] = short_ema            
            
            #Add the EMA Long indicator
            long_ema = self.EMA(ticker, self.long_period, Resolution.Daily)
            if self.warmup_indicator(ticker, long_ema, self.long_period): 
                self.Debug("EMA ({1}) Indicator for {0} is ready!".format(ticker, self.long_period))
            self.long_ema_per_symbol[ticker] = long_ema            
            
            #Add the EMA Very Long Indicator
            vlong_ema = self.EMA(ticker, self.vlong_period, Resolution.Daily)
            if self.warmup_indicator(ticker, vlong_ema, self.vlong_period): 
                self.Debug("EMA ({1}) Indicator for {0} is ready!".format(ticker, self.vlong_period))
            self.vlong_ema_per_symbol[ticker] = vlong_ema
            
            # define our daily macd(12,26) with a 9 day signal
            macd = self.MACD(ticker, self.short_period, self.long_period, self.macd_signal_period, MovingAverageType.Exponential, Resolution.Daily)
            if self.warmup_indicator(ticker, macd, self.long_period): 
                self.Debug("MACD Indicator for {0} is ready!".format(ticker))
            self.macd_per_symbol[ticker] = macd
            
            
            stockPlot = Chart('{0} Trade Plot'.format(ticker))
            # On the Trade Plotter Chart we want 3 series: trades and price:
            stockPlot.AddSeries(Series('Buy', SeriesType.Scatter, '$', Color.Green))
            stockPlot.AddSeries(Series('Sell', SeriesType.Scatter, '$', Color.Red))

            stockPlot.AddSeries(Series('EMA Short', SeriesType.Line, '$', Color.White))
            stockPlot.AddSeries(Series('EMA Long', SeriesType.Line, '$', Color.Yellow))
            stockPlot.AddSeries(Series('EMA vLong', SeriesType.Line, '$', Color.Orange))
            stockPlot.AddSeries(Series('Price', SeriesType.Line, '$', Color.Blue))
            self.AddChart(stockPlot)
        
        self.Schedule.On(self.DateRules.EveryDay(), \
            self.TimeRules.Every(TimeSpan.FromHours(12)), \
            Action(self.chart))
        
        self.Schedule.On(self.DateRules.EveryDay(), \
            self.TimeRules.Every(TimeSpan.FromHours(12)), \
            Action(self.trade))
            
        self.Schedule.On(self.DateRules.EveryDay(), \
            self.TimeRules.Every(TimeSpan.FromHours(self.trailing_stop_loss_period)), \
            Action(self.trail_positions))
            
        self.Schedule.On(self.DateRules.EveryDay(), \
            self.TimeRules.At(1, 0), \
            #self.TimeRules.Every(TimeSpan.FromHours(6)), \
            Action(self.observe))
        
        self.Train(self.DateRules.Every(DayOfWeek.Monday), \
            self.TimeRules.At(2, 30), \
            Action(self.train_quantity_models))
    
    def warmup_indicator(self, symbol, indicator, warmup_duration): 
        history = self.History([symbol], warmup_duration, Resolution.Daily)
        if not history.empty: 
            for time, row in history.loc[symbol].iterrows():
                indicator.Update(time, row["close"])
            return True
        else: 
            return False 
            
    def is_blocked(self, symbol): 
        if symbol in self.blocked_until_per_symbol: 
            blocked_until = self.blocked_until_per_symbol[symbol]
            if self.Time > blocked_until: 
                del self.blocked_until_per_symbol[symbol]
                return False
            else: 
                return True
        else: 
            return False
    
    def block_symbol(self, symbol, block_timespan): 
        self.blocked_until_per_symbol[symbol] = self.Time + timedelta(hours=block_timespan)
    
    def OnEndOfAlgorithm(self):
        if self.overwrite_saved_model: 
            for symbol, model in self.model_per_symbol.items():
                self.save_quantity_model(symbol, model)
    
    def chart(self): 
        for symbol in list(self.tickers): 
            close_price = self.Securities[symbol].Close
            
            #Charting
            self.Plot('{0} Trade Plot'.format(symbol), 'Price', close_price)
            
            if symbol in self.short_ema_per_symbol: 
                ema = self.short_ema_per_symbol[symbol]
                
                if ema.IsReady: 
                    ema_value = ema.Current.Value
                    
                    self.Plot('{0} Trade Plot'.format(symbol), 'EMA Short', ema_value)
            
            if symbol in self.long_ema_per_symbol: 
                ema = self.long_ema_per_symbol[symbol]
                
                if ema.IsReady: 
                    ema_current = ema.Current.Value
                    
                    self.Plot('{0} Trade Plot'.format(symbol), 'EMA Long', ema_current)            
                    
            if symbol in self.vlong_ema_per_symbol: 
                ema = self.vlong_ema_per_symbol[symbol]
                
                if ema.IsReady: 
                    ema_current = ema.Current.Value
                    
                    self.Plot('{0} Trade Plot'.format(symbol), 'EMA vLong', ema_current)
    
    def trail_positions(self): 
        for symbol in list(self.tickers): 
            open_price = self.Securities[symbol].Open
            close_price = self.Securities[symbol].Close
            
            if close_price <= 0: 
                self.Debug("Error occured: Close Price for {0} is {1}".format(symbol, close_price))
                continue
            
            if self.Portfolio[symbol].Invested: 
                self.UpdateTicket(symbol, close_price)
                continue
    
    def trade(self): 
        orders = list(self.ticket_per_symbol.values())
        for order in orders: 
            close_price = self.Securities[order.symbol].Close
            if close_price <= 0: 
                self.Debug("Error occured: Close Price for {0} is {1}".format(order.symbol, close_price))
                continue
            
            #Kill Position if reached stop loss
            if order.reached_stop_limit(close_price):
                self.add_ticket_to_training_data(order.symbol, close_price, order)
                
                #Kill Position
                del self.ticket_per_symbol[order.symbol]
                self.EmitInsights(Insight.Price(order.symbol, timedelta(self.short_period), InsightDirection.Flat))
                self.Sell(order.symbol, order.quantity)
                self.Debug("Order for {0} reached Stop Loss".format(order.symbol))
            
            #Kill losing positions after 4 days
            elif order.is_losing(close_price) and order.reached_timeout(self.Time, 4): #Timeout of winning Position after 4 days
                self.add_ticket_to_training_data(order.symbol, close_price, order)
                self.block_symbol(order.symbol, 12)
                
                #Kill Position
                del self.ticket_per_symbol[order.symbol]
                self.EmitInsights(Insight.Price(order.symbol, timedelta(self.short_period), InsightDirection.Flat))
                self.Sell(order.symbol, order.quantity)
                self.Debug("Order for {0} reached Timeout".format(order.symbol))
            
            #Kill winning positions after 2 days
            elif order.is_winning(close_price) and order.reached_timeout(self.Time, 2):  #Timeout of losing Position after 2 days
                self.add_ticket_to_training_data(order.symbol, close_price, order)
                #self.block_symbol(order.symbol, 12)
            
                #Kill Position
                del self.ticket_per_symbol[order.symbol]
                self.EmitInsights(Insight.Price(order.symbol, timedelta(self.short_period), InsightDirection.Flat))
                self.Sell(order.symbol, order.quantity)
                self.Debug("Order for {0} reached Timeout".format(order.symbol))
    
    def add_ticket_to_training_data(self, symbol, close_price, order): 
        #Gather Training Data
        train_set, y_target = self.ticket_per_symbol[order.symbol].get_training_set(close_price, self.max_quantity_per_trade)
        self.training_data_per_symbol[symbol][0].append(train_set)
        self.training_data_per_symbol[symbol][1].append(y_target)
        return 
                    

    def observe(self):
        if self.IsWarmingUp: 
            self.Debug("Crypto Raptor is warming up")
            return
        for symbol in list(self.tickers): 
            open_price = self.Securities[symbol].Open
            close_price = self.Securities[symbol].Close
            
            if close_price <= 0: 
                self.Debug("Error occured: Close Price for {0} is {1}".format(symbol, close_price))
                continue
            
            if self.Portfolio[symbol].Invested: 
                continue
            
            if self.is_blocked(symbol): 
                continue
            
            #Calculate the EMA Indicators
            ema_buy_signal = False
            ema_short_buy_signal = False
            ema_sell_signal = False
            ema_short_value = 0
            ema_long_value = 0
            ema_vlong_value = 0
            if symbol in self.short_ema_per_symbol and symbol in self.long_ema_per_symbol and symbol in self.vlong_ema_per_symbol: 
                    
                ema_short = self.short_ema_per_symbol[symbol]
                if ema_short.IsReady: 
                    ema_short_value = ema_short.Current.Value
                else: 
                    self.Debug("EMA ({1}) Indicator for {0} is not ready!".format(symbol, self.short_period))
                    continue
                    
                ema_long = self.long_ema_per_symbol[symbol]
                if ema_long.IsReady: 
                    ema_long_value = ema_long.Current.Value
                else:
                    self.Debug("EMA ({1}) Indicator for {0} is not ready!".format(symbol, self.long_period))
                    continue
                
                ema_vlong = self.vlong_ema_per_symbol[symbol]
                if ema_vlong.IsReady: 
                    ema_vlong_value = ema_vlong.Current.Value
                else: 
                    self.Debug("EMA ({1}) Indicator for {0} is not ready!".format(symbol, self.vlong_period))
                    continue

                if ema_short_value / ema_long_value > self.ema_long_buy_signal and ema_long_value / ema_vlong_value > 1.0: 
                    ema_buy_signal = True  
                if ema_short_value / ema_long_value < self.ema_short_buy_signal and ema_long_value / ema_vlong_value < 1.0: 
                    ema_short_buy_signal = True
                        
            
            rsi_buy_signal = False
            #Calculate the RSI Indicator  
            if symbol in self.rsi_per_symbol: 
                rsi = self.rsi_per_symbol[symbol]
                if rsi.IsReady:
                    # get the current RSI value
                    rsi_value = rsi.Current.Value
                   
                    if rsi_value > 0 and rsi_value < self.rsi_buy_signal: 
                        rsi_buy_signal = True
                else: 
                    self.Debug("RSI Indicator for {0} is not ready!".format(symbol))
                    continue
                    
            
            macd_buy_signal = False
            #Calculate the MACD Indicator
            if symbol in self.macd_per_symbol: 
                macd = self.macd_per_symbol[symbol]
                if macd.IsReady: 
                    signalDeltaPercent = (macd.Current.Value - macd.Signal.Current.Value) / macd.Fast.Current.Value
                    if signalDeltaPercent > self.macd_buy_signal: 
                        macd_buy_signal = True                    
                else: 
                    self.Debug("MACD Indicator for {0} is not ready!".format(symbol))
                    continue  
            
            #Ask model for the 'best' quantity
            quantity_prediction = self.get_quantity_prediction(symbol, close_price, ema_short_value, ema_long_value, ema_vlong_value, rsi_value, macd.Current.Value, macd.Signal.Current.Value, signalDeltaPercent)
            self.Debug("{0} models quantity prediction: {1}".format(symbol, quantity_prediction))
            
            if quantity_prediction < -self.max_quantity_per_trade: 
                self.Debug(f'{symbol} raptor cancel the buy')
                continue
            
            target_quantity = quantity_prediction * self.quantity_target_multiplier 
            if target_quantity < self.min_quantity_per_trade: 
                target_quantity = self.min_quantity_per_trade * self.quantity_target_multiplier
            elif target_quantity > self.max_quantity_per_trade: 
                target_quantity = self.max_quantity_per_trade * self.quantity_target_multiplier
                
            #Compare Indicators
            if ema_buy_signal and macd_buy_signal: 
                self.OpenBuy(symbol, close_price, self.buy_long_stopLossPercent, target_quantity)
                # Creates an insight for our symbol, predicting that it will move down within
                #  the fast ema period number of days
                self.EmitInsights(Insight.Price(symbol, timedelta(self.short_period), InsightDirection.Up, None, None, None, 0.25))
                self.ticket_per_symbol[symbol].set_training_data(ema_short_value, ema_long_value, ema_vlong_value, rsi_value, macd.Current.Value, macd.Signal.Current.Value, signalDeltaPercent)
                continue
            elif ema_buy_signal and rsi_buy_signal: 
                self.OpenBuy(symbol, close_price, self.buy_long_stopLossPercent,  target_quantity)
                self.EmitInsights(Insight.Price(symbol, timedelta(self.short_period), InsightDirection.Up, None, None, None, 0.25))
                self.ticket_per_symbol[symbol].set_training_data(ema_short_value, ema_long_value, ema_vlong_value, rsi_value, macd.Current.Value, macd.Signal.Current.Value, signalDeltaPercent)
                continue
            
            elif ema_short_buy_signal and macd_buy_signal: 
                self.OpenBuy(symbol, close_price, self.buy_short_stopLossPercent, target_quantity)
                self.EmitInsights(Insight.Price(symbol, timedelta(self.short_period), InsightDirection.Up, None, None, None, 0.10))
                self.ticket_per_symbol[symbol].set_training_data(ema_short_value, ema_long_value, ema_vlong_value, rsi_value, macd.Current.Value, macd.Signal.Current.Value, signalDeltaPercent)
                continue
            else: 
                self.Debug("raptor for {0} is waiting for better buy signals...".format(symbol))
                continue

    
    def OpenBuy(self, symbol, close_price, stop_loss, target_quantity): 
        # Calculate the fee adjusted quantity of shares with given buying power
        target = (1 / len(self.tickers)) * target_quantity
        quantity = round(self.CalculateOrderQuantity(symbol, target), 1)
        
        if quantity * close_price > self.trading_limit: 
            quantity = self.trading_limit / close_price
        
        #Calculate the stop price
        stop_price  = round(close_price * (1 - stop_loss), 5)
        
        self.Debug("Open Buy Order with {0} {1} on {2}. SL: {3}".format(quantity, symbol, close_price, stop_price))
        self.Plot('{0} Trade Plot'.format(symbol), 'Buy', close_price)
        
        self.Buy(symbol, quantity)
        #self.SetHoldings(symbol, target)
        ticket = order_ticket(symbol, close_price, quantity, stop_loss, self.Time)
        self.ticket_per_symbol[symbol] = ticket
        
    #def OpenSell(self, symbol, close_price, target_quantity): 
    #    # Calculate the fee adjusted quantity of shares with given buying power
    #    target = (1 / len(self.rsi_per_symbol.keys())) * target_quantity
    #    quantity = self.CalculateOrderQuantity(symbol, target)
    #    
    #    #Calculate the stop price
    #    stop_price  = round(close_price * (1 + self.sell_stopLossPercent), 5)
    #                            
    #    self.Debug("Open Sell Order for {0} on {1}. SL: {2}".format(symbol, close_price, stop_price))
    #    self.Plot('{0} Trade Plot'.format(symbol), 'Sell', close_price)
    #    
    #    self.Sell(symbol, quantity)
    #    ticket = order_ticket(symbol, close_price, stop_price, self.Time)
    #    self.ticket_per_symbol[symbol] = ticket

    
    def UpdateTicket(self, symbol, close_price): 
        
        if symbol in self.ticket_per_symbol: 
            buy_ticket = self.ticket_per_symbol[symbol]
            
            old_stop_price = buy_ticket.stopLoss_price
            if buy_ticket.update_stop_limit(close_price, self.Time): 
                self.Debug("{0}: Stop Loss for {1} Buy Order updated from {2} to {3}".format(self.Time, symbol, old_stop_price, buy_ticket.stopLoss_price))
    
    def save_quantity_model(self, symbol, model): 
        model_config, encoded_weights = model.get_model_config_and_weights()
        
        self.ObjectStore.Save(f'{symbol}_modelv2_weights', str(encoded_weights))
        self.ObjectStore.Save(f'{symbol}_modelv2', model_config)
        self.Debug(f'Model for {symbol} sucessfully saved in the ObjectStore')
    
    def load_quantity_model(self, symbol): 
        if self.ObjectStore.ContainsKey(f'{symbol}_modelv2'): 
            modelStr = self.ObjectStore.Read(f'{symbol}_modelv2')
            config = json.loads(modelStr)['config']
            decoded_weights = json.loads(self.ObjectStore.Read(f"{symbol}_modelv2_weights"))
            weights = np.asarray(decoded_weights)
            model = crypto_raptor(symbol, self.input_dimension, self.batch_size, True, config, weights)
            self.Debug("Model for {0} loaded from Object Store.".format(symbol))
        else: 
            #Create default
            self.Debug("Create default Model for {0}.".format(symbol))
            model = crypto_raptor(symbol, self.input_dimension, self.batch_size, False, None, None)

        tf.keras.backend.clear_session()
        return model
    
    def train_quantity_models(self): 
        for symbol in self.tickers: 
            close_price = self.Securities[symbol].Close
            
            self.process_training_data(symbol, close_price)
    
    def process_training_data(self, symbol, close): 
        
        if symbol in self.training_data_per_symbol: 
            x_train_list, y_train_list = self.training_data_per_symbol[symbol]
            
            if len(x_train_list) < self.batch_size: 
                self.Debug("Need more examples to start training.")
                return
            
            model = self.model_per_symbol[symbol]
            
            model.process_training_data(symbol, close, x_train_list, y_train_list)
                    
    #Ask model for quantity
    def get_quantity_prediction(self, symbol, close_price, fast_ema, long_ema, vlong_ema, rsi, macd_current, macd_signal, macd_delta): 
        if symbol in self.model_per_symbol: 
            model = self.model_per_symbol[symbol]
        else: 
            return 0

        train_set = [close_price, fast_ema, long_ema, vlong_ema, \
                    fast_ema / long_ema, long_ema / vlong_ema, rsi, \
                    macd_current, macd_signal, macd_delta]
        
        test = np.zeros([self.batch_size, len(train_set)])
        for i in range(len(train_set)): 
            test[-1, i] = train_set[i]
            
        prediction = model.get_quantity_prediction(symbol, test)
        return prediction
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation,LSTM,Dropout,Embedding,Flatten
from keras.optimizers import SGD, Adam
from keras.utils.generic_utils import serialize_keras_object


import numpy as np

import json
from json import JSONEncoder


class crypto_raptor(): 
    def __init__(self, symbol, input_dimension, batch_size, from_config = False, model_config = None, weights = None): 
        self.symbol = symbol; 
        self.batch_size = batch_size
        #self.session = tf.Session()
        self.graph = tf.get_default_graph()
        self.session = keras.backend.get_session()
        self.session.run(tf.global_variables_initializer())
        
        self.untrained = True
        
        if from_config == False: 
            self.create_model(input_dimension)
        else: 
            self.create_model_from_config(model_config, weights)
            self.untrained = False
    
    def get_model_config_and_weights(self): 
        with self.graph.as_default():
            with self.session.as_default():
                weights = self.model.get_weights()
                encoded_weights = json.dumps(weights, cls=NumpyArrayEncoder)
                modelStr = json.dumps(serialize_keras_object(self.model))
        return modelStr, encoded_weights
    
    def create_model_from_config(self, model_config, weights): 
        with self.graph.as_default():
            with self.session.as_default():
                self.model = Sequential.from_config(model_config)
                self.model.set_weights(weights)
                sgd = SGD(lr = 0.01, clipnorm=1.0)   # learning rate = 0.01
                    # choose loss function and optimizing method
                self.model.compile(loss='mse', optimizer=sgd)
    
    def create_model(self, input_dimension): 
        with self.graph.as_default():
            with self.session.as_default():
            
                self.model = Sequential()
            
                self.model.add(Dense(32, input_dim=input_dimension, activation='relu'))
                self.model.add(Dense(16, activation='relu'))
                #model.add(Dense(1, activation='sigmoid'))
                self.model.add(Dense(1))
                    
                sgd = SGD(lr = 0.01, clipnorm=1.0)   # learning rate = 0.01
                
                # choose loss function and optimizing method
                self.model.compile(loss='mse', optimizer=sgd)

    def process_training_data(self, symbol, close, x_train_list, y_train_list): 
        num_examples = len(x_train_list) #length of x
        
        if self.untrained == True: 
            training_duration = 100
        else: 
            training_duration = 10
        
        with self.graph.as_default():
            #set_session(self.session)
            with self.session.as_default():
                for i in range(self.batch_size, num_examples, self.batch_size): 
                    x_train = x_train_list[i -self.batch_size: i]
                    y_train = y_train_list[i -self.batch_size: i]
                        
                    if np.any(np.isnan(x_train)): 
                        #self.Debug("Error in Training Data")
                        continue
                    if np.any(np.isnan(y_train)): 
                        #self.Debug("Error in Validation Data")
                        continue
                    for i in range(training_duration): 
                        x_train = np.array(x_train_list)
                        y_train = np.array(y_train_list)
    
                        self.model.fit(x_train, y_train, epochs=1)
        tf.keras.backend.clear_session()
        self.untrained = False
                    
                
    #Ask model for quantity
    def get_quantity_prediction(self, symbol, pred_set): 
        if self.untrained == True: 
            return 0
            
        with self.graph.as_default():
            with self.session.as_default():
                prediction = self.model.predict(pred_set, batch_size=1,  verbose=1)
        tf.keras.backend.clear_session()
            
        #self.Debug(prediction)
        return round(prediction[-1][0], 3)
    
            
class NumpyArrayEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return JSONEncoder.default(self, obj)