Created with Highcharts 12.1.2EquityJul 2023Sep 2023Nov 2023Jan 2024Mar 2024May 2024Jul 2024Sep 2024Nov 2024Jan 2025Mar 2025May 202502M4M-20-1000250500-0.500.50500k1,000k0100G200G90100110
Overall Statistics
Total Orders
55678
Average Win
0.09%
Average Loss
-0.11%
Compounding Annual Return
70.961%
Drawdown
22.400%
Expectancy
0.033
Start Equity
1000000
End Equity
2572066.96
Net Profit
157.207%
Sharpe Ratio
1.98
Sortino Ratio
2.382
Probabilistic Sharpe Ratio
90.557%
Loss Rate
45%
Win Rate
55%
Profit-Loss Ratio
0.86
Alpha
0.431
Beta
-0.087
Annual Standard Deviation
0.216
Annual Variance
0.047
Information Ratio
1.573
Tracking Error
0.247
Treynor Ratio
-4.918
Total Fees
â‚®0.00
Estimated Strategy Capacity
â‚®320000.00
Lowest Capacity Asset
BTCUSDT 18N
Portfolio Turnover
8464.84%
#region imports
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.Globalization;
using System.Linq;
using System.Text;
using Newtonsoft.Json;
using QuantConnect;
using QuantConnect.Algorithm;
using QuantConnect.Algorithm.Framework;
using QuantConnect.Algorithm.Framework.Alphas;
using QuantConnect.Algorithm.Framework.Execution;
using QuantConnect.Algorithm.Framework.Portfolio;
using QuantConnect.Algorithm.Framework.Portfolio.SignalExports;
using QuantConnect.Algorithm.Framework.Risk;
using QuantConnect.Algorithm.Framework.Selection;
using QuantConnect.Algorithm.Selection;
using QuantConnect.Api;
using QuantConnect.Benchmarks;
using QuantConnect.Brokerages;
using QuantConnect.Commands;
using QuantConnect.Configuration;
using QuantConnect.Data;
using QuantConnect.Data.Auxiliary;
using QuantConnect.Data.Consolidators;
using QuantConnect.Data.Custom;
using QuantConnect.Data.Custom.IconicTypes;
using QuantConnect.Data.Fundamental;
using QuantConnect.Data.Market;
using QuantConnect.Data.Shortable;
using QuantConnect.Data.UniverseSelection;
using QuantConnect.DataSource;
using QuantConnect.Indicators;
using QuantConnect.Interfaces;
using QuantConnect.Notifications;
using QuantConnect.Orders;
using QuantConnect.Orders.Fees;
using QuantConnect.Orders.Fills;
using QuantConnect.Orders.OptionExercise;
using QuantConnect.Orders.Slippage;
using QuantConnect.Orders.TimeInForces;
using QuantConnect.Parameters;
using QuantConnect.Python;
using QuantConnect.Scheduling;
using QuantConnect.Securities;
using QuantConnect.Securities.Crypto;
using QuantConnect.Securities.CryptoFuture;
using QuantConnect.Securities.Equity;
using QuantConnect.Securities.Forex;
using QuantConnect.Securities.Future;
using QuantConnect.Securities.IndexOption;
using QuantConnect.Securities.Interfaces;
using QuantConnect.Securities.Option;
using QuantConnect.Securities.Positions;
using QuantConnect.Securities.Volatility;
using QuantConnect.Statistics;
using QuantConnect.Storage;
using QuantConnect.Util;
using Calendar = QuantConnect.Data.Consolidators.Calendar;
using QCAlgorithmFramework = QuantConnect.Algorithm.QCAlgorithm;
using QCAlgorithmFrameworkBridge = QuantConnect.Algorithm.QCAlgorithm;
#endregion

namespace QuantConnect.Algorithm.CSharp
{

    public class CryptoExampleStrategyPublic : QCAlgorithm
    {
        public static string ModelParamsFileName = "baseline_model_params.json";
        public static string ThresholdArrayFileName = "baseline_threshold_array.json";

