package retry

import (
	"context"
	"errors"
	"fmt"
	"math"
	"os"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

func TestDoWithDataAllFailed(t *testing.T) {
	var retrySum uint
	v, err := NewWithData[int](
		OnRetry(func(n uint, err error) { retrySum += n }),
		Delay(time.Nanosecond),
	).Do(
		func() (int, error) { return 7, errors.New("test") },
	)
	assert.Error(t, err)
	assert.Equal(t, 0, v)

	expectedErrorFormat := `All attempts fail:
#1: test
#2: test
#3: test
#4: test
#5: test
#6: test
#7: test
#8: test
#9: test
#10: test`
	assert.Len(t, err, 10)
	fmt.Println(err.Error())
	assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
	assert.Equal(t, uint(45), retrySum, "right count of retry")
}

func TestDoFirstOk(t *testing.T) {
	var retrySum uint
	err := New(
		OnRetry(func(n uint, err error) { retrySum += n }),
	).Do(
		func() error { return nil },
	)
	assert.NoError(t, err)
	assert.Equal(t, uint(0), retrySum, "no retry")
}

func TestDoWithDataFirstOk(t *testing.T) {
	returnVal := 1

	var retrySum uint
	val, err := NewWithData[int](
		OnRetry(func(n uint, err error) { retrySum += n }),
	).Do(
		func() (int, error) { return returnVal, nil },
	)
	assert.NoError(t, err)
	assert.Equal(t, returnVal, val)
	assert.Equal(t, uint(0), retrySum, "no retry")
}

func TestRetryIf(t *testing.T) {
	var retryCount uint
	err := New(
		OnRetry(func(n uint, err error) { retryCount++ }),
		RetryIf(func(err error) bool {
			return err.Error() != "special"
		}),
		Delay(time.Nanosecond),
	).Do(
		func() error {
			if retryCount >= 2 {
				return errors.New("special")
			} else {
				return errors.New("test")
			}
		},
	)
	assert.Error(t, err)

	expectedErrorFormat := `All attempts fail:
#1: test
#2: test
#3: special`
	assert.Len(t, err, 3)
	assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
	assert.Equal(t, uint(2), retryCount, "right count of retry")
}

func TestRetryIf_ZeroAttempts(t *testing.T) {
	var retryCount, onRetryCount uint
	err := New(
		OnRetry(func(n uint, err error) { onRetryCount = n }),
		RetryIf(func(err error) bool {
			return err.Error() != "special"
		}),
		Delay(time.Nanosecond),
		Attempts(0),
	).Do(
		func() error {
			if retryCount >= 2 {
				return errors.New("special")
			} else {
				retryCount++
				return errors.New("test")
			}
		},
	)
	assert.Error(t, err)

	assert.Equal(t, "special", err.Error(), "retry error format")
	assert.Equal(t, retryCount, onRetryCount+1, "right count of retry")
}

func TestZeroAttemptsWithError(t *testing.T) {
	const maxErrors = 999
	count := 0

	err := New(
		Attempts(0),
		MaxDelay(time.Nanosecond),
	).Do(
		func() error {
			if count < maxErrors {
				count += 1
				return errors.New("test")
			}

			return nil
		},
	)
	assert.NoError(t, err)

	assert.Equal(t, count, maxErrors)
}

func TestZeroAttemptsWithoutError(t *testing.T) {
	count := 0

	err := New(
		Attempts(0),
	).Do(
		func() error {
			count++

			return nil
		},
	)
	assert.NoError(t, err)

	assert.Equal(t, count, 1)
}

func TestZeroAttemptsWithUnrecoverableError(t *testing.T) {
	err := New(
		Attempts(0),
		MaxDelay(time.Nanosecond),
	).Do(
		func() error {
			return Unrecoverable(assert.AnError)
		},
	)
	assert.Error(t, err)
	assert.Equal(t, Unrecoverable(assert.AnError), err)
}

func TestAttemptsForError(t *testing.T) {
	count := uint(0)
	testErr := os.ErrInvalid
	attemptsForTestError := uint(3)
	err := New(
		AttemptsForError(attemptsForTestError, testErr),
		Attempts(5),
	).Do(
		func() error {
			count++
			return testErr
		},
	)
	assert.Error(t, err)
	assert.Equal(t, attemptsForTestError, count)
}

func TestDefaultSleep(t *testing.T) {
	start := time.Now()
	err := New(
		Attempts(3),
	).Do(
		func() error { return errors.New("test") },
	)
	dur := time.Since(start)
	assert.Error(t, err)
	assert.Greater(t, dur, 300*time.Millisecond, "3 times default retry is longer then 300ms")
}

func TestFixedSleep(t *testing.T) {
	start := time.Now()
	err := New(
		Attempts(3),
		DelayType(FixedDelay),
	).Do(
		func() error { return errors.New("test") },
	)
	dur := time.Since(start)
	assert.Error(t, err)
	assert.Less(t, dur, 500*time.Millisecond, "3 times default retry is shorter then 500ms")
}

func TestLastErrorOnly(t *testing.T) {
	var retrySum uint
	err := New(
		OnRetry(func(n uint, err error) { retrySum += 1 }),
		Delay(time.Nanosecond),
		LastErrorOnly(true),
	).Do(
		func() error { return fmt.Errorf("%d", retrySum) },
	)
	assert.Error(t, err)
	assert.Equal(t, "9", err.Error())
}

func TestUnrecoverableError(t *testing.T) {
	attempts := 0
	testErr := errors.New("error")
	expectedErr := Error{testErr}
	err := New(
		Attempts(2),
	).Do(
		func() error {
			attempts++
			return Unrecoverable(testErr)
		},
	)
	assert.Equal(t, expectedErr, err)
	assert.Nil(t, errors.Unwrap(err))
	assert.Equal(t, 1, attempts, "unrecoverable error broke the loop")
}

func TestCombineFixedDelays(t *testing.T) {
	if os.Getenv("OS") == "macos-latest" {
		t.Skip("Skipping testing in MacOS GitHub actions - too slow, duration is wrong")
	}

	start := time.Now()
	err := New(
		Attempts(3),
		DelayType(CombineDelay(FixedDelay, FixedDelay)),
	).Do(
		func() error { return errors.New("test") },
	)
	dur := time.Since(start)
	assert.Error(t, err)
	assert.Greater(t, dur, 400*time.Millisecond, "3 times combined, fixed retry is greater then 400ms")
	assert.Less(t, dur, 500*time.Millisecond, "3 times combined, fixed retry is less then 500ms")
}

func TestRandomDelay(t *testing.T) {
	if os.Getenv("OS") == "macos-latest" {
		t.Skip("Skipping testing in MacOS GitHub actions - too slow, duration is wrong")
	}

	start := time.Now()
	err := New(
		Attempts(3),
		DelayType(RandomDelay),
		MaxJitter(50*time.Millisecond),
	).Do(
		func() error { return errors.New("test") },
	)
	dur := time.Since(start)
	assert.Error(t, err)
	assert.Greater(t, dur, 2*time.Millisecond, "3 times random retry is longer then 2ms")
	assert.Less(t, dur, 150*time.Millisecond, "3 times random retry is shorter then 150ms")
}

func TestMaxDelay(t *testing.T) {
	if os.Getenv("OS") == "macos-latest" {
		t.Skip("Skipping testing in MacOS GitHub actions - too slow, duration is wrong")
	}

	start := time.Now()
	err := New(
		Attempts(5),
		Delay(10*time.Millisecond),
		MaxDelay(50*time.Millisecond),
	).Do(
		func() error { return errors.New("test") },
	)
	dur := time.Since(start)
	assert.Error(t, err)
	assert.Greater(t, dur, 120*time.Millisecond, "5 times with maximum delay retry is less than 120ms")
	assert.Less(t, dur, 275*time.Millisecond, "5 times with maximum delay retry is longer than 275ms")
}

func TestBackOffDelay(t *testing.T) {
	for _, c := range []struct {
		label         string
		delay         time.Duration
		expectedMaxN  uint
		n             uint
		expectedDelay time.Duration
	}{
		{
			label:         "negative-delay",
			delay:         -1,
			expectedMaxN:  62,
			n:             2,
			expectedDelay: 0,
		},
		{
			label:         "zero-delay",
			delay:         0,
			expectedMaxN:  62,
			n:             65,
			expectedDelay: 0,
		},
		{
			label:         "one-second",
			delay:         time.Second,
			expectedMaxN:  33,
			n:             62,
			expectedDelay: time.Second << 33,
		},
		{
			label:         "one-second-n",
			delay:         time.Second,
			expectedMaxN:  33,
			n:             1,
			expectedDelay: time.Second,
		},
	} {
		t.Run(
			c.label,
			func(t *testing.T) {
				retrier := New(Delay(c.delay))
				delay := BackOffDelay(c.n, nil, retrier)
				assert.Equal(t, c.expectedMaxN, retrier.maxBackOffN, "max n mismatch")
				assert.Equal(t, c.expectedDelay, delay, "delay duration mismatch")
			},
		)
	}
}

func TestCombineDelay(t *testing.T) {
	f := func(d time.Duration) DelayTypeFunc {
		return func(_ uint, _ error, _ DelayContext) time.Duration {
			return d
		}
	}
	const max = time.Duration(1<<63 - 1)
	for _, c := range []struct {
		label    string
		delays   []time.Duration
		expected time.Duration
	}{
		{
			label: "empty",
		},
		{
			label: "single",
			delays: []time.Duration{
				time.Second,
			},
			expected: time.Second,
		},
		{
			label: "negative",
			delays: []time.Duration{
				time.Second,
				-time.Millisecond,
			},
			expected: time.Second - time.Millisecond,
		},
		{
			label: "overflow",
			delays: []time.Duration{
				max,
				time.Second,
				time.Millisecond,
			},
			expected: max,
		},
	} {
		t.Run(
			c.label,
			func(t *testing.T) {
				funcs := make([]DelayTypeFunc, len(c.delays))
				for i, d := range c.delays {
					funcs[i] = f(d)
				}
				actual := CombineDelay(funcs...)(0, nil, nil)
				assert.Equal(t, c.expected, actual, "delay duration mismatch")
			},
		)
	}
}

func TestContext(t *testing.T) {
	const defaultDelay = 100 * time.Millisecond
	t.Run("cancel before", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		cancel()

		retrySum := 0
		start := time.Now()
		err := New(
			OnRetry(func(n uint, err error) { retrySum += 1 }),
			Context(ctx),
		).Do(
			func() error { return errors.New("test") },
		)
		dur := time.Since(start)
		assert.Error(t, err)
		assert.True(t, dur < defaultDelay, "immediately cancellation")
		assert.Equal(t, 0, retrySum, "called at most once")
	})

	t.Run("cancel in retry progress", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())

		retrySum := 0
		err := New(
			OnRetry(func(n uint, err error) {
				retrySum += 1
				if retrySum > 1 {
					cancel()
				}
			}),
			Context(ctx),
		).Do(
			func() error { return errors.New("test") },
		)
		assert.Error(t, err)

		expectedErrorFormat := `All attempts fail:
#1: test
#2: test
#3: context canceled`
		assert.Len(t, err, 3)
		assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
		assert.Equal(t, 2, retrySum, "called at most once")
	})

	t.Run("cancel in retry progress - last error only", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())

		retrySum := 0
		err := New(
			OnRetry(func(n uint, err error) {
				retrySum += 1
				if retrySum > 1 {
					cancel()
				}
			}),
			Context(ctx),
			LastErrorOnly(true),
		).Do(
			func() error { return errors.New("test") },
		)
		assert.Equal(t, context.Canceled, err)

		assert.Equal(t, 2, retrySum, "called at most once")
	})

	t.Run("cancel in retry progress - infinite attempts", func(t *testing.T) {
		go func() {
			ctx, cancel := context.WithCancel(context.Background())

			retrySum := 0
			err := New(
				OnRetry(func(n uint, err error) {
					fmt.Println(n)
					retrySum += 1
					if retrySum > 1 {
						cancel()
					}
				}),
				Context(ctx),
				Attempts(0),
			).Do(
				func() error { return errors.New("test") },
			)

			assert.Equal(t, context.Canceled, err)

			assert.Equal(t, 2, retrySum, "called at most once")
		}()
	})

	t.Run("cancelled on retry infinte attempts - wraps context error with last retried function error", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()

		retrySum := 0
		err := New(
			OnRetry(func(n uint, err error) {
				retrySum += 1
				if retrySum == 2 {
					cancel()
				}
			}),
			Context(ctx),
			Attempts(0),
			WrapContextErrorWithLastError(true),
		).Do(
			func() error { return fooErr{str: fmt.Sprintf("error %d", retrySum+1)} },
		)
		assert.ErrorIs(t, err, context.Canceled)
		assert.ErrorIs(t, err, fooErr{str: "error 2"})
	})

	t.Run("timed out on retry infinte attempts - wraps context error with last retried function error", func(t *testing.T) {
		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
		defer cancel()

		retrySum := 0
		err := New(
			OnRetry(func(n uint, err error) {
				retrySum += 1
			}),
			Context(ctx),
			Attempts(0),
			WrapContextErrorWithLastError(true),
		).Do(
			func() error { return fooErr{str: fmt.Sprintf("error %d", retrySum+1)} },
		)
		assert.ErrorIs(t, err, context.DeadlineExceeded)
		assert.ErrorIs(t, err, fooErr{str: "error 2"})
	})
}

