Overall Statistics
# https://quantpedia.com/Screener/Details/100
from QuantConnect.Data import SubscriptionDataSource
from QuantConnect.Python import PythonData
from QuantConnect.Python import PythonQuandl
from QuantConnect.Data.Custom import *
from sklearn import datasets, linear_model
from datetime import date, timedelta, datetime
from collections import deque
import statsmodels.api as sm
from decimal import Decimal
import numpy as np


class TradeWtiBrentSpreadAlgorithm(QCAlgorithm):

    def Initialize(self):
        self.SetStartDate(2018, 6, 1)
        self.SetEndDate(2018, 12, 1)
        self.SetCash(1000000)
        # import the custom data 
        self.AddData(WTI, "WTI", Resolution.Daily)
        self.AddData(BRENT, "BRENT", Resolution.Daily)
        self.es1 = self.AddData(Quandl, "CHRIS/CME_ES1", Resolution.Daily) # Add Quandl E-mini S&P500 front month futures data (daily)
        self.nq1 = self.AddData(Quandl, "CHRIS/CME_NQ1", Resolution.Daily)
        #specifying exchange so bars work properly
        self.Securities["CHRIS/CME_ES1"].Exchange = EquityExchange()
        self.Securities["CHRIS/CME_NQ1"].Exchange = EquityExchange()
        ES = self.AddFuture(Futures.Indices.SP500EMini) 
        ES.SetFilter(timedelta(0), timedelta(days=180))
        NQ = self.AddFuture(Futures.Indices.NASDAQ100EMini)
        NQ.SetFilter(timedelta(0), timedelta(days=180))
        '''
        # Explore the future contract chain
    def OnData(self, slice):
        for chain in slice.FutureChains.Values:
            contracts = chain.Contracts
        for contract in contracts:
            pass'''
        #for WTI|BRENT
        self.SpreadSMA = SimpleMovingAverage(13)
        #hist = self.History(["WTI", "BRENT"], 400, Resolution.Daily)["value"].unstack(level=0).dropna()
        self.esnqSMA = SimpleMovingAverage(13)
        self.esnqEMA = ExponentialMovingAverage(7)
        
        tradeBarHistory = self.History(["CHRIS/CME_ES1"], 30)
        self.AssertHistoryCount("History([\"CHRIS/CME_ES1\"], 30)", tradeBarHistory, 30)
        
        for index, tradeBar in tradeBarHistory.loc["CHRIS/CME_ES1"].iterrows():
            self.esnqSMA.Update(index, tradeBar["settle"])
            
        for index, tradeBar in tradeBarHistory.loc["CHRIS/CME_ES1"].iterrows():
            self.esnqEMA.Update(index, tradeBar[""])
        
        #Create the EMA spread for the liquidate signal
        self.wtiema = self.EMA("WTI", 7)
        self.brentema = self.EMA("BRENT", 7)
        self.wtibrentema = IndicatorExtensions.Minus(self.wtiema, self.brentema)
        #self.es1ema = self.EMA(self.es1, 7)
        #self.nq1ema = self.EMA(self.nq1, 7)
        
        # Add the spread plot and mark the long/short spread point
        spreadPlot = Chart("Spread Plot")
        spreadPlot.AddSeries(Series("Spread", SeriesType.Line, 0))
        #spreadPlot.AddSeries(Series("Long Spread Trade", SeriesType.Scatter, 0))
        #spreadPlot.AddSeries(Series("Short Spread Trade", SeriesType.Scatter, 0))
        spreadPlot.AddSeries(Series("Spread EMA", SeriesType.Line, 0))
        spreadPlot.AddSeries(Series("Spread SMA", SeriesType.Line, 0))
        self.AddChart(spreadPlot)
        
    def OnData(self, data):
        if not (data.ContainsKey("WTI") and data.ContainsKey("BRENT")): return
        self.Plot("Spread Plot", "Spread", data["WTI"].Price - data["BRENT"].Price)
        self.Plot("Spread Plot", "Spread EMA", self.wtibrentema.Current.Value)
        
        self.SpreadSMA.Update(self.Time, data["WTI"].Price - data["BRENT"].Price) 
        #self.esnqEMA.Update(self.Time, data["CHRIS/CME_NQ1"].Price - data["CHRIS/CME_ES1"].Price)
        
        if not self.SpreadSMA.IsReady: return
        spread = self.Securities["WTI"].Price - self.Securities["BRENT"].Price 
        #fair_value = (self.Securities["WTI"].Price - Decimal(self.regr.predict(self.Securities["WTI"].Price)[0])).reshape(-1, 1)      
        self.Plot("Spread Plot", "Spread SMA", self.SpreadSMA.Current.Value)
        
        if spread > self.SpreadSMA.Current.Value and not (self.Portfolio["WTI"].IsShort and self.Portfolio["BRENT"].IsLong):
            self.SetHoldings("WTI", -0.5)
            self.SetHoldings("BRENT", 0.5)
            #self.Plot("Spread Plot", "Long Spread Trade", data["WTI"].Price - data["BRENT"].Price)
            self.Debug("esnqEMA worked!!! " + str(self.esnqEMA.Current.Value))
            self.Debug("esnqSMA worked!!! " + str(self.esnqSMA.Current.Value))
        elif spread < self.SpreadSMA.Current.Value and not (self.Portfolio["WTI"].IsLong and self.Portfolio["BRENT"].IsShort):
            self.SetHoldings("WTI", 0.5)
            self.SetHoldings("BRENT", -0.5)
            #self.Plot("Spread Plot", "Short Spread Trade", data["WTI"].Price - data["BRENT"].Price)
            
        if self.Portfolio["WTI"].IsShort and self.Portfolio["BRENT"].IsLong and spread < self.wtibrentema.Current.Value:
            self.Liquidate()
            #self.Debug("the ema of wti/brent is " + str(self.wtibrentema.Current.Value))
        
        if self.Portfolio["WTI"].IsLong and self.Portfolio["BRENT"].IsShort and spread > self.wtibrentema.Current.Value:
            self.Liquidate()        
        
        