        public class ModelParams
        {
            [JsonProperty("feature_cols")]
            public string[] FeatureCols { get; set; }

            [JsonProperty("coefficients")]
            public decimal[] Coefficients { get; set; }

            [JsonProperty("intercept")]
            public decimal Intercept { get; set; }

            [JsonProperty("center")]
            public decimal[] Center { get; set; }

            [JsonProperty("scale")]
            public decimal[] Scale { get; set; }
        }

        // Ring buffer for prediction history
        private class RingBuffer<T>
        {
            private T[] _buffer;
            private DateTime[] _timestamps;
            private int _size;
            private int _currentIndex;
            private int _count;

            public RingBuffer(int size)
            {
                _size = size;
                _buffer = new T[size];
                _timestamps = new DateTime[size];
                _currentIndex = 0;
                _count = 0;
            }

            public void Add(DateTime timestamp, T item)
            {
                _timestamps[_currentIndex] = timestamp;
                _buffer[_currentIndex] = item;
                _currentIndex = (_currentIndex + 1) % _size;
                if (_count < _size)
                    _count++;
            }

            public int Count => _count;

            public T GetByIndex(int index)
            {
                if (index < 0 || index >= _count)
                    throw new IndexOutOfRangeException();

                int actualIndex = (_currentIndex - _count + index + _size) % _size;
                return _buffer[actualIndex];
            }

            public DateTime GetTimestampByIndex(int index)
            {
                if (index < 0 || index >= _count)
                    throw new IndexOutOfRangeException();

                int actualIndex = (_currentIndex - _count + index + _size) % _size;
                return _timestamps[actualIndex];
            }

            public T GetLatest()
            {
                if (_count == 0)
                    throw new InvalidOperationException("Buffer is empty");

                int index = (_currentIndex - 1 + _size) % _size;
                return _buffer[index];
            }

            public DateTime GetLatestTimestamp()
            {
                if (_count == 0)
                    throw new InvalidOperationException("Buffer is empty");

                int index = (_currentIndex - 1 + _size) % _size;
                return _timestamps[index];
            }

            public List<T> GetItems()
            {
                List<T> result = new List<T>(_count);
                for (int i = 0; i < _count; i++)
                {
                    int index = (_currentIndex - _count + i + _size) % _size;
                    result.Add(_buffer[index]);
                }
                return result;
            }

            public List<KeyValuePair<DateTime, T>> GetAllWithTimestamps()
            {
                List<KeyValuePair<DateTime, T>> result = new List<KeyValuePair<DateTime, T>>(
                    _count
                );
                for (int i = 0; i < _count; i++)
                {
                    int index = (_currentIndex - _count + i + _size) % _size;
                    result.Add(new KeyValuePair<DateTime, T>(_timestamps[index], _buffer[index]));
                }
                return result;
            }
        }
        private enum ModelState
        {
            Normal,
            Suspicious,
            Reversed,
            HighlyUnreliable,
        }

        private Symbol _btcusdt;
        private ModelParams _modelParams;
        private decimal[] _thresholdArr;
        private bool _modelLoaded = false;

        private decimal _positionSize = 0.98m;
        private decimal _leverage = 1.0m;

        private decimal _enterPositionThreshold = 0.04m;
        private decimal _exitPositionThreshold = 0.60m;

        private decimal _takeProfitTarget = 0.01m;
        private decimal _stopLossLevel = 1m;
        private decimal _modelReverseThreshold = 1m;
        private ModelState _currentModelState = ModelState.Normal;
        private int _consecutiveLosses = 0;
        private int _consecutiveWins = 0;
        private int _consecutiveLossesThreshold = 2;
        private int _consecutiveWinsThreshold = 2;
        private DateTime _stateTransitionTime;
        private decimal _stateTransitionPrice;

        private DateTime _positionEntryTime;
        private bool _inLongPosition = false;
        private bool _inShortPosition = false;
        private decimal _entryPrice = 0m;
        private int _positionHoldingWindow = 10;
        private int _earlyProfitMinHoldingTime = 1;