type testTimer struct {
	called bool
}

func (t *testTimer) After(d time.Duration) <-chan time.Time {
	t.called = true
	return time.After(d)
}

func TestTimerInterface(t *testing.T) {
	var timer testTimer
	err := New(
		Attempts(1),
		Delay(10*time.Millisecond),
		MaxDelay(50*time.Millisecond),
		WithTimer(&timer),
	).Do(
		func() error { return errors.New("test") },
	)

	assert.Error(t, err)
}

func TestErrorIs(t *testing.T) {
	var e Error
	expectErr := errors.New("error")
	closedErr := os.ErrClosed
	e = append(e, expectErr)
	e = append(e, closedErr)

	assert.True(t, errors.Is(e, expectErr))
	assert.True(t, errors.Is(e, closedErr))
	assert.False(t, errors.Is(e, errors.New("error")))
}

type fooErr struct{ str string }

func (e fooErr) Error() string {
	return e.str
}

type barErr struct{ str string }

func (e barErr) Error() string {
	return e.str
}

func TestErrorAs(t *testing.T) {
	var e Error
	fe := fooErr{str: "foo"}
	e = append(e, fe)

	var tf fooErr
	var tb barErr

	assert.True(t, errors.As(e, &tf))
	assert.False(t, errors.As(e, &tb))
	assert.Equal(t, "foo", tf.str)
}