class WTI(PythonData):
    "Class to import WTI Spot Price(Dollars per Barrel) data from Dropbox"
    
    def GetSource(self, config, date, isLiveMode):
        return SubscriptionDataSource("https://www.dropbox.com/s/jpie3z6j0stp97d/wti-crude-oil-prices-10-year-daily.csv?dl=1", SubscriptionTransportMedium.RemoteFile)

    def Reader(self, config, line, date, isLiveMode):
        if not (line.strip() and line[1].isdigit()): return None
        index = WTI()
        index.Symbol = config.Symbol
        try:
            # Example File Format: (Data starts from 08/11/2008)
            # date     value
            # 8/11/08    114.44
            data = line.split(',')
            index.Time = datetime.strptime(data[0], "%Y-%m-%d")
            index.Value = Decimal(data[1])
        except:
            return None
            
        return index

class BRENT(PythonData):
    "Class to import BRENT Spot Price(Dollars per Barrel) data from Dropbox"
    
    def GetSource(self, config, date, isLiveMode):
        return SubscriptionDataSource("https://www.dropbox.com/s/w380c4n7xjmdqxl/brent-crude-oil-prices-10-year-daily.csv?dl=1", SubscriptionTransportMedium.RemoteFile)

    def Reader(self, config, line, date, isLiveMode):
        if not (line.strip() and line[1].isdigit()): return None
        index = BRENT()
        index.Symbol = config.Symbol
        try:
            # Example File Format: (Data starts from 08/11/2008)
            # date     value
            # 8/11/08    110.54
            data = line.split(',')
            index.Time = datetime.strptime(data[0], "%Y-%m-%d")
            index.Value = Decimal(data[1])
        except:
            return None
            
        return index