        private RingBuffer<decimal> _predictionHistory;
        private int _maxPredictionHistory = 60;

        private HashSet<DateTime> _testDays = new HashSet<DateTime>();
        private List<TradeRecord> _tradeRecords = new List<TradeRecord>();

        private class TradeRecord
        {
            public DateTime EntryTime { get; set; }
            public DateTime ExitTime { get; set; }
            public decimal EntryPrice { get; set; }
            public decimal ExitPrice { get; set; }
            public string Direction { get; set; }
            public decimal PnL { get; set; }
            public string ExitReason { get; set; }
            public ModelState ModelStateAtEntry { get; set; }
            public decimal OriginalPrediction { get; set; }
            public decimal AdjustedPrediction { get; set; }
        }

        public override void Initialize()
        {
            SetStartDate(2023, 7, 1);
            // SetEndDate(2023, 10, 1);
            SetEndDate(DateTime.Now);
            SetAccountCurrency("USDT");
            SetCash(1_000_000);
            SetBrokerageModel(new DefaultBrokerageModel());
            SetTimeZone(TimeZones.Utc);

            var security = AddCrypto(
                "BTCUSDT",
                Resolution.Minute,
                LiveMode ? null: Market.Binance,
                fillForward: true,
                leverage: _leverage
            );
            security.SetFeeModel(new ConstantFeeModel(0.0m));
            _btcusdt = security.Symbol;
            _predictionHistory = new RingBuffer<decimal>(_maxPredictionHistory);

            // Reload model every 00:00 UTC
            // Schedule.On(
            //     DateRules.EveryDay("BTCUSDT"),
            //     TimeRules.At(new TimeSpan(00, 00, 00)),
            //     LoadModelParameters
            // );
            // Initialize test days
            // InitializeTestDays();
            // // Reset state machine at start of each day
            // Schedule.On(
            //     DateRules.EveryDay("BTCUSDT"),
            //     TimeRules.At(00, 00, 01), // Just after midnight
            //     ResetStateMachine
            // );
            // Liquidate at the start of each day
            // Schedule.On(
            //     DateRules.EveryDay("BTCUSDT"),
            //     TimeRules.At(00, 00, 05), // Just after midnight
            //     CheckAndLiquidateForNonTestDays
            // );

            ResetStateMachine();
            LoadModelParameters();
            LoadThresholdArray();

            // We use 2x leverage for quantconnect live paper trading for the high sharpe ratio
            if (LiveMode)
            {
                _positionSize = 0.95m;
                // _leverage = 2.0m;
            }
            // TODO: Maybe buy some futures for hedging.
        }

        private void ResetStateMachine()
        {
            if (_currentModelState != ModelState.Normal)
            {
                Log(
                    $"Resetting state machine. Previous state: {_currentModelState}, Consecutive losses: {_consecutiveLosses}, Consecutive wins: {_consecutiveWins}"
                );
            }

            _currentModelState = ModelState.Normal;
            _consecutiveLosses = 0;
            _consecutiveWins = 0;

            Log(
                $"State machine reset for {Time.Date:yyyy-MM-dd}. Now in {_currentModelState} state."
            );
        }

        private void InitializeTestDays()
        {
            string[] testDaysStrings = new string[] {};

            foreach (string dateStr in testDaysStrings)
            {
                DateTime date = DateTime.Parse(dateStr);
                _testDays.Add(date.Date); // Store just the date part, no time
            }

            Log($"Initialized {_testDays.Count} test days for trading");
        }

        private void CheckAndLiquidateForNonTestDays()
        {
            DateTime currentDate = Time.Date;

            // Check if the current day is a test day
            if (!LiveMode && !_testDays.Contains(currentDate))
            {
                // If not a test day, liquidate all positions
                if (Portfolio.Invested)
                {
                    Liquidate(_btcusdt);
                    _inLongPosition = false;
                    _inShortPosition = false;
                    Log($"Not a test day: Liquidated all positions on {currentDate:yyyy-MM-dd}");
                }
            }
            else
            {
                // Log($"Test day: Trading enabled for {currentDate:yyyy-MM-dd}");
                // in live mode or test day, do not liquidate
                Log($"LiveMode or test day: Trading enabled for {currentDate:yyyy-MM-dd}");
            }
        }