func TestUnwrap(t *testing.T) {
	testError := errors.New("test error")
	err := New(
		Attempts(1),
	).Do(
		func() error {
			return testError
		},
	)

	assert.Error(t, err)
	assert.Nil(t, errors.Unwrap(err))

	// Check for Go 1.20 Unwrap() []error
	type unwrapper interface {
		Unwrap() []error
	}
	u, ok := err.(unwrapper)
	assert.True(t, ok)
	assert.Equal(t, []error{testError}, u.Unwrap())
}

func TestWrappedErrors(t *testing.T) {
	testError := errors.New("test error")
	err := New(
		Attempts(1),
	).Do(
		func() error {
			return testError
		},
	)

	assert.Error(t, err)

	type wrapper interface {
		WrappedErrors() []error
	}
	w, ok := err.(wrapper)
	assert.True(t, ok)
	assert.Equal(t, []error{testError}, w.WrappedErrors())
}

func TestErrorsUnwrapReturnsNil(t *testing.T) {
	err1 := errors.New("error 1")
	err2 := errors.New("error 2")

	retryErr := Error{err1, err2}

	// IMPORTANT: errors.Unwrap only calls Unwrap() error, NOT Unwrap() []error
	// Since retry.Error implements Unwrap() []error (like errors.Join),
	// errors.Unwrap returns nil. This is the correct Go 1.20 behavior.
	// See: https://pkg.go.dev/errors#Unwrap
	unwrapped := errors.Unwrap(retryErr)
	assert.Nil(t, unwrapped, "errors.Unwrap should return nil for retry.Error")

	// However, errors.Is and errors.As DO work with Unwrap() []error
	// They traverse the error tree using either Unwrap() method
	assert.True(t, errors.Is(retryErr, err1), "errors.Is should work with first error")
	assert.True(t, errors.Is(retryErr, err2), "errors.Is should work with second error")

	// Verify errors.As also works
	customErr := &fooErr{str: "custom"}
	retryErr2 := Error{customErr}

	var target *fooErr
	assert.True(t, errors.As(retryErr2, &target), "errors.As should work")
	assert.Equal(t, "custom", target.str)
}

