# Martin Kolnes,
# Margaret Laurie,
# Konrad Lehmann,
# Diana Orghian,
# Francesca Serra
# Don van den Bergh
library(shiny)
library(plotly) # unused
library(shinythemes)
source('begin2.R')
ui <- fluidPage(theme = shinytheme("journal"),
includeScript("../../../Matomo-tquant.js"),
headerPanel("James-Stein Estimate"),
fluidRow(),
navbarPage(#"James-Stein Estimate",# theme = "journal",
tabsetPanel(type = "tabs",
tabPanel("Plots",
sidebarPanel(
numericInput(inputId = "nBetween", label = "Input the number of samples (groups of data, e.g. subjects) you have observed:",
min = 1, max = 1e3, value = 50),
sliderInput(inputId = "grandMean", label = "Choose the grand mean of your data:",
min = 5, max = 100, value = 50),
sliderInput(inputId = "grandVar", label = "Choose the variance of the samples:",
min = .1, max = 5, value = 1),
sliderInput(inputId = "nWithin", label = "Choose the number of observations within each sample:",
min = 5, max = 100, value = 50),
sliderInput(inputId = "groupDistVar", label = "Choose the variance of the observations:",
min = .1, max = 5, value = 1)
),
mainPanel(
tableOutput("table"),
br(),
plotOutput("plot"),
br(),
plotOutput("plot2"))),
tabPanel("James-Stein Estimate", br(),
"Stein's Paradox concerns the use of observed averages to estimate unobserved quantities.", br(),
"It relies on a method named 'Shrinking' that pulls all the individual averages towards the grand mean.", br(), br(),
"If the individual mean is higher than the grand mean the estimation will shrink, and the estimate will increase if it is lower than the grand mean.", br(),
"You can observe this using the application. The resulting shrunken value is designated by z.", br(), br(),
"z is calculated by the following equation:", br(), "z = gm - c(im - gm),", br(), "where gm is the grand mean, the im is the individual mean and c a constant (shrinking factor).", br(), br(),
"For the most of the cases z is better estimator than the individual mean as it can be observed in the present app."),
a("Find more information on Wikipedia", href ="https://en.wikipedia.org/wiki/James-Stein_estimator")
))
)
server <- function(input, output) {
allinput = reactive(list(nBetween = input$nBetween,
nWithin = input$nWithin,
groupDistVar = input$groupDistVar,
grandMean = input$grandMean,
grandVar = input$grandVar))
output$table <- renderTable({
set.seed(1)
x <- genData(nBetween = allinput()$nBetween,
nWithin = allinput()$nWithin,
args1 = list(mean = allinput()$grandMean,
sd = allinput()$grandVar),
args2 = list(sd = allinput()$groupDistVar))
JS = JS.est(x$data)
ShrinkFactor = round(JS$c, digits = 4)
ShrinkPercent = (1 - ShrinkFactor)*100
table <- data.frame(ShrinkFactor,
ShrinkPercent)
colnames(table) <- c("Shrinkage Factor", "Shrink Percent")
row.names(table) <- NULL
print(table, row.names = FALSE)
}, rownames=TRUE)
output$plot <- renderPlot({
par(las = 1, bty = 'n')
set.seed(1)
x <- genData(nBetween = allinput()$nBetween,
nWithin = allinput()$nWithin,
args1 = list(mean = allinput()$grandMean,
sd = allinput()$grandVar),
args2 = list(sd = allinput()$groupDistVar))
dat = x$data
mx <- unlist(lapply(dat, mean))
JS = JS.est(dat)
JS.est = JS$z
shrinkagePlot(JS = JS, Mu = mx, pch = 19, lty = 1, yaxt = 'n', type = 'b',
ylab = '', xlab = 'Estimate', main = "",
col = rainbow(length(mx)))
axis(2, at = 0:1, labels = c('JS', 'Mean'))
})
output$plot2 = renderPlot({
par(las = 1, bty = 'n', mfrow = 1:2)
set.seed(1)
x <- genData(nBetween = allinput()$nBetween,
nWithin = allinput()$nWithin,
args1 = list(mean = allinput()$grandMean,
sd = allinput()$grandVar),
args2 = list(sd = allinput()$groupDistVar))
dat = x$data
mx <- unlist(lapply(dat, mean))
JS = JS.est(dat)
JS.est = JS$z
col = c(rainbow(length(mx)), 'black')
lwd = c(rep(1, length(mx)), 2.5)
densityPlot(mu = mx, sd = unlist(lapply(dat, sd)),
grandMean = allinput()$grandMean, grandSd = 1,
type = 'l', lty = 1, col = col, lwd = lwd,
main = "Each Group Mean Estimate")
densityPlot(mu = JS$z, sd = unlist(lapply(dat, sd)),
grandMean = allinput()$grandMean, grandSd = 1,
type = 'l', lty = 1, col = col, lwd = lwd,
main = "Each Group's JS Estimate")
legend('topright', legend = 'Total Distribution', lwd = 2, col = 'black', bty = 'n')
})
}
shinyApp(ui = ui, server = server)
genData = function(nBetween, nWithin, fun1 = 'rnorm', fun2 = 'rnorm',
args1 = list(mean = 0, sd = 1), args2 = list(sd = 5)) {
data.list = vector('list', nBetween) # make empty lists
args11 = args1[names(args1) %in% names(formals(fun1))] # only retain function args actually used in fun1
args22 = args2[names(args2) %in% names(formals(fun2))] # only retain function args actually used in fun2
mu = do.call(fun1, c(list(n = nBetween), args11)) # do function fun1 with arguments ...
for (i in 1:length(data.list)) {
# fill every element in data.list with fun2 executed with arguments ...
data.list[[i]] = do.call(fun2, c(list(n = nWithin, mean = mu[i]), args22))
}
return(list(data = data.list, mu = mu)) # return output
}
JS.est = function(x) {
y = unlist(lapply(x, mean)) # calc mean by group
yh = mean(unlist(x)) # calc mean of all data together
sig = var(y) # calc variance of the group means
if (is.na(sig)) {
sig = 0
}
k = length(x) # number of means to estimate (== number, of groups)
c = 1 - (k-3)*sig / (sum((y - yh)^2)) # formula for shrinkage factor from paper
if (is.nan(c)) {
c = 1
}
z = yh + c * (y - yh) # formula for stein estimates from paper
return(list(z = z, c = c)) # return output
}
shrinkagePlot = function(JS, Mu, plotly = FALSE, ...) {
x = t(cbind(Mu, JS$z)) # bind estimates in a matrix
x = as.matrix(x[, order(x[1, ])]) # order matrix by first column
if (!plotly) {
matplot(x, t(matrix(1:0, ncol(x), 2, TRUE)), ...) # plot matrix.
} else {
col = rainbow(ncol(x))
dat = data.frame(x = c(x),
y = c(t(matrix(1:0, ncol(x), 2, TRUE))),
group = rep(1:ncol(x), 1, NA, 2))
plot_ly(data = dat, x = x, y = y, group = dat$group, type = 'line',
color = dat$group)
}
}
densityPlot = function(mu, sd, grandMean, grandSd, x, ...) {
if (missing (sd)) {
sd = rep(1, length(mu))
}
if (missing(x)) {
r = range(mu) + c(-9, 9)
x = seq.int(r[1], r[2], .01)
}
y = mapply(dnorm, mean = mu, sd = sd, MoreArgs = list(x = x))
if (!missing(grandMean)) {
if (missing(grandSd)) {
grandSd = 1
}
y = cbind(y, 1*dnorm(x, grandMean, grandSd))
}
matplot(x, y, ...)
}
y = rep(0:1, 1, NA, 5)
x = c(1:5, 1:5+.1)
dat = data.frame(x = x, y = y, group = rep(1:5, 2))
plot_ly(dat, x = x, y = y, group = dat$group, type = 'line')
dat = genData(100, 10)
JS = JS.est(dat$data)
Mu = sapply(dat$data, mean)
JS$c
shrinkagePlot(JS = JS, Mu = Mu, plotly = TRUE)