        public override void OnData(Slice slice)
        {
            Log($"[OnData] - {Time} - Before Check {_btcusdt}, _modelLoaded {_modelLoaded}");
            if (!slice.Bars.ContainsKey(_btcusdt) || !_modelLoaded)
                return;
            Log($"[OnData] - {Time} - After Check: {slice.Bars[_btcusdt]}");
            // Only trade on test days when LiveMode == false
            // if (!LiveMode && !_testDays.Contains(Time.Date))
            //     return;

            var bar = slice.Bars[_btcusdt];

            decimal[] features = CalculateFeatures(bar);
            decimal originalPredictProb = PredictProbability(features);
            decimal adjustedPredictProb = AdjustPredictionByState(originalPredictProb);
            decimal percentile = GetProbabilityPercentile(adjustedPredictProb);
            _predictionHistory.Add(Time, originalPredictProb);

            Log(
                $"Time: {Time}, Price: {bar.Close}, Original Prediction: {originalPredictProb:F4}, "
                    + $"Adjusted Prediction: {adjustedPredictProb:F4}, Percentile: {percentile:P2}, State: {_currentModelState}"
            );

            bool shouldBeLong = percentile >= (1m - _enterPositionThreshold / 2m);
            bool shouldBeShort = percentile <= (_enterPositionThreshold / 2m);

            bool shouldExitLong = percentile <= (_exitPositionThreshold / 2m);
            bool shouldExitShort = percentile >= (1m - _exitPositionThreshold / 2m);

            bool holdingTimeElapsed = false;
            bool earlyProfitTimeElapsed = false;
            decimal currentPnlPercent = 0m;

            if (_inLongPosition || _inShortPosition)
            {
                TimeSpan holdingTime = Time - _positionEntryTime;
                holdingTimeElapsed = holdingTime.TotalMinutes >= _positionHoldingWindow;
                earlyProfitTimeElapsed = holdingTime.TotalMinutes >= _earlyProfitMinHoldingTime;
                if (_inLongPosition)
                {
                    currentPnlPercent = (bar.Close - _entryPrice) / _entryPrice * 100m;
                }
                else if (_inShortPosition)
                {
                    currentPnlPercent = (_entryPrice - bar.Close) / _entryPrice * 100m;
                }
                if (holdingTimeElapsed)
                {
                    Log($"Position holding window of {_positionHoldingWindow} minutes elapsed");
                }
            }

            bool takeProfitTriggered =
                earlyProfitTimeElapsed && currentPnlPercent >= _takeProfitTarget;
            bool stopLossTriggered = currentPnlPercent <= -_stopLossLevel;

            if (_inLongPosition)
            {
                // Exit if:
                // 1. opposite signal
                // 2. holding time elapsed
                // 3. exit threshold reached
                // 4. take profit target hit
                // 5. stop loss triggered
                if (
                    shouldBeShort
                    || holdingTimeElapsed
                    || shouldExitLong
                    || takeProfitTriggered
                    || stopLossTriggered
                )
                {
                    string reason =
                        shouldBeShort ? "Opposite signal"
                        : holdingTimeElapsed ? "Holding time elapsed"
                        : takeProfitTriggered ? $"Take profit target hit: {currentPnlPercent:F2}%"
                        : stopLossTriggered ? $"Stop loss triggered: {currentPnlPercent:F2}%"
                        : "Exit threshold reached";

                    ClosePosition(
                        "LONG",
                        bar.Close,
                        reason,
                        originalPredictProb,
                        adjustedPredictProb
                    );

                    // Check if stop loss should trigger state machine transition
                    if (stopLossTriggered)
                    {
                        UpdateStateMachineOnLoss();
                    }
                    else if (takeProfitTriggered)
                    {
                        UpdateStateMachineOnWin();
                    }
                }
            }
            else if (_inShortPosition)
            {
                // Exit if:
                // 1. opposite signal
                // 2. holding time elapsed
                // 3. exit threshold reached
                // 4. take profit target hit
                // 5. stop loss triggered
                if (
                    shouldBeLong
                    || holdingTimeElapsed
                    || shouldExitShort
                    || takeProfitTriggered
                    || stopLossTriggered
                )
                {
                    string reason =
                        shouldBeLong ? "Opposite signal"
                        : holdingTimeElapsed ? "Holding time elapsed"
                        : takeProfitTriggered ? $"Take profit target hit: {currentPnlPercent:F2}%"
                        : stopLossTriggered ? $"Stop loss triggered: {currentPnlPercent:F2}%"
                        : "Exit threshold reached";

                    ClosePosition(
                        "SHORT",
                        bar.Close,
                        reason,
                        originalPredictProb,
                        adjustedPredictProb
                    );
                    // Check if stop loss should trigger state machine transition
                    if (stopLossTriggered)
                    {
                        UpdateStateMachineOnLoss();
                    }
                    else if (takeProfitTriggered)
                    {
                        UpdateStateMachineOnWin();
                    }
                }
            }
            // Enter new positions if we're not already in a position
            if (!_inLongPosition && !_inShortPosition)
            {
                if (shouldBeLong)
                {
                    EnterLong(bar.Close, originalPredictProb, adjustedPredictProb);
                }
                else if (shouldBeShort)
                {
                    EnterShort(bar.Close, originalPredictProb, adjustedPredictProb);
                }
            }
        }