func BenchmarkDo_ImmediateSuccess(b *testing.B) {
	retrier := New(Attempts(10), Delay(0))
	for i := 0; i < b.N; i++ {
		_ = retrier.Do(
			func() error {
				return nil
			},
		)
	}
}

func BenchmarkDoWithData_ImmediateSuccess(b *testing.B) {
	retrier := NewWithData[int](Attempts(10), Delay(0))
	for i := 0; i < b.N; i++ {
		_, _ = retrier.Do(
			func() (int, error) {
				return 0, nil
			},
		)
	}
}

func BenchmarkDo_OneRetry(b *testing.B) {
	counter := 0
	retryOnceFunc := func() error {
		counter++
		if counter%2 == 1 {
			return errors.New("temporary error")
		}
		return nil
	}
	retrier := New(Attempts(10), Delay(0))

	for i := 0; i < b.N; i++ {
		_ = retrier.Do(retryOnceFunc)
	}
}

func TestIsRecoverable(t *testing.T) {
	err := errors.New("err")
	assert.True(t, IsRecoverable(err))

	err = Unrecoverable(err)
	assert.False(t, IsRecoverable(err))

	err = fmt.Errorf("wrapping: %w", err)
	assert.False(t, IsRecoverable(err))
}

func TestFullJitterBackoffDelay(t *testing.T) {
	// Seed for predictable randomness in tests
	// In real usage, math/rand is auto-seeded in Go 1.20+ or should be seeded once at program start.
	// For library test predictability, local seeding is fine.
	// However, retry-go's RandomDelay uses global math/rand without explicit seeding in tests.
	// Let's follow the existing pattern of not explicitly seeding in each test for now,
	// assuming test runs are isolated enough or that exact delay values aren't asserted,
	// but rather ranges or properties.

	baseDelay := 50 * time.Millisecond
	maxDelay := 500 * time.Millisecond

	config := New(Delay(baseDelay), MaxDelay(maxDelay))

	attempts := []uint{0, 1, 2, 3, 4, 5, 6, 10}

	for _, n := range attempts {
		delay := FullJitterBackoffDelay(n, errors.New("test error"), config)

		expectedMaxCeiling := float64(baseDelay) * math.Pow(2, float64(n))
		if expectedMaxCeiling > float64(maxDelay) {
			expectedMaxCeiling = float64(maxDelay)
		}

		assert.True(t, delay >= 0, "Delay should be non-negative. Got: %v for attempt %d", delay, n)
		assert.True(t, delay <= time.Duration(expectedMaxCeiling),
			"Delay %v should be less than or equal to current backoff ceiling %v for attempt %d", delay, time.Duration(expectedMaxCeiling), n)

		t.Logf("Attempt %d: BaseDelay=%v, MaxDelay=%v, Calculated Ceiling=~%v, Actual Delay=%v",
			n, baseDelay, maxDelay, time.Duration(expectedMaxCeiling), delay)

		// Test with MaxDelay disabled (0)
		configNoMax := New(Delay(baseDelay))
		delayNoMax := FullJitterBackoffDelay(n, errors.New("test error"), configNoMax)
		expectedCeilingNoMax := float64(baseDelay) * math.Pow(2, float64(n))
		if expectedCeilingNoMax > float64(10*time.Minute) { // Avoid overflow for very large N
			expectedCeilingNoMax = float64(10 * time.Minute)
		}
		assert.True(t, delayNoMax >= 0, "Delay (no max) should be non-negative. Got: %v for attempt %d", delayNoMax, n)
		assert.True(t, delayNoMax <= time.Duration(expectedCeilingNoMax),
			"Delay (no max) %v should be less than or equal to current backoff ceiling %v for attempt %d", delayNoMax, time.Duration(expectedCeilingNoMax), n)
	}

	// Test case where baseDelay might be zero
	configZeroBase := New(Delay(0), MaxDelay(maxDelay))
	delayZeroBase := FullJitterBackoffDelay(0, errors.New("test error"), configZeroBase)
	assert.Equal(t, time.Duration(0), delayZeroBase, "Delay with zero base delay should be 0")

	delayZeroBaseAttempt1 := FullJitterBackoffDelay(1, errors.New("test error"), configZeroBase)
	assert.Equal(t, time.Duration(0), delayZeroBaseAttempt1, "Delay with zero base delay (attempt > 0) should be 0")

	// Test with very small base delay
	smallBaseDelay := 1 * time.Nanosecond
	configSmallBase := New(Delay(smallBaseDelay), MaxDelay(100*time.Nanosecond))
	for i := uint(0); i < 5; i++ {
		d := FullJitterBackoffDelay(i, errors.New("test"), configSmallBase)
		ceil := float64(smallBaseDelay) * math.Pow(2, float64(i))
		if ceil > 100 {
			ceil = 100
		}
		assert.True(t, d <= time.Duration(ceil))
	}
}
