James-Stein Estimate

  • Plots
  • James-Stein Estimate



Stein's Paradox concerns the use of observed averages to estimate unobserved quantities.
It relies on a method named 'Shrinking' that pulls all the individual averages towards the grand mean.

If the individual mean is higher than the grand mean the estimation will shrink, and the estimate will increase if it is lower than the grand mean.
You can observe this using the application. The resulting shrunken value is designated by z.

z is calculated by the following equation:
z = gm - c(im - gm),
where gm is the grand mean, the im is the individual mean and c a constant (shrinking factor).

For the most of the cases z is better estimator than the individual mean as it can be observed in the present app.
Find more information on Wikipedia
    show with app
    • app.R
    • begin2.R
    # Martin Kolnes, 
    # Margaret Laurie, 
    # Konrad Lehmann, 
    # Diana Orghian, 
    # Francesca Serra 
    # Don van den Bergh
    
    library(shiny)
    library(plotly) # unused
    library(shinythemes)
    source('begin2.R')
    
    ui <- fluidPage(theme = shinytheme("journal"),
      includeScript("../../../Matomo-tquant.js"),
      headerPanel("James-Stein Estimate"),
      fluidRow(),
      
      navbarPage(#"James-Stein Estimate",# theme = "journal",
        tabsetPanel(type = "tabs", 
                    tabPanel("Plots",
                             
                             sidebarPanel(  
                               numericInput(inputId = "nBetween", label = "Input the number of samples (groups of data, e.g. subjects) you have observed:",
                                            min = 1, max = 1e3, value = 50),
                               sliderInput(inputId = "grandMean", label = "Choose the grand mean of your data:",
                                           min = 5, max = 100, value = 50),
                               sliderInput(inputId = "grandVar", label = "Choose the variance of the samples:",
                                           min = .1, max = 5, value = 1),
                               sliderInput(inputId = "nWithin", label = "Choose the number of observations within each sample:",
                                           min = 5, max = 100, value = 50),
                               sliderInput(inputId = "groupDistVar", label = "Choose the variance of the observations:",
                                           min = .1, max = 5, value = 1)
                             ),     
                             mainPanel(
                             tableOutput("table"), 
                             br(),
                             plotOutput("plot"),
                             br(),
                             plotOutput("plot2"))),
                    tabPanel("James-Stein Estimate", br(),
                             "Stein's Paradox concerns the use of observed averages to estimate unobserved quantities.", br(),
                             "It relies on a method named 'Shrinking' that pulls all the individual averages towards the grand mean.", br(), br(),
                             "If the individual mean is higher than the grand mean the estimation will shrink, and the estimate will increase if it is lower than the grand mean.", br(),
                              "You can observe this using the application. The resulting shrunken value is designated by z.", br(), br(),
                             "z is calculated by the following equation:", br(), "z = gm - c(im - gm),", br(), "where gm is the grand mean, the im is the individual mean and c a constant (shrinking factor).", br(), br(),
                              "For the most of the cases z is better estimator than the individual mean as it can be observed in the present app."),
                    a("Find more information on Wikipedia", href ="https://en.wikipedia.org/wiki/James-Stein_estimator")
                    ))
      
      )
    
    server <- function(input, output) {
      
      allinput = reactive(list(nBetween = input$nBetween,
                               nWithin = input$nWithin,
                               groupDistVar = input$groupDistVar,
                               grandMean = input$grandMean,
                               grandVar = input$grandVar))
      
      output$table <- renderTable({
        set.seed(1)
        x <- genData(nBetween = allinput()$nBetween,
                     nWithin = allinput()$nWithin,
                     args1 = list(mean = allinput()$grandMean,
                                  sd = allinput()$grandVar),
                     args2 = list(sd = allinput()$groupDistVar))
        JS = JS.est(x$data)
        ShrinkFactor = round(JS$c, digits = 4)
        ShrinkPercent = (1 - ShrinkFactor)*100
        table <- data.frame(ShrinkFactor, 
                            ShrinkPercent)
                            colnames(table) <- c("Shrinkage Factor", "Shrink Percent")
                            row.names(table) <- NULL
        print(table, row.names = FALSE)
      }, rownames=TRUE)
      
      output$plot <- renderPlot({
        par(las = 1, bty = 'n')
        set.seed(1)
        x <- genData(nBetween = allinput()$nBetween,
                     nWithin = allinput()$nWithin,
                     args1 = list(mean = allinput()$grandMean,
                                  sd = allinput()$grandVar),
                     args2 = list(sd = allinput()$groupDistVar))
        dat = x$data
        mx <- unlist(lapply(dat, mean))
        JS = JS.est(dat)
        JS.est = JS$z
        
        shrinkagePlot(JS = JS, Mu = mx, pch = 19, lty = 1, yaxt = 'n', type = 'b', 
                      ylab = '', xlab = 'Estimate', main = "",
                      col = rainbow(length(mx)))
        axis(2, at = 0:1, labels = c('JS', 'Mean'))
    
      })
      output$plot2 = renderPlot({
        par(las = 1, bty = 'n', mfrow = 1:2)
        set.seed(1)
        x <- genData(nBetween = allinput()$nBetween,
                     nWithin = allinput()$nWithin,
                     args1 = list(mean = allinput()$grandMean,
                                  sd = allinput()$grandVar),
                     args2 = list(sd = allinput()$groupDistVar))
        dat = x$data
        mx <- unlist(lapply(dat, mean))
        JS = JS.est(dat)
        JS.est = JS$z
        col = c(rainbow(length(mx)), 'black')
        lwd = c(rep(1, length(mx)), 2.5)
        densityPlot(mu = mx, sd = unlist(lapply(dat, sd)), 
                    grandMean = allinput()$grandMean, grandSd = 1,
                    type = 'l', lty = 1, col = col, lwd = lwd,
                    main = "Each Group Mean Estimate")
        densityPlot(mu = JS$z, sd = unlist(lapply(dat, sd)), 
                    grandMean = allinput()$grandMean, grandSd = 1,
                    type = 'l', lty = 1, col = col, lwd = lwd,
                    main = "Each Group's JS Estimate")
        legend('topright', legend = 'Total Distribution', lwd = 2, col = 'black', bty = 'n')
      })
      
    }
    
    shinyApp(ui = ui, server = server)
    genData = function(nBetween, nWithin, fun1 = 'rnorm', fun2 = 'rnorm',
                       args1 = list(mean = 0, sd = 1), args2 = list(sd = 5)) {
      data.list = vector('list', nBetween) # make empty lists
      args11 = args1[names(args1) %in% names(formals(fun1))] # only retain function args actually used in fun1
      args22 = args2[names(args2) %in% names(formals(fun2))] # only retain function args actually used in fun2
      mu = do.call(fun1, c(list(n = nBetween), args11)) # do function fun1 with arguments ...
      for (i in 1:length(data.list)) {
        # fill every element in data.list with fun2 executed with arguments ...
        data.list[[i]] = do.call(fun2, c(list(n = nWithin, mean = mu[i]), args22))
      }
      return(list(data = data.list, mu = mu)) # return output
    }
    
    JS.est = function(x) {
      y = unlist(lapply(x, mean)) # calc mean by group
      yh = mean(unlist(x)) # calc mean of all data together
      sig = var(y) # calc variance of the group means
      if (is.na(sig)) {
        sig = 0
      }
      k = length(x) # number of means to estimate (== number, of groups)
      c = 1 - (k-3)*sig / (sum((y - yh)^2)) # formula for shrinkage factor from paper
      if (is.nan(c)) {
        c = 1
      }
      z = yh + c * (y - yh) # formula for stein estimates from paper
      return(list(z = z, c = c)) # return output
    }
    
    shrinkagePlot = function(JS, Mu, plotly = FALSE, ...) {
      x = t(cbind(Mu, JS$z)) # bind estimates in a matrix
      x = as.matrix(x[, order(x[1, ])]) # order matrix by first column
      if (!plotly) {
        matplot(x, t(matrix(1:0, ncol(x), 2, TRUE)), ...) # plot matrix.
      } else {
        col = rainbow(ncol(x))
        dat = data.frame(x = c(x), 
                         y = c(t(matrix(1:0, ncol(x), 2, TRUE))),
                         group = rep(1:ncol(x), 1, NA, 2))
        plot_ly(data = dat, x = x, y = y, group = dat$group, type = 'line',
                        color = dat$group)
      }
    }
    
    densityPlot = function(mu, sd, grandMean, grandSd, x, ...) {
      if (missing (sd)) {
        sd = rep(1, length(mu))
      }
      if (missing(x)) {
        r = range(mu) + c(-9, 9)
        x = seq.int(r[1], r[2], .01)
      }
      y = mapply(dnorm, mean = mu, sd = sd, MoreArgs = list(x = x))
    
      if (!missing(grandMean)) {
        if (missing(grandSd)) {
          grandSd = 1
        }
        y = cbind(y, 1*dnorm(x, grandMean, grandSd))
      }
      matplot(x, y, ...)
    }
    
    
    
    y = rep(0:1, 1, NA, 5)
    x = c(1:5, 1:5+.1)
    dat = data.frame(x = x, y = y, group = rep(1:5, 2))
    plot_ly(dat, x = x, y = y, group = dat$group, type = 'line')
    
    
    
    dat = genData(100, 10)
    JS = JS.est(dat$data)
    Mu = sapply(dat$data, mean)
    
    JS$c
    
    shrinkagePlot(JS = JS, Mu = Mu, plotly = TRUE)