        private decimal AdjustPredictionByState(decimal originalPrediction)
        {
            switch (_currentModelState)
            {
                case ModelState.Normal:
                    // No adjustment needed
                    return originalPrediction;
                case ModelState.Suspicious:
                    // Reduce confidence by moving prediction toward 0.5
                    return 0.5m + (originalPrediction - 0.5m) * 0.5m;
                case ModelState.Reversed:
                    // Invert the prediction (1-p)
                    return 1m - originalPrediction;
                case ModelState.HighlyUnreliable:
                    // Just return 0.5 (no clear signal)
                    return 0.5m;
                default:
                    return originalPrediction;
            }
        }

        private void UpdateStateMachineOnLoss()
        {
            _consecutiveLosses++;
            _consecutiveWins = 0;
            // Transition state machine based on consecutive losses
            switch (_currentModelState)
            {
                case ModelState.Normal:
                    if (_consecutiveLosses >= _consecutiveLossesThreshold)
                    {
                        _currentModelState = ModelState.Suspicious;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: Normal -> Suspicious after {_consecutiveLosses} consecutive losses"
                        );
                    }
                    break;
                case ModelState.Suspicious:
                    if (_consecutiveLosses >= _consecutiveLossesThreshold * 2)
                    {
                        _currentModelState = ModelState.Reversed;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: Suspicious -> Reversed after {_consecutiveLosses} consecutive losses"
                        );
                    }
                    break;
                case ModelState.Reversed:
                    if (_consecutiveLosses >= _consecutiveLossesThreshold * 3)
                    {
                        _currentModelState = ModelState.HighlyUnreliable;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: Reversed -> HighlyUnreliable after {_consecutiveLosses} consecutive losses"
                        );
                    }
                    break;
            }
        }

        private void UpdateStateMachineOnWin()
        {
            _consecutiveWins++;
            _consecutiveLosses = 0;
            // Transition state machine based on consecutive wins
            switch (_currentModelState)
            {
                case ModelState.HighlyUnreliable:
                    if (_consecutiveWins >= _consecutiveWinsThreshold)
                    {
                        _currentModelState = ModelState.Reversed;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: HighlyUnreliable -> Reversed after {_consecutiveWins} consecutive wins"
                        );
                    }
                    break;
                case ModelState.Reversed:
                    if (_consecutiveWins >= _consecutiveWinsThreshold * 2)
                    {
                        _currentModelState = ModelState.Suspicious;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: Reversed -> Suspicious after {_consecutiveWins} consecutive wins"
                        );
                    }
                    break;
                case ModelState.Suspicious:
                    if (_consecutiveWins >= _consecutiveWinsThreshold * 3)
                    {
                        _currentModelState = ModelState.Normal;
                        _stateTransitionTime = Time;
                        Log(
                            $"State transition: Suspicious -> Normal after {_consecutiveWins} consecutive wins"
                        );
                    }
                    break;
            }
        }

        private void EnterLong(
            decimal price,
            decimal originalPrediction,
            decimal adjustedPrediction
        )
        {
            SetHoldings(_btcusdt, _positionSize);
            _inLongPosition = true;
            _inShortPosition = false;
            _positionEntryTime = Time;
            _entryPrice = price;
            Log(
                $"ENTERED LONG at {Time}, Price: {price}, Position Size: {_positionSize}, Model State: {_currentModelState}"
            );
            var trade = new TradeRecord
            {
                EntryTime = Time,
                EntryPrice = price,
                Direction = "LONG",
                ModelStateAtEntry = _currentModelState,
                OriginalPrediction = originalPrediction,
                AdjustedPrediction = adjustedPrediction,
            };
            _tradeRecords.Add(trade);
        }

        private void EnterShort(
            decimal price,
            decimal originalPrediction,
            decimal adjustedPrediction
        )
        {
            SetHoldings(_btcusdt, -_positionSize);
            _inShortPosition = true;
            _inLongPosition = false;
            _positionEntryTime = Time;
            _entryPrice = price;
            Log(
                $"ENTERED SHORT at {Time}, Price: {price}, Position Size: {_positionSize}, Model State: {_currentModelState}"
            );
            var trade = new TradeRecord
            {
                EntryTime = Time,
                EntryPrice = price,
                Direction = "SHORT",
                ModelStateAtEntry = _currentModelState,
                OriginalPrediction = originalPrediction,
                AdjustedPrediction = adjustedPrediction,
            };
            _tradeRecords.Add(trade);
        }

        private void ClosePosition(
            string positionType,
            decimal price,
            string reason,
            decimal originalPrediction,
            decimal adjustedPrediction
        )
        {
            Liquidate(_btcusdt);
            decimal pnl = 0;
            if (positionType == "LONG")
            {
                pnl = (price - _entryPrice) / _entryPrice * 100;
                _inLongPosition = false;
            }
            else
            {
                pnl = (_entryPrice - price) / _entryPrice * 100;
                _inShortPosition = false;
            }
            Log(
                $"EXITED {positionType} at {Time}, Price: {price}, PnL: {pnl:F2}%, Reason: {reason}, Model State: {_currentModelState}"
            );
            if (_tradeRecords.Count > 0)
            {
                var lastTrade = _tradeRecords[_tradeRecords.Count - 1];
                lastTrade.ExitTime = Time;
                lastTrade.ExitPrice = price;
                lastTrade.PnL = pnl;
                lastTrade.ExitReason = reason;
            }
        }

        private decimal[] CalculateFeatures(TradeBar bar)
        {
            decimal[] features = new decimal[_modelParams.FeatureCols.Length];

            int hour = Time.Hour;
            int minute = Time.Minute;
            decimal dayPct = (hour * 60 + minute) / (24m * 60m);

            for (int i = 0; i < _modelParams.FeatureCols.Length; i++)
            {
                switch (_modelParams.FeatureCols[i])
                {
                    case "close_open_ratio":
                        features[i] = bar.Close / bar.Open;
                        break;
                    case "high_low_ratio":
                        features[i] = bar.High / bar.Low;
                        break;
                    case "day_pct":
                        features[i] = dayPct;
                        break;
                    default:
                        Log($"Unknown feature: {_modelParams.FeatureCols[i]}");
                        features[i] = 0;
                        break;
                }
            }
            return features;
        }

        private void LoadModelParameters()
        {
            if (!ObjectStore.ContainsKey(ModelParamsFileName))
            {
                Log($"Model parameters file {ModelParamsFileName} not found.");
                return;
            }
            string jsonStr = ObjectStore.Read(ModelParamsFileName);
            try
            {
                _modelParams = JsonConvert.DeserializeObject<ModelParams>(jsonStr);
                var formattedJson = JsonConvert.SerializeObject(_modelParams, Formatting.Indented);
                Log($"Model parameters loaded:\n{formattedJson}");
                _modelLoaded = true;
            }
            catch (Exception ex)
            {
                Log($"Error deserializing JSON: {ex.Message}");
            }
        }

        private void LoadThresholdArray()
        {
            if (!ObjectStore.ContainsKey(ThresholdArrayFileName))
            {
                Log($"Threshold array file {ThresholdArrayFileName} not found.");
                InitializeDefaultThresholdArray();
                return;
            }
            string jsonStr = ObjectStore.Read(ThresholdArrayFileName);
            try
            {
                _thresholdArr = JsonConvert.DeserializeObject<decimal[]>(jsonStr);
                Log($"Threshold array loaded with {_thresholdArr.Length} values.");
            }
            catch (Exception ex)
            {
                Log($"Error deserializing threshold array JSON: {ex.Message}");
                InitializeDefaultThresholdArray();
            }
        }

        private void InitializeDefaultThresholdArray()
        {
            // Create a default threshold array with 200 points (0.5% resolution)
            // Values will be distributed according to a Gaussian (Normal) distribution
            int arraySize = 200;
            _thresholdArr = new decimal[arraySize];
            double mean = 0.5;
            double stdDev = 0.15;

            for (int i = 0; i < arraySize; i++)
            {
                double x = (double)i / (arraySize - 1);
                // Apply sigmoid function to approximate Gaussian CDF
                // This gives a reasonable S-shaped curve similar to the normal distribution CDF
                double z = (x - mean) / stdDev;
                double probability = 1.0 / (1.0 + Math.Exp(-z * 1.702));

                _thresholdArr[i] = (decimal)probability;
            }
            Array.Sort(_thresholdArr);
            Log(
                $"Initialized default threshold array with {arraySize} Gaussian-distributed values."
            );
        }

        private decimal PredictProbability(decimal[] features)
        {
            // sklearn RobustScaler equivalent
            decimal[] scaledFeatures = new decimal[features.Length];
            for (int i = 0; i < features.Length; i++)
            {
                scaledFeatures[i] = (features[i] - _modelParams.Center[i]) / _modelParams.Scale[i];
            }
            decimal logit = _modelParams.Intercept;
            for (int i = 0; i < scaledFeatures.Length; i++)
            {
                logit += scaledFeatures[i] * _modelParams.Coefficients[i];
            }
            decimal prob = 1m / (1m + (decimal)Math.Exp(-(double)logit));
            return prob;
        }

        private decimal GetProbabilityPercentile(decimal probability)
        {
            // If threshold array is not loaded, initialize it with default values
            if (_thresholdArr == null || _thresholdArr.Length == 0)
            {
                InitializeDefaultThresholdArray();
            }
            int index = Array.BinarySearch(_thresholdArr, probability);
            if (index >= 0)
            {
                return (decimal)index / (_thresholdArr.Length - 1);
            }
            else
            {
                // No direct match - get the insertion point
                int insertPoint = ~index;
                if (insertPoint == 0)
                {
                    return 0m; // Probability is lower than all values in the array
                }
                else if (insertPoint >= _thresholdArr.Length)
                {
                    return 1m; // Probability is higher than all values in the array
                }
                else
                {
                    // Interpolate between the two closest points
                    decimal lowerProb = _thresholdArr[insertPoint - 1];
                    decimal upperProb = _thresholdArr[insertPoint];
                    decimal lowerPct = (decimal)(insertPoint - 1) / (_thresholdArr.Length - 1);
                    decimal upperPct = (decimal)insertPoint / (_thresholdArr.Length - 1);
                    // Linear interpolation
                    decimal ratio = (probability - lowerProb) / (upperProb - lowerProb);
                    return lowerPct + ratio * (upperPct - lowerPct);
                }
            }
        }
    